acc_figures.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import matplotlib
  4. import argparse
  5. import os
  6. import sys
  7. sys.path.append('..')
  8. from util.mlflow_util import load_uri
  9. import mlflow
  10. from cycler import cycler
  11. print(plt.style.available)
  12. # MATLPLOTLIB settings
  13. plt.style.use('default')
  14. plt.style.use('seaborn-deep')
  15. matplotlib.rcParams['text.usetex'] = True
  16. matplotlib.rcParams['text.latex.unicode'] = True
  17. matplotlib.rcParams['axes.prop_cycle'] = cycler(color=['r', 'g', 'b', 'y', 'cyan', 'brown', 'k', 'gray', 'orange','purple', 'pink'])
  18. SMALL_SIZE = 16
  19. MEDIUM_SIZE = 22
  20. BIGGER_SIZE = 24
  21. plt.rc('font', size=SMALL_SIZE) # controls default text sizes
  22. plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title
  23. plt.rc('axes', labelsize=26) # fontsize of the x and y labels
  24. plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
  25. plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
  26. plt.rc('legend', fontsize=18) # legend fontsize
  27. plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
  28. acq_model2label_marker_color = {
  29. 'vopt-mgr' : ('VOPT-MGR', '+', 'k'),
  30. 'sopt-mgr' : ('SOPT-MGR', 'x', 'brown'),
  31. 'sopt-hf' : ('SOPT-HF', 'v', 'pink'),
  32. 'vopt-hf' : ('VOPT-HF', '^', 'm'),
  33. 'vopt-gr' : ('VOPT-GR', '+', 'k'),
  34. 'sopt-gr' : ('SOPT-GR', 'x', 'brown'),
  35. 'mc-mgr' : ('MC-MGR', '*', 'g'),
  36. 'mcgreedy-ce' : ('MCG-CE', 'v', 'cyan'),
  37. 'mc-ce' : ('MC-CE', 'v', 'b'),
  38. 'mc-probitnorm' : ('MC-P', 's', 'y'),
  39. 'mc-gr' : ('MC-GR', '*', 'g'),
  40. 'mc-log' : ('MC-LOG', 'o', 'b'),
  41. 'rand-mgr' : ('RAND-MGR', '<', 'purple'),
  42. 'rand-log' : ('RAND-LOG', '<', 'purple'),
  43. 'rand-ce' : ('RAND-CE', '<', 'purple'),
  44. 'rand-gr' : ('RAND-GR', '<', 'purple'),
  45. 'db-rkhs' : ('DB-RKHS', '>', 'r'),
  46. 'uncertainty-mgr' : ('UNC-MGR', 'P', 'orange'),
  47. 'uncertainty-gr' : ('UNC-GR', 'P', 'orange'),
  48. 'uncertainty-ce' : ('UNC-CE', 'X', 'gray'),
  49. 'uncertainty-log' : ('UNC-LOG', 'X', 'gray'),
  50. 'uncertainty-probitnorm' : ('UNC-P', 'd', 'gray'),
  51. }
  52. mlflow.set_tracking_uri('../mlruns')
  53. client = mlflow.tracking.MlflowClient()
  54. save_root = './for-paper/'
  55. # ## Binary Clusters
  56. # +
  57. exp_name = 'sequential-binary-clusters'
  58. exp_save_root = os.path.join(save_root, exp_name)
  59. if not os.path.exists(exp_save_root):
  60. os.makedirs(exp_save_root)
  61. experiment = client.get_experiment_by_name(exp_name)
  62. query = 'attributes.status = "FINISHED"'
  63. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  64. print(len(all_runs))
  65. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  66. print(setup_names)
  67. # +
  68. # Plot figure
  69. skip = 2
  70. first = True
  71. not_plot = ['rand-gr', 'rand-probitnorm', 'uncertainty-gr', 'uncertainty-probitnorm', 'vopt-gr', 'sopt-gr']
  72. plt.figure(figsize=(8,6))
  73. for setup_name in setup_names:
  74. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  75. continue
  76. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  77. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  78. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  79. al_iters = int(runs[0].data.params['al_iters'])
  80. ACC = np.zeros(al_iters + 1)
  81. print(len(runs), setup_name)
  82. for r in runs:
  83. acc = np.array([r.data.metrics['init_acc']])
  84. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  85. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  86. ACC /= float(len(runs))
  87. if first:
  88. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  89. B = int(runs[0].data.params['B'])
  90. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  91. first = False
  92. plt.scatter(dom[:50:skip], ACC[:50:skip], marker=mrkr, label=lbl, s=50, c=clr)
  93. plt.plot(dom[:50:skip], ACC[:50:skip], linewidth=0.9, c=clr)
  94. plt.legend()
  95. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  96. plt.ylabel("Accuracy")
  97. plt.tight_layout()
  98. # plt.savefig('{}/acc.pdf'.format(exp_save_root))
  99. plt.show()
  100. # -
  101. checkdata = np.load('../data/binary_clusters2/X_labels.npz', allow_pickle=True)
  102. X = checkdata['X']
  103. labels = checkdata['labels']
  104. clrs = np.array(X.shape[0]*['r'])
  105. clrs[labels == 0] = 'b'
  106. for r in all_runs:
  107. print(r.data.tags['mlflow.runName'])
  108. if r.data.tags['mlflow.runName'][-1] != '0':
  109. continue
  110. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  111. choices = iter_stats['al_choices'].flatten()
  112. init_labeled = load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy'))
  113. choices = np.concatenate((init_labeled, choices))[:350]
  114. if np.max(choices) > 1999:
  115. print('found checker run with more than 2000 nodes')
  116. continue
  117. plt.figure(figsize=(5,5))
  118. plt.scatter(X[:,0], X[:,1], c=clrs)
  119. plt.scatter(X[choices[:50],0], X[choices[:50],1], marker='*', s=90, c='gold', linewidths=0.6, edgecolors='k')
  120. #plt.title(r.data.tags['mlflow.runName'])
  121. plt.xticks([], [])
  122. plt.yticks([], [])
  123. plt.savefig('{}/{}.pdf'.format(exp_save_root, r.data.tags['mlflow.runName']))
  124. plt.show()
  125. # +
  126. # Redo of experiment -- changed tau = 0.001 and gamma = 0.5
  127. exp_name = 'sequential-binary-clusters3'
  128. exp_save_root = os.path.join(save_root, exp_name)
  129. if not os.path.exists(exp_save_root):
  130. os.makedirs(exp_save_root)
  131. experiment = client.get_experiment_by_name(exp_name)
  132. query = 'attributes.status = "FINISHED"'
  133. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  134. print(len(all_runs))
  135. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  136. print(setup_names)
  137. # +
  138. # Plot figure
  139. skip = 2
  140. first = True
  141. not_plot = ['rand-gr', 'rand-probitnorm', 'uncertainty-gr', 'uncertainty-probitnorm', 'vopt-gr', 'sopt-gr']
  142. plt.figure(figsize=(8,6))
  143. for setup_name in setup_names:
  144. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  145. continue
  146. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  147. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  148. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  149. al_iters = int(runs[0].data.params['al_iters'])
  150. ACC = np.zeros(al_iters + 1)
  151. print(len(runs), setup_name)
  152. for r in runs:
  153. acc = np.array([r.data.metrics['init_acc']])
  154. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  155. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  156. ACC /= float(len(runs))
  157. if first:
  158. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  159. B = int(runs[0].data.params['B'])
  160. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  161. first = False
  162. plt.scatter(dom[:50:skip], ACC[:50:skip], marker=mrkr, label=lbl, s=50, c=clr)
  163. plt.plot(dom[:50:skip], ACC[:50:skip], linewidth=0.9, c=clr)
  164. plt.legend()
  165. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  166. plt.ylabel("Accuracy")
  167. plt.tight_layout()
  168. plt.savefig('{}/acc.pdf'.format(exp_save_root))
  169. plt.show()
  170. # -
  171. checkdata = np.load('../data/binary_clusters2/X_labels.npz', allow_pickle=True)
  172. X = checkdata['X']
  173. labels = checkdata['labels']
  174. clrs = np.array(X.shape[0]*['r'])
  175. clrs[labels == 0] = 'b'
  176. for r in all_runs:
  177. print(r.data.tags['mlflow.runName'])
  178. if r.data.tags['mlflow.runName'][-1] != '0':
  179. continue
  180. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  181. choices = iter_stats['al_choices'].flatten()
  182. init_labeled = load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy'))
  183. choices = np.concatenate((init_labeled, choices))[:350]
  184. if np.max(choices) > 1999:
  185. print('found checker run with more than 2000 nodes')
  186. continue
  187. plt.figure(figsize=(5,5))
  188. plt.scatter(X[:,0], X[:,1], c=clrs)
  189. plt.scatter(X[choices[:50],0], X[choices[:50],1], marker='*', s=90, c='gold', linewidths=0.6, edgecolors='k')
  190. #plt.title(r.data.tags['mlflow.runName'])
  191. plt.xticks([], [])
  192. plt.yticks([], [])
  193. plt.savefig('{}/{}.pdf'.format(exp_save_root, r.data.tags['mlflow.runName']))
  194. plt.show()
  195. # # Binary Clusters 3 - Sequential
  196. # +
  197. # Redo of experiment -- changed tau = 0.001 and gamma = 0.5
  198. exp_name = 'binary-clusters3-sequential'
  199. exp_save_root = os.path.join(save_root, exp_name)
  200. if not os.path.exists(exp_save_root):
  201. os.makedirs(exp_save_root)
  202. experiment = client.get_experiment_by_name(exp_name)
  203. query = 'attributes.status = "FINISHED"'
  204. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  205. print(len(all_runs))
  206. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  207. print(setup_names)
  208. # +
  209. # Plot figure
  210. skip = 2
  211. first = True
  212. not_plot = ['rand-gr', 'rand-probitnorm', 'uncertainty-gr', 'uncertainty-probitnorm', 'vopt-gr', 'sopt-gr']
  213. plt.figure(figsize=(8,6))
  214. for setup_name in setup_names:
  215. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  216. continue
  217. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  218. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  219. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  220. al_iters = int(runs[0].data.params['al_iters'])
  221. ACC = np.zeros(al_iters + 1)
  222. print(len(runs), setup_name)
  223. for r in runs:
  224. acc = np.array([r.data.metrics['init_acc']])
  225. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  226. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  227. ACC /= float(len(runs))
  228. if first:
  229. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  230. B = int(runs[0].data.params['B'])
  231. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  232. first = False
  233. plt.scatter(dom[:100:skip], ACC[:100:skip], marker=mrkr, label=lbl, s=50, c=clr)
  234. plt.plot(dom[:100:skip], ACC[:100:skip], linewidth=0.9, c=clr)
  235. plt.legend()
  236. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  237. plt.ylabel("Accuracy")
  238. plt.tight_layout()
  239. plt.savefig('{}/acc.pdf'.format(exp_save_root))
  240. plt.show()
  241. # -
  242. checkdata = np.load('../data/binary_clusters3/X_labels.npz', allow_pickle=True)
  243. X = checkdata['X']
  244. labels = checkdata['labels']
  245. clrs = np.array(X.shape[0]*['r'])
  246. clrs[labels == 0] = 'b'
  247. for r in all_runs:
  248. print(r.data.tags['mlflow.runName'])
  249. if r.data.tags['mlflow.runName'][-1] != '0':
  250. continue
  251. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  252. choices = iter_stats['al_choices'].flatten()
  253. init_labeled = load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy'))
  254. choices = np.concatenate((init_labeled, choices))[:350]
  255. if np.max(choices) > 1999:
  256. print('found checker run with more than 2000 nodes')
  257. continue
  258. plt.figure(figsize=(5,5))
  259. plt.scatter(X[:,0], X[:,1], c=clrs)
  260. plt.scatter(X[choices[:50],0], X[choices[:50],1], marker='*', s=90, c='gold', linewidths=0.6, edgecolors='k')
  261. #plt.title(r.data.tags['mlflow.runName'])
  262. plt.xticks([], [])
  263. plt.yticks([], [])
  264. plt.savefig('{}/{}.pdf'.format(exp_save_root, r.data.tags['mlflow.runName']))
  265. plt.show()
  266. plt.figure(figsize=(5,5))
  267. plt.scatter(X[:,0], X[:,1], c=clrs)
  268. plt.xticks([], [])
  269. plt.yticks([], [])
  270. plt.savefig('{}/{}.pdf'.format(exp_save_root, 'bc-gt'))
  271. plt.show()
  272. # # Binary Clusters3 - Batch
  273. # +
  274. # Redo of experiment -- changed tau = 0.001 and gamma = 0.5
  275. exp_name = 'binary-clusters3-batch'
  276. exp_save_root = os.path.join(save_root, exp_name)
  277. if not os.path.exists(exp_save_root):
  278. os.makedirs(exp_save_root)
  279. experiment = client.get_experiment_by_name(exp_name)
  280. query = 'attributes.status = "FINISHED"'
  281. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  282. print(len(all_runs))
  283. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  284. print(setup_names)
  285. # +
  286. # Plot figure
  287. skip = 1
  288. first = True
  289. not_plot = ['rand-gr', 'rand-probitnorm', 'uncertainty-gr', 'uncertainty-probitnorm', 'vopt-gr', 'sopt-gr']
  290. plt.figure(figsize=(8,6))
  291. for setup_name in setup_names:
  292. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  293. continue
  294. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  295. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  296. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  297. al_iters = int(runs[0].data.params['al_iters'])
  298. ACC = np.zeros(al_iters + 1)
  299. print(len(runs), setup_name)
  300. for r in runs:
  301. acc = np.array([r.data.metrics['init_acc']])
  302. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  303. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  304. ACC /= float(len(runs))
  305. if first:
  306. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  307. B = int(runs[0].data.params['B'])
  308. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  309. first = False
  310. plt.scatter(dom[:20:skip], ACC[:20:skip], marker=mrkr, label=lbl, s=50, c=clr)
  311. plt.plot(dom[:20:skip], ACC[:20:skip], linewidth=0.9, c=clr)
  312. plt.legend()
  313. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  314. plt.ylabel("Accuracy")
  315. plt.tight_layout()
  316. plt.savefig('{}/acc.pdf'.format(exp_save_root))
  317. plt.show()
  318. # -
  319. plt.scatter(X[:,0], X[:,1], c=labels)
  320. plt.show()
  321. checkdata2 = np.load('../data/binary_clusters_check/X_labels.npz', allow_pickle=True)
  322. X2 = checkdata2['X']
  323. labels2 = checkdata2['labels']
  324. print(np.allclose(X2, X))
  325. # ## Checker 2
  326. # +
  327. exp_name = 'checker2'
  328. exp_save_root = os.path.join(save_root, exp_name)
  329. if not os.path.exists(exp_save_root):
  330. os.makedirs(exp_save_root)
  331. experiment = client.get_experiment_by_name(exp_name)
  332. query = 'attributes.status = "FINISHED"'
  333. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  334. print(len(all_runs))
  335. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  336. print(setup_names)
  337. # +
  338. # Plot figure
  339. skip = 2
  340. first = True
  341. not_plot = ['rand-gr', 'rand-probitnorm', 'uncertainty-gr', 'uncertainty-probitnorm', 'vopt-gr', 'sopt-gr']
  342. plt.figure(figsize=(8,6))
  343. for setup_name in setup_names:
  344. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  345. continue
  346. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  347. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  348. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  349. al_iters = int(runs[0].data.params['al_iters'])
  350. ACC = np.zeros(al_iters + 1)
  351. print(len(runs), setup_name)
  352. for r in runs:
  353. acc = np.array([r.data.metrics['init_acc']])
  354. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  355. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  356. ACC /= float(len(runs))
  357. if first:
  358. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  359. B = int(runs[0].data.params['B'])
  360. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  361. first = False
  362. plt.scatter(dom[::skip], ACC[::skip], marker=mrkr, label=lbl, s=50, c=clr)
  363. plt.plot(dom[::skip], ACC[::skip], linewidth=0.9, c=clr)
  364. plt.legend()
  365. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  366. plt.ylabel("Accuracy")
  367. plt.tight_layout()
  368. plt.savefig('{}/acc.pdf'.format(exp_save_root))
  369. plt.show()
  370. # -
  371. checkdata = np.load('../data/checker2/X_labels.npz', allow_pickle=True)
  372. X = checkdata['X']
  373. labels = checkdata['labels']
  374. clrs = np.array(X.shape[0]*['r'])
  375. clrs[labels == 0] = 'b'
  376. for r in all_runs:
  377. if r.data.tags['mlflow.runName'][-1] != '0':
  378. continue
  379. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  380. choices = iter_stats['al_choices'].flatten()
  381. init_labeled = load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy'))
  382. choices = np.concatenate((init_labeled, choices))[:350]
  383. if np.max(choices) > 1999:
  384. print('found checker run with more than 2000 nodes')
  385. continue
  386. plt.figure(figsize=(5,5))
  387. plt.scatter(X[:,0], X[:,1], c=clrs)
  388. plt.scatter(X[choices,0], X[choices,1], marker='*', s=90, c='gold', linewidths=0.6, edgecolors='k')
  389. #plt.title(r.data.tags['mlflow.runName'])
  390. plt.xticks([], [])
  391. plt.yticks([], [])
  392. plt.savefig('{}/{}.pdf'.format(exp_save_root, r.data.tags['mlflow.runName']))
  393. plt.show()
  394. # # Sequential Checker2
  395. # +
  396. exp_name = 'sequential-checker2'
  397. exp_save_root = os.path.join(save_root, exp_name)
  398. if not os.path.exists(exp_save_root):
  399. os.makedirs(exp_save_root)
  400. experiment = client.get_experiment_by_name(exp_name)
  401. query = 'attributes.status = "FINISHED"'
  402. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  403. print(len(all_runs))
  404. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  405. print(setup_names)
  406. # +
  407. # Plot figure
  408. skip = 5
  409. tot = 55
  410. first = True
  411. not_plot = ['rand-gr', 'rand-probitnorm', 'uncertainty-gr', 'uncertainty-probitnorm', 'vopt-gr', 'sopt-gr']
  412. plt.figure(figsize=(7,6))
  413. for setup_name in setup_names:
  414. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  415. continue
  416. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  417. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  418. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  419. al_iters = int(runs[0].data.params['al_iters'])
  420. ACC = np.zeros(al_iters + 1)
  421. print(len(runs), setup_name)
  422. for r in runs:
  423. acc = np.array([r.data.metrics['init_acc']])
  424. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  425. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  426. ACC /= float(len(runs))
  427. if first:
  428. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  429. B = int(runs[0].data.params['B'])
  430. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  431. first = False
  432. plt.scatter(dom[::skip][:tot], ACC[::skip][:tot], marker=mrkr, label=lbl, s=40, c=clr)
  433. plt.plot(dom[::skip][:tot], ACC[::skip][:tot], linewidth=0.9, c=clr)
  434. plt.legend()
  435. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  436. plt.ylabel("Accuracy")
  437. plt.tight_layout()
  438. plt.savefig('{}/acc.pdf'.format(exp_save_root))
  439. plt.show()
  440. # -
  441. checkdata = np.load('../data/checker2/X_labels.npz', allow_pickle=True)
  442. X = checkdata['X']
  443. labels = checkdata['labels']
  444. clrs = np.array(X.shape[0]*['r'])
  445. clrs[labels == 0] = 'b'
  446. for r in all_runs:
  447. if r.data.tags['mlflow.runName'][-1] != '0':
  448. continue
  449. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  450. choices = iter_stats['al_choices'].flatten()
  451. init_labeled = load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy'))
  452. choices = np.concatenate((init_labeled, choices))[:350]
  453. if np.max(choices) > 1999:
  454. print('found checker run with more than 2000 nodes')
  455. continue
  456. print(r.data.tags['mlflow.runName'])
  457. plt.figure(figsize=(5,5))
  458. plt.scatter(X[:,0], X[:,1], c=clrs)
  459. plt.scatter(X[choices,0], X[choices,1], marker='*', s=110, c='yellow', linewidths=0.2, edgecolors='k')
  460. #plt.title(r.data.tags['mlflow.runName'])
  461. plt.xticks([], [])
  462. plt.yticks([], [])
  463. plt.savefig('{}/{}.pdf'.format(exp_save_root, r.data.tags['mlflow.runName']),bbox_inches = 'tight',
  464. pad_inches = 0)
  465. plt.show()
  466. # +
  467. checkdata = np.load('../data/checker2/X_labels.npz', allow_pickle=True)
  468. X = checkdata['X']
  469. labels = checkdata['labels']
  470. clrs = np.array(X.shape[0]*['r'])
  471. clrs[labels == 0] = 'b'
  472. i1, i2 = None, None
  473. i = 0
  474. while (i1 is None or i2 is None) and i < 2000:
  475. x, y = X[i,0], X[i,1]
  476. if i1 is None:
  477. if 0.55 <= x <= 0.7 and 0.3 <= y <= 0.45:
  478. i1 = i
  479. if i2 is None:
  480. if 0.3 <= x <= 0.45 and 0.3 <= y <= 0.45:
  481. i2 = i
  482. i += 1
  483. plt.figure(figsize=(5,5))
  484. plt.scatter(X[:,0], X[:,1], c=clrs)
  485. # plt.scatter(X[i1,0], X[i1,1], marker='*', s=90, c='gold', linewidths=0.4, edgecolors='k')
  486. # plt.scatter(X[i2,0], X[i2,1], marker='*', s=90, c='gold', linewidths=0.4, edgecolors='k')
  487. plt.xticks([], [])
  488. plt.yticks([], [])
  489. plt.savefig('{}/gt.pdf'.format(exp_save_root), bbox_inches = 'tight',
  490. pad_inches = 0)
  491. plt.show()
  492. # -
  493. labeled_handchosen = [i1, i2]
  494. print(labeled_handchosen)
  495. # +
  496. exp_name = 'handchosen-checker2'
  497. exp_save_root = os.path.join(save_root, exp_name)
  498. if not os.path.exists(exp_save_root):
  499. os.makedirs(exp_save_root)
  500. experiment = client.get_experiment_by_name(exp_name)
  501. query = 'attributes.status = "FINISHED"'
  502. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  503. print(len(all_runs))
  504. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  505. print(setup_names)
  506. # +
  507. # Plot figure
  508. skip = 5
  509. tot = -1
  510. first = True
  511. not_plot = ['rand-gr', 'rand-probitnorm', 'uncertainty-gr', 'uncertainty-probitnorm', 'vopt-gr', 'sopt-gr']
  512. plt.figure(figsize=(7,6))
  513. for setup_name in setup_names:
  514. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  515. continue
  516. # runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  517. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName'] and r.data.params['cand'] == 'full']
  518. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  519. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  520. al_iters = int(runs[0].data.params['al_iters'])
  521. ACC = np.zeros(al_iters + 1)
  522. print(len(runs), setup_name)
  523. for r in runs:
  524. acc = np.array([r.data.metrics['init_acc']])
  525. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  526. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  527. ACC /= float(len(runs))
  528. if first:
  529. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  530. B = int(runs[0].data.params['B'])
  531. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  532. first = False
  533. plt.scatter(dom[::skip][:tot], ACC[::skip][:tot], marker=mrkr, label=lbl, s=40, c=clr)
  534. plt.plot(dom[::skip][:tot], ACC[::skip][:tot], linewidth=0.9, c=clr)
  535. plt.legend()
  536. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  537. plt.ylabel("Accuracy")
  538. plt.tight_layout()
  539. #plt.savefig('{}/acc.pdf'.format(exp_save_root))
  540. plt.show()
  541. # -
  542. checkdata = np.load('../data/checker2/X_labels.npz', allow_pickle=True)
  543. X = checkdata['X']
  544. labels = checkdata['labels']
  545. clrs = np.array(X.shape[0]*['r'])
  546. clrs[labels == 0] = 'b'
  547. for r in all_runs:
  548. if r.data.tags['mlflow.runName'][-1] != '1' or r.data.params['cand'] != 'full':
  549. continue
  550. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  551. choices = iter_stats['al_choices'].flatten()
  552. init_labeled = load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy'))
  553. choices = np.concatenate((init_labeled, choices))
  554. if np.max(choices) > 1999:
  555. print('found checker run with more than 2000 nodes')
  556. continue
  557. print(r.data.tags['mlflow.runName'])
  558. plt.figure(figsize=(5,5))
  559. plt.scatter(X[:,0], X[:,1], c=clrs)
  560. plt.scatter(X[choices,0], X[choices,1], marker='*', s=100, c='gold', linewidths=0.4, edgecolors='k')
  561. #plt.title(r.data.tags['mlflow.runName'])
  562. plt.xticks([], [])
  563. plt.yticks([], [])
  564. #plt.savefig('{}/{}.pdf'.format(exp_save_root, r.data.tags['mlflow.runName']))
  565. plt.show()
  566. # # Binary MNIST
  567. # +
  568. exp_name = 'sequential-binary-mnist'
  569. exp_save_root = os.path.join(save_root, exp_name)
  570. if not os.path.exists(exp_save_root):
  571. os.makedirs(exp_save_root)
  572. experiment = client.get_experiment_by_name(exp_name)
  573. query = 'attributes.status = "FINISHED"'
  574. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  575. print(len(all_runs))
  576. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  577. print(setup_names)
  578. # +
  579. # Plot figure
  580. skip = 1
  581. tot = 100
  582. first = True
  583. not_plot = ['rand-log', 'rand-probitnorm', 'uncertainty-log', 'uncertainty-probitnorm']
  584. plt.figure(figsize=(10,6))
  585. for setup_name in setup_names:
  586. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  587. continue
  588. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  589. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  590. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  591. al_iters = int(runs[0].data.params['al_iters'])
  592. ACC = np.zeros(al_iters + 1)
  593. print(len(runs), setup_name)
  594. for r in runs:
  595. acc = np.array([r.data.metrics['init_acc']])
  596. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  597. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  598. ACC /= float(len(runs))
  599. if first:
  600. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  601. B = int(runs[0].data.params['B'])
  602. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  603. first = False
  604. plt.scatter(dom[::skip][:tot], ACC[::skip][:tot], marker=mrkr, label=lbl, s=50, c=clr)
  605. plt.plot(dom[::skip][:tot], ACC[::skip][:tot], linewidth=0.9, c=clr)
  606. plt.legend()
  607. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  608. plt.ylabel("Accuracy")
  609. plt.tight_layout()
  610. plt.savefig('{}/acc.pdf'.format(exp_save_root))
  611. plt.show()
  612. # -
  613. # # Checker 3
  614. # +
  615. exp_name = 'checker3'
  616. exp_save_root = os.path.join(save_root, exp_name)
  617. if not os.path.exists(exp_save_root):
  618. os.makedirs(exp_save_root)
  619. experiment = client.get_experiment_by_name(exp_name)
  620. query = 'attributes.status = "FINISHED"'
  621. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  622. print(len(all_runs))
  623. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  624. print(setup_names)
  625. # +
  626. # Plot figure
  627. skip = 3
  628. first = True
  629. not_plot = ['rand-mgr', 'uncertainty-mgr']
  630. plt.figure(figsize=(10,6))
  631. for setup_name in setup_names:
  632. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  633. continue
  634. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  635. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  636. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  637. al_iters = int(runs[0].data.params['al_iters'])
  638. ACC = np.zeros(al_iters + 1)
  639. print(len(runs), setup_name)
  640. for r in runs:
  641. acc = np.array([r.data.metrics['init_acc']])
  642. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  643. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  644. ACC /= float(len(runs))
  645. if first:
  646. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  647. B = int(runs[0].data.params['B'])
  648. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  649. first = False
  650. plt.scatter(dom[::skip], ACC[::skip], marker=mrkr, label=lbl, s=50, c=clr)
  651. plt.plot(dom[::skip], ACC[::skip], linewidth=0.5, c=clr)
  652. plt.legend()
  653. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  654. plt.ylabel("Accuracy")
  655. plt.tight_layout()
  656. #plt.savefig('{}/acc.pdf'.format(exp_save_root))
  657. plt.show()
  658. # +
  659. # Plot figure
  660. skip = 3
  661. first = True
  662. for modelname in ['mgr', 'ce']:
  663. plt.figure(figsize=(7,5))
  664. for setup_name in setup_names:
  665. if modelname not in setup_name.split('-')[1]:
  666. continue
  667. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  668. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  669. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  670. al_iters = int(runs[0].data.params['al_iters'])
  671. ACC = np.zeros(al_iters + 1)
  672. print(len(runs), setup_name)
  673. for r in runs:
  674. acc = np.array([r.data.metrics['init_acc']])
  675. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  676. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  677. ACC /= float(len(runs))
  678. if first:
  679. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  680. B = int(runs[0].data.params['B'])
  681. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  682. first = False
  683. plt.scatter(dom[::skip], ACC[::skip], marker=mrkr, label=lbl, s=50, c=clr)
  684. plt.plot(dom[::skip], ACC[::skip], linewidth=1.5, c=clr)
  685. plt.legend()
  686. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  687. plt.ylabel("Accuracy")
  688. plt.tight_layout()
  689. plt.savefig('{}/acc-{}.pdf'.format(exp_save_root, modelname))
  690. plt.show()
  691. # -
  692. checkdata = np.load('../data/checker3/X_labels.npz', allow_pickle=True)
  693. X = checkdata['X']
  694. labels = checkdata['labels']
  695. clrs = np.array(X.shape[0]*['r'])
  696. clrs[labels == 0] = 'b'
  697. clrs[labels== 1] = 'g'
  698. for r in all_runs:
  699. if r.data.tags['mlflow.runName'][-1] != '1':
  700. continue
  701. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  702. #print(list(iter_stats.keys()))
  703. choices = iter_stats['al_choices'].flatten()
  704. init_labeled = load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy'))
  705. choices = np.concatenate((init_labeled, choices))[:]
  706. if np.max(choices) > 2999:
  707. continue
  708. plt.figure(figsize=(5,5))
  709. plt.scatter(X[:,0], X[:,1], c=clrs)
  710. plt.scatter(X[choices,0], X[choices,1], marker='*', s=90, c='gold', linewidths=0.6, edgecolors='k')
  711. #plt.title(r.data.tags['mlflow.runName'])
  712. plt.xticks([], [])
  713. plt.yticks([], [])
  714. plt.savefig('{}/{}.pdf'.format(exp_save_root, r.data.tags['mlflow.runName']))
  715. plt.show()
  716. # # MNIST
  717. # +
  718. exp_name = 'mnist'
  719. exp_save_root = os.path.join(save_root, exp_name)
  720. if not os.path.exists(exp_save_root):
  721. os.makedirs(exp_save_root)
  722. experiment = client.get_experiment_by_name(exp_name)
  723. query = 'attributes.status = "FINISHED"'
  724. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  725. print(len(all_runs))
  726. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  727. print(setup_names)
  728. # +
  729. # Plot figure
  730. skip = 3
  731. first = True
  732. not_plot = ['rand-mgr', 'uncertainty-mgr']
  733. plt.figure(figsize=(7,5))
  734. for setup_name in setup_names:
  735. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  736. continue
  737. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  738. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  739. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  740. al_iters = int(runs[0].data.params['al_iters'])
  741. ACC = np.zeros(al_iters + 1)
  742. print(len(runs), setup_name)
  743. for r in runs:
  744. acc = np.array([r.data.metrics['init_acc']])
  745. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  746. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  747. ACC /= float(len(runs))
  748. if first:
  749. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  750. B = int(runs[0].data.params['B'])
  751. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  752. first = False
  753. plt.scatter(dom[::skip], ACC[::skip], marker=mrkr, label=lbl, s=50, c=clr)
  754. plt.plot(dom[::skip], ACC[::skip], linewidth=0.5, c=clr)
  755. plt.legend()
  756. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  757. plt.ylabel("Accuracy")
  758. plt.tight_layout()
  759. #plt.savefig('{}/acc.pdf'.format(exp_save_root))
  760. plt.show()
  761. # +
  762. # Plot figure
  763. skip = 3
  764. first = True
  765. for modelname in ['mgr', 'ce']:
  766. plt.figure(figsize=(7,5))
  767. mm = 1.0
  768. for setup_name in setup_names:
  769. if modelname not in setup_name.split('-')[1]:
  770. continue
  771. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  772. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  773. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  774. al_iters = int(runs[0].data.params['al_iters'])
  775. ACC = np.zeros(al_iters + 1)
  776. print(len(runs), setup_name)
  777. for r in runs:
  778. acc = np.array([r.data.metrics['init_acc']])
  779. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  780. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  781. ACC /= float(len(runs))
  782. if first:
  783. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  784. B = int(runs[0].data.params['B'])
  785. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  786. first = False
  787. plt.scatter(dom[::skip], ACC[::skip], marker=mrkr, label=lbl, s=50, c=clr)
  788. plt.plot(dom[::skip], ACC[::skip], linewidth=1.5, c=clr)
  789. if min(ACC[::skip]) < mm:
  790. mm = min(ACC[::skip])
  791. plt.legend()
  792. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  793. plt.ylabel("Accuracy")
  794. plt.ylim([mm,1.0])
  795. plt.xticks([i*100 for i in range(6)])
  796. plt.tight_layout()
  797. plt.savefig('{}/acc-{}2.pdf'.format(exp_save_root, modelname))
  798. plt.show()
  799. # -
  800. # # Salinas
  801. # +
  802. exp_name = 'salinas'
  803. exp_save_root = os.path.join(save_root, exp_name)
  804. if not os.path.exists(exp_save_root):
  805. os.makedirs(exp_save_root)
  806. experiment = client.get_experiment_by_name(exp_name)
  807. query = 'attributes.status = "FINISHED"'
  808. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  809. print(len(all_runs))
  810. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  811. print(setup_names)
  812. # +
  813. # Plot figure
  814. skip = 3
  815. first = True
  816. not_plot = [] #['rand-mgr', 'uncertainty-mgr']
  817. plt.figure(figsize=(10,6))
  818. for setup_name in setup_names:
  819. if '-'.join(setup_name.split('-')[:2]) in not_plot:
  820. continue
  821. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  822. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  823. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  824. al_iters = int(runs[0].data.params['al_iters'])
  825. ACC = np.zeros(al_iters + 1)
  826. print(len(runs), setup_name)
  827. for r in runs:
  828. acc = np.array([r.data.metrics['init_acc']])
  829. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  830. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  831. ACC /= float(len(runs))
  832. if first:
  833. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  834. B = int(runs[0].data.params['B'])
  835. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  836. first = False
  837. plt.scatter(dom[::skip], ACC[::skip], marker=mrkr, label=lbl, s=50, c=clr)
  838. plt.plot(dom[::skip], ACC[::skip], linewidth=0.5, c=clr)
  839. plt.legend()
  840. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  841. plt.ylabel("Accuracy")
  842. plt.tight_layout()
  843. #plt.savefig('{}/acc.pdf'.format(exp_save_root))
  844. plt.show()
  845. # +
  846. # Plot figure
  847. skip = 3
  848. first = True
  849. for modelname in ['mgr', 'ce']:
  850. plt.figure(figsize=(7,5))
  851. mm = 1.0
  852. for setup_name in setup_names:
  853. if modelname not in setup_name.split('-')[1]:
  854. continue
  855. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  856. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  857. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  858. al_iters = int(runs[0].data.params['al_iters'])
  859. ACC = np.zeros(al_iters + 1)
  860. print(len(runs), setup_name)
  861. for r in runs:
  862. acc = np.array([r.data.metrics['init_acc']])
  863. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  864. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  865. ACC /= float(len(runs))
  866. if first:
  867. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  868. B = int(runs[0].data.params['B'])
  869. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  870. first = False
  871. plt.scatter(dom[::skip], ACC[::skip], marker=mrkr, label=lbl, s=50, c=clr)
  872. plt.plot(dom[::skip], ACC[::skip], linewidth=1.5, c=clr)
  873. if min(ACC[::skip]) < mm:
  874. mm = min(ACC[::skip])
  875. plt.legend()
  876. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  877. plt.ylabel("Accuracy")
  878. plt.ylim([mm, 0.9])
  879. plt.xticks([i*100 for i in range(6)])
  880. plt.tight_layout()
  881. plt.savefig('{}/acc-{}2.pdf'.format(exp_save_root, modelname))
  882. plt.show()
  883. # -
  884. # # Urban
  885. # +
  886. exp_name = 'urban'
  887. exp_save_root = os.path.join(save_root, exp_name)
  888. if not os.path.exists(exp_save_root):
  889. os.makedirs(exp_save_root)
  890. experiment = client.get_experiment_by_name(exp_name)
  891. query = 'attributes.status = "FINISHED"'
  892. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  893. print(len(all_runs))
  894. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  895. print(setup_names)
  896. # +
  897. # Plot figure
  898. skip = 3
  899. first = True
  900. not_plot = ['mc-mgr', 'mc-ce']
  901. plt.figure(figsize=(7,5))
  902. for setup_name in setup_names:
  903. if '-'.join(setup_name.split('-')[:2]) not in plot:
  904. continue
  905. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  906. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  907. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  908. al_iters = int(runs[0].data.params['al_iters'])
  909. ACC = np.zeros(al_iters + 1)
  910. print(len(runs), setup_name)
  911. for r in runs:
  912. acc = np.array([r.data.metrics['init_acc']])
  913. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  914. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  915. ACC /= float(len(runs))
  916. if first:
  917. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  918. B = int(runs[0].data.params['B'])
  919. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  920. first = False
  921. plt.scatter(dom[::skip], ACC[::skip], marker=mrkr, label=lbl, s=50, c=clr)
  922. plt.plot(dom[::skip], ACC[::skip], linewidth=0.5, c=clr)
  923. plt.legend()
  924. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  925. plt.ylabel("Accuracy")
  926. plt.tight_layout()
  927. #plt.savefig('{}/acc.pdf'.format(exp_save_root))
  928. plt.show()
  929. # +
  930. # Plot figure
  931. skip = 3
  932. first = True
  933. for modelname in ['mgr', 'ce']:
  934. plt.figure(figsize=(7,5))
  935. mm = 1.0
  936. for setup_name in setup_names:
  937. if modelname not in setup_name.split('-')[1]:
  938. continue
  939. runs = [r for r in all_runs if setup_name in r.data.tags['mlflow.runName']]
  940. acq_model = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  941. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  942. al_iters = int(runs[0].data.params['al_iters'])
  943. ACC = np.zeros(al_iters + 1)
  944. print(len(runs), setup_name)
  945. for r in runs:
  946. acc = np.array([r.data.metrics['init_acc']])
  947. iter_stats = load_uri(os.path.join(r.info.artifact_uri, 'iter_stats.npz'))
  948. ACC += np.concatenate((acc, iter_stats['iter_acc']))
  949. ACC /= float(len(runs))
  950. if first:
  951. num_init_labeled = len(load_uri(os.path.join(r.info.artifact_uri, 'init_labeled.npy')))
  952. B = int(runs[0].data.params['B'])
  953. dom = [num_init_labeled + B*i for i in range(al_iters+1)]
  954. first = False
  955. plt.scatter(dom[::skip], ACC[::skip], marker=mrkr, label=lbl, s=50, c=clr)
  956. plt.plot(dom[::skip], ACC[::skip], linewidth=1.5, c=clr)
  957. if min(ACC[::skip]) < mm:
  958. mm = min(ACC[::skip])
  959. plt.legend()
  960. plt.xlabel("Number of labeled points, $|\mathcal{L}|$")
  961. plt.ylabel("Accuracy")
  962. plt.ylim([mm, 1.0])
  963. plt.xticks([i*100 for i in range(6)])
  964. plt.tight_layout()
  965. plt.savefig('{}/acc-{}2.pdf'.format(exp_save_root, modelname))
  966. plt.show()
  967. # -
  968. # # Timing
  969. from collections import defaultdict
  970. mlflow.set_tracking_uri('../mlruns-old')
  971. client = mlflow.tracking.MlflowClient()
  972. # +
  973. TIMES = defaultdict(list)
  974. sizes = [2000, 5000, 10000, 20000]
  975. for i, exp_name in enumerate(['checker2', '5K-checker2', '10K-c2', '20K-c2']):
  976. experiment = client.get_experiment_by_name(exp_name)
  977. query = 'attributes.status = "FINISHED"'
  978. query += ' and params.B = "{}"'.format(B)
  979. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  980. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  981. print(setup_names)
  982. for setup_name in setup_names:
  983. if 'rand' in setup_name or 'uncertainty' in setup_name:
  984. continue
  985. runs = [r for r in all_runs if (setup_name in r.data.tags['mlflow.runName'] and r.data.tags['mlflow.runName'][-1] == '0')]
  986. print(len(runs))
  987. for r in runs:
  988. a_uri_split = str(r.info.artifact_uri).split('/')
  989. a_uri_split[7] += '-old'
  990. a_uri = '/'.join(a_uri_split)
  991. iter_stats = load_uri(os.path.join(a_uri, 'iter_stats.npz'))
  992. choices = iter_stats['al_choices'].flatten()
  993. init_labeled = load_uri(os.path.join(a_uri, 'init_labeled.npy'))
  994. choices = np.concatenate((init_labeled, choices))
  995. k = '-'.join(r.data.tags['mlflow.runName'].split('-')[:2])
  996. times = iter_stats['iter_time']
  997. if i == 0:
  998. if np.max(choices) < 2000:
  999. TIMES[k].append((2, np.average(times)))
  1000. break
  1001. else:
  1002. TIMES[k].append((sizes[i]//1000, np.average(times)))
  1003. break
  1004. for k in TIMES:
  1005. print(k)
  1006. print(TIMES[k])
  1007. # +
  1008. TIMESMC = defaultdict(list)
  1009. mlflow.set_tracking_uri('../mlruns')
  1010. client = mlflow.tracking.MlflowClient()
  1011. B = 5
  1012. for exp_name in ['checker3', 'salinas', 'mnist', 'urban']:
  1013. exp_save_root = os.path.join(save_root, exp_name)
  1014. if not os.path.exists(exp_save_root):
  1015. os.makedirs(exp_save_root)
  1016. experiment = client.get_experiment_by_name(exp_name)
  1017. query = 'attributes.status = "FINISHED"'
  1018. print(query)
  1019. all_runs = client.search_runs(experiment.experiment_id, filter_string=query)
  1020. setup_names = sorted(set(['-'.join(r.data.tags['mlflow.runName'].split('-')[:-1]) for r in all_runs]))
  1021. print(setup_names)
  1022. for setup_name in setup_names:
  1023. if 'rand' in setup_name or 'uncertainty' in setup_name:
  1024. continue
  1025. runs = [r for r in all_runs if (setup_name in r.data.tags['mlflow.runName'] and r.data.tags['mlflow.runName'][-1] == '0')]
  1026. iter_stats = load_uri(os.path.join(runs[0].info.artifact_uri, 'iter_stats.npz'))
  1027. choices = iter_stats['al_choices'].flatten()
  1028. init_labeled = load_uri(os.path.join(runs[0].info.artifact_uri, 'init_labeled.npy'))
  1029. choices = np.concatenate((init_labeled, choices))
  1030. k = '-'.join(runs[0].data.tags['mlflow.runName'].split('-')[:2])
  1031. times = iter_stats['iter_time']
  1032. TIMESMC[k].append((exp_name, np.average(times)))
  1033. for k in TIMESMC:
  1034. print(k)
  1035. print(TIMESMC[k])
  1036. # +
  1037. name2size = {
  1038. 'checker3' : 3000,
  1039. 'salinas' : 7148,
  1040. 'mnist' : 70000,
  1041. 'urban' : 94129
  1042. }
  1043. fig, ax = plt.subplots(1,1, figsize=(8,5))
  1044. lines = []
  1045. Names = []
  1046. for k in TIMES:
  1047. if 'sopt' in k: continue
  1048. digs, times = zip(*TIMES[k])
  1049. acq_model = '-'.join(k.split("-")[:2]).lower()
  1050. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  1051. p1, = ax.loglog(1000*np.array(digs), times, c=clr)
  1052. p2 = ax.scatter(1000*np.array(digs), times, marker=mrkr, c=clr, s=50)
  1053. lines.append((p1, p2))
  1054. Names.append(lbl)
  1055. # plt.legend()
  1056. # plt.show()
  1057. for k in TIMESMC:
  1058. if 'sopt' in k: continue
  1059. names, times = zip(*TIMESMC[k])
  1060. acq_model = '-'.join(k.split("-")[:2]).lower()
  1061. lbl, mrkr, clr = acq_model2label_marker_color[acq_model]
  1062. numbers = np.array([name2size[name] for name in names])
  1063. p1, = ax.loglog(numbers, times, '--', c=clr)
  1064. p2 = ax.scatter(numbers, times, marker=mrkr, c=clr, s=50)
  1065. lines.append((p1, p2))
  1066. Names.append(lbl)
  1067. print(lines)
  1068. print(Names)
  1069. ax.legend(lines, Names, bbox_to_anchor=(1.01, 1.0))
  1070. ax.set_xlabel("Size of Dataset, $N$")
  1071. ax.set_ylabel("Avg. AL Query Time")
  1072. plt.savefig(os.path.join(save_root, "timing.pdf"), bbox_inches = "tight")
  1073. plt.tight_layout()
  1074. plt.show()
  1075. # -
  1076. 70000./20.
  1077. 20./70000.