123456789101112131415161718192021222324252627282930313233343536373839 |
- def train(self, save=False, save_dir=None):
- train_img_list = glob.glob(self.path_train + "/*")
- print(train_img_list)
- train_features = []
- for img_file in train_img_list:
- img = io.imread(img_file)
- img = color.rgb2lab(img)
- img_features = self.extract_texton_feature(img, self.fb, self.nb_features)
- train_features.extend(img_features)
- train_features = np.array(train_features)
- print(train_features.shape)
- kmeans_cluster = MiniBatchKMeans(n_clusters=self.nb_clusters, verbose=1, max_iter=300)
- kmeans_cluster.fit(train_features)
- print(kmeans_cluster.cluster_centers_)
- print(kmeans_cluster.cluster_centers_.shape)
- self.cluster = kmeans_cluster
- # save kmeans result
- if save is True:
- with open(save_dir, 'wb') as f:
- pickle.dump(self.cluster, f)
- def save(self, event):
- if not self.filename:
- self.save_as(event)
- else:
- if self.writefile(self.filename):
- self.set_saved(True)
- try:
- self.editwin.store_file_breaks()
- except AttributeError: # may be a PyShell
- pass
- self.text.focus_set()
- return "break"
|