Cluster.java 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. package com.mooctest.cluster;
  2. import java.util.ArrayList;
  3. import java.util.List;
  4. import java.util.ListIterator;
  5. import java.util.Random;
  6. public class Cluster<K> implements Comparable< Cluster<K>>
  7. {
  8. List<Document<K>> documents_; ///< documents
  9. SparseVector composite_; ///< a composite SparseVector
  10. SparseVector centroid_; ///< a centroid SparseVector
  11. List< Cluster<K>> sectioned_clusters_; ///< sectioned clusters
  12. double sectioned_gain_; ///< a sectioned gain
  13. Random random;
  14. public Cluster()
  15. {
  16. this(new ArrayList<Document<K>>());
  17. }
  18. public Cluster(List<Document<K>> documents)
  19. {
  20. this.documents_ = documents;
  21. composite_ = new SparseVector();
  22. random = new Random();
  23. }
  24. /**
  25. * Add the vectors of all documents to a composite vector.
  26. */
  27. void set_composite_vector()
  28. {
  29. composite_.clear();
  30. for (Document<K> document : documents_)
  31. {
  32. composite_.add_vector(document.feature());
  33. }
  34. }
  35. /**
  36. * Clear status.
  37. */
  38. void clear()
  39. {
  40. documents_.clear();
  41. composite_.clear();
  42. if (centroid_ != null)
  43. centroid_.clear();
  44. if (sectioned_clusters_ != null)
  45. sectioned_clusters_.clear();
  46. sectioned_gain_ = 0.0;
  47. }
  48. /**
  49. * Get the size.
  50. *
  51. * @return the size of this cluster
  52. */
  53. int size()
  54. {
  55. return documents_.size();
  56. }
  57. /**
  58. * Get the pointer of a centroid vector.
  59. *
  60. * @return the pointer of a centroid vector
  61. */
  62. SparseVector centroid_vector()
  63. {
  64. if (documents_.size() > 0 && composite_.size() == 0)
  65. set_composite_vector();
  66. centroid_ = (SparseVector) composite_vector().clone();
  67. centroid_.normalize();
  68. return centroid_;
  69. }
  70. /**
  71. * Get the pointer of a composite vector.
  72. *
  73. * @return the pointer of a composite vector
  74. */
  75. SparseVector composite_vector()
  76. {
  77. return composite_;
  78. }
  79. /**
  80. * Get documents in this cluster.
  81. *
  82. * @return documents in this cluster
  83. */
  84. List<Document<K>> documents()
  85. {
  86. return documents_;
  87. }
  88. /**
  89. * Add a document.
  90. *
  91. * @param doc the pointer of a document object
  92. */
  93. void add_document(Document doc)
  94. {
  95. doc.feature().normalize();
  96. documents_.add(doc);
  97. composite_.add_vector(doc.feature());
  98. }
  99. /**
  100. * Remove a document from this cluster.
  101. *
  102. * @param index the index of vector container of documents
  103. */
  104. void remove_document(int index)
  105. {
  106. ListIterator<Document<K>> listIterator = documents_.listIterator(index);
  107. Document<K> document = listIterator.next();
  108. listIterator.set(null);
  109. composite_.sub_vector(document.feature());
  110. }
  111. /**
  112. * Remove a document from this cluster.
  113. *
  114. * @param doc the pointer of a document object
  115. */
  116. void remove_document(Document doc)
  117. {
  118. for (Document<K> document : documents_)
  119. {
  120. if (document.equals(doc))
  121. {
  122. remove_document(doc);
  123. return;
  124. }
  125. }
  126. }
  127. /**
  128. * Delete removed documents from the internal container.
  129. */
  130. void refresh()
  131. {
  132. ListIterator<Document<K>> listIterator = documents_.listIterator();
  133. while (listIterator.hasNext())
  134. {
  135. if (listIterator.next() == null)
  136. listIterator.remove();
  137. }
  138. }
  139. /**
  140. * Get a gain when this cluster sectioned.
  141. *
  142. * @return a gain
  143. */
  144. double sectioned_gain()
  145. {
  146. return sectioned_gain_;
  147. }
  148. /**
  149. * Set a gain when the cluster sectioned.
  150. */
  151. void set_sectioned_gain()
  152. {
  153. double gain = 0.0f;
  154. if (sectioned_gain_ == 0 && sectioned_clusters_.size() > 1)
  155. {
  156. for ( Cluster<K> cluster : sectioned_clusters_)
  157. {
  158. gain += cluster.composite_vector().norm();
  159. }
  160. gain -= composite_.norm();
  161. }
  162. sectioned_gain_ = gain;
  163. }
  164. /**
  165. * Get sectioned clusters.
  166. *
  167. * @return sectioned clusters
  168. */
  169. List< Cluster<K>> sectioned_clusters()
  170. {
  171. return sectioned_clusters_;
  172. }
  173. // /**
  174. // * Choose documents randomly.
  175. // */
  176. // void choose_randomly(int ndocs, List<Document > docs)
  177. //{
  178. // HashMap<int, bool>.type choosed;
  179. // int siz = size();
  180. // init_hash_map(siz, choosed, ndocs);
  181. // if (siz < ndocs)
  182. // ndocs = siz;
  183. // int count = 0;
  184. // while (count < ndocs)
  185. // {
  186. // int index = myrand(seed_) % siz;
  187. // if (choosed.find(index) == choosed.end())
  188. // {
  189. // choosed.insert(std.pair<int, bool>(index, true));
  190. // docs.push_back(documents_[index]);
  191. // ++count;
  192. // }
  193. // }
  194. //}
  195. /**
  196. * 选取初始质心
  197. *
  198. * @param ndocs 质心数量
  199. * @param docs 输出到该列表中
  200. */
  201. void choose_smartly(int ndocs, List<Document> docs)
  202. {
  203. int siz = size();
  204. double[] closest = new double[siz];
  205. if (siz < ndocs)
  206. ndocs = siz;
  207. int index, count = 0;
  208. index = random.nextInt(siz); // initial center
  209. docs.add(documents_.get(index));
  210. ++count;
  211. double potential = 0.0;
  212. for (int i = 0; i < documents_.size(); i++)
  213. {
  214. double dist = 1.0 - SparseVector.inner_product(documents_.get(i).feature(), documents_.get(index).feature());
  215. potential += dist;
  216. closest[i] = dist;
  217. }
  218. // choose each center
  219. while (count < ndocs)
  220. {
  221. double randval = random.nextDouble() * potential;
  222. for (index = 0; index < documents_.size(); index++)
  223. {
  224. double dist = closest[index];
  225. if (randval <= dist)
  226. break;
  227. randval -= dist;
  228. }
  229. if (index == documents_.size())
  230. index--;
  231. docs.add(documents_.get(index));
  232. ++count;
  233. double new_potential = 0.0;
  234. for (int i = 0; i < documents_.size(); i++)
  235. {
  236. double dist = 1.0 - SparseVector.inner_product(documents_.get(i).feature(), documents_.get(index).feature());
  237. double min = closest[i];
  238. if (dist < min)
  239. {
  240. closest[i] = dist;
  241. min = dist;
  242. }
  243. new_potential += min;
  244. }
  245. potential = new_potential;
  246. }
  247. }
  248. /**
  249. * 将本簇划分为nclusters个簇
  250. *
  251. * @param nclusters
  252. */
  253. void section(int nclusters)
  254. {
  255. if (size() < nclusters)
  256. return;
  257. sectioned_clusters_ = new ArrayList< Cluster<K>>(nclusters);
  258. List<Document> centroids = new ArrayList<Document>(nclusters);
  259. // choose_randomly(nclusters, centroids);
  260. choose_smartly(nclusters, centroids);
  261. for (int i = 0; i < centroids.size(); i++)
  262. {
  263. Cluster<K> cluster = new Cluster<K>();
  264. sectioned_clusters_.add(cluster);
  265. }
  266. for (Document<K> d : documents_)
  267. {
  268. double max_similarity = -1.0;
  269. int max_index = 0;
  270. for (int j = 0; j < centroids.size(); j++)
  271. {
  272. double similarity = SparseVector.inner_product(d.feature(), centroids.get(j).feature());
  273. if (max_similarity < similarity)
  274. {
  275. max_similarity = similarity;
  276. max_index = j;
  277. }
  278. }
  279. sectioned_clusters_.get(max_index).add_document(d);
  280. }
  281. }
  282. @Override
  283. public int compareTo( Cluster<K> o)
  284. {
  285. return Double.compare(o.sectioned_gain(), sectioned_gain());
  286. }
  287. }