runs_util.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. import numpy as np
  2. import time
  3. import sys
  4. from sklearn.preprocessing import OneHotEncoder
  5. from sklearn.datasets import make_moons, make_blobs
  6. from .activelearner import *
  7. from .gbssl import *
  8. import mlflow
  9. BMODELNAMES = ['gr', 'log', 'probitnorm']
  10. MMODELNAMES = ['mgr', 'ce']
  11. OTHERMODELNAMES = ['rkhs', 'hf']
  12. ACQS = ['mc', 'uncertainty', 'rand', 'vopt', 'sopt', 'mbr', 'mcgreedy', 'mcavg', 'mcavgf', 'mcf']
  13. def create_checkerboard2(N):
  14. X = np.random.rand(N,2)
  15. labels = []
  16. for x in X:
  17. i, j = 0,0
  18. if 0.25 <= x[0] and x[0] < 0.5:
  19. i = 1
  20. elif 0.5 <= x[0] and x[0] < 0.75:
  21. i = 2
  22. elif 0.75 <= x[0]:
  23. i = 3
  24. if 0.25 <= x[1] and x[1] < 0.5:
  25. j = 1
  26. elif 0.5 <= x[1] and x[1] < 0.75:
  27. j = 2
  28. elif 0.75 <= x[1]:
  29. j = 3
  30. labels.append((i+j) % 2)
  31. return X, np.array(labels)
  32. def create_binary_clusters():
  33. np.random.seed(4)
  34. Xm, labelsm = make_moons(200, shuffle=False, noise=0.12)
  35. X1, labels1 = make_blobs([50,60, 40, 30, 40], 2, shuffle=False, centers=[[1.6,-1.3],[1.3,1.7], [0.5, 2.4], [0.2,-1.], [-1.7,2.2]], cluster_std=[.26, .23, .23, .26, .23])
  36. labels1 = labels1 % 2
  37. X2 = np.random.randn(100,2) @ np.array([[.4, 0.],[0.,.3]]) + np.array([-1.5,-.8])
  38. X3 = np.random.randn(70,2) @ np.array([[.4, 0.],[0.,.3]]) + np.array([2.5,2.8])
  39. x11, x12 = np.array([-2., 0.8])[np.newaxis, :], np.array([-.2,2.])[np.newaxis, :]
  40. l1 = (x11 + np.linspace(0,1, 80)[:, np.newaxis] @ (x12 - x11)) + np.random.randn(80, 2)*0.18
  41. x21, x22 = np.array([2.5, -1.5])[np.newaxis, :], np.array([2.5, 2.])[np.newaxis, :]
  42. l2 = (x21 + np.linspace(0,1, 90)[:, np.newaxis] @ (x22 - x21)) + np.random.randn(90, 2)*0.2
  43. X = np.concatenate((Xm, X1, X2, X3, l1, l2))
  44. labels = np.concatenate((labelsm, labels1, np.zeros(100), np.ones(70), np.ones(80), np.zeros(90)))
  45. return X, labels
  46. def create_checkerboard3(N):
  47. X = np.random.rand(N,2)
  48. labels = []
  49. for x in X:
  50. i, j = 0,0
  51. if 0.33333 <= x[0] and x[0] < 0.66666:
  52. i = 1
  53. elif 0.66666 <= x[0]:
  54. i = 2
  55. if 0.33333 <= x[1] and x[1] < 0.66666:
  56. j = 1
  57. elif 0.66666 <= x[1]:
  58. j = 2
  59. labels.append(3*j + i)
  60. labels = np.array(labels)
  61. labels[labels == 4] = 0
  62. labels[labels == 8] = 0
  63. labels[labels == 5] = 1
  64. labels[labels == 6] = 1
  65. labels[labels == 3] = 2
  66. labels[labels == 7] = 2
  67. return X, labels
  68. def run_binary(w, v, tau, gamma, oracle, init_labeled, num_al_iters, B_per_al_iter, modelname='gr', acq='mc',
  69. cand='rand', select_method='top', full=False,
  70. verbose=False):
  71. '''
  72. Inputs:
  73. w = eigenvalue numpy array
  74. v = eigenvectors numpy array (columns)
  75. oracle = "labels" ground truth numpy array, in {0, 1, ..., n_c} or {-1, 1}
  76. init_labeled = list of indices that are initially labeled, per ordering in oracle and rows of v
  77. num_al_iters = total number of active learning iterations to perform
  78. B_per_al_iter = batch size B that will be done on each iteration
  79. acq = string that refers to the acquisition function to be tried in this experiment
  80. Outputs:
  81. labeled : list of indices of labeled points chosen throughout whole active learning process
  82. acc : list of length (num_al_iters + 1) corresponding to the accuracies of the current classifer at each AL iteration
  83. '''
  84. if modelname not in BMODELNAMES:
  85. raise ValueError("modelname %s not in list of possible modelnames : \n%s" % (
  86. modelname, str(BMODELNAMES)))
  87. if acq not in ACQS:
  88. raise ValueError(
  89. "acq = %s is not a valid acquisition function currently implemented:\n\t%s" % (acq, str(ACQS)))
  90. N, M = v.shape
  91. if M < N:
  92. truncated = True
  93. else:
  94. truncated = False
  95. if -1 not in np.unique(oracle):
  96. oracle[oracle == 0] = -1
  97. if truncated and not full:
  98. print("Binary %s Reduced Model -- i.e. not storing full C covariance matrix" % modelname)
  99. model = BinaryGraphBasedSSLModelReduced(
  100. modelname, gamma, tau, w=w, v=v)
  101. elif truncated and full:
  102. print("Binary %s FULL Model, but Truncated eigenvalues" % modelname)
  103. model = BinaryGraphBasedSSLModel(modelname, gamma, tau, w=w, v=v)
  104. else:
  105. print("Binary %s FULL Model, with ALL eigenvalues" % modelname)
  106. model = BinaryGraphBasedSSLModel(modelname, gamma, tau, w=w, v=v)
  107. # train the initial model, record accuracy
  108. model.calculate_model(labeled=init_labeled[:], y=list(oracle[init_labeled]))
  109. acc = get_acc(model.m, oracle, unlabeled=model.unlabeled)[1]
  110. mlflow.log_metric('init_acc', acc)
  111. # instantiate ActiveLearner object
  112. print("ActiveLearner Settings:\n\tacq = \t%s\n\tcand = \t%s" % (acq, cand))
  113. print("\tselect_method = %s, B = %d" % (select_method, B_per_al_iter))
  114. AL = ActiveLearner(acquisition=acq, candidate=cand)
  115. iter_acc = []
  116. iter_time = []
  117. al_choices = []
  118. for al_iter in range(num_al_iters):
  119. if verbose or (al_iter % 10 == 0):
  120. print("AL Iteration %d, acc=%1.6f" % (al_iter + 1, acc))
  121. # select query points via active learning
  122. tic = time.perf_counter()
  123. Q = AL.select_query_points(
  124. model, B_per_al_iter, method=select_method, verbose=verbose)
  125. toc = time.perf_counter()
  126. # query oracle
  127. yQ = list(oracle[Q])
  128. # update model, and calculate updated model's accuracy
  129. model.update_model(Q, yQ)
  130. acc = get_acc(model.m, oracle, unlabeled=model.unlabeled)[1]
  131. iter_acc.append(acc)
  132. iter_time.append(toc - tic)
  133. al_choices.append(Q)
  134. np.savez('tmp/iter_stats.npz', al_choices=np.array(al_choices), iter_acc=np.array(iter_acc), iter_time=np.array(iter_time))
  135. mlflow.log_artifact('tmp/iter_stats.npz')
  136. return
  137. def run_rkhs_hf(oracle, init_labeled, num_al_iters, B_per_al_iter, modelname='rkhs', h=0.1, delta=0.1, X=None, L=None,
  138. cand='rand', select_method='top', acq='db', verbose=False):
  139. '''
  140. Inputs:
  141. X = dataset
  142. oracle = "labels" ground truth numpy array, in {0, 1, ..., n_c} or {-1, 1}
  143. init_labeled = list of indices that are initially labeled, per ordering in oracle and rows of v
  144. num_al_iters = total number of active learning iterations to perform
  145. B_per_al_iter = batch size B that will be done on each iteration
  146. Outputs:
  147. labeled : list of indices of labeled points chosen throughout whole active learning process
  148. acc : list of length (num_al_iters + 1) corresponding to the accuracies of the current classifer at each AL iteration
  149. '''
  150. if modelname == 'rkhs':
  151. assert X is not None
  152. model = RKHSClassifier(X, sigma=h) # bandwidth from Karzand paper
  153. else:
  154. assert L is not None
  155. model = HFGraphBasedSSLModel(delta, L)
  156. # train the initial model, record accuracy
  157. if len(np.unique(oracle)) > 2:
  158. # calculate one-hot labels for oracle
  159. enc = OneHotEncoder()
  160. enc.fit(oracle.reshape((-1, 1)))
  161. oracle_onehot = enc.transform(oracle.reshape((-1, 1))).todense()
  162. y_init = oracle_onehot[init_labeled]
  163. else:
  164. # binary case
  165. if -1 not in np.unique(oracle):
  166. oracle[oracle == 0] = -1
  167. y_init = list(oracle[init_labeled])
  168. model.calculate_model(labeled=init_labeled[:], y=y_init)
  169. if model.nc > 2:
  170. acc = get_acc_multi(np.argmax(model.f, axis=1),
  171. oracle, unlabeled=model.unlabeled)[1]
  172. else:
  173. acc = get_acc(model.f, oracle, unlabeled=model.unlabeled)[1]
  174. mlflow.log_metric('init_acc', acc)
  175. # instantiate ActiveLearner object
  176. print("ActiveLearner Settings:\n\t{} {}".format(modelname.upper(), acq.upper()))
  177. print("\tselect_method = %s, B = %d" % (select_method, B_per_al_iter))
  178. AL = ActiveLearner(acquisition=acq, candidate=cand)
  179. iter_acc = []
  180. iter_time = []
  181. al_choices = []
  182. for al_iter in range(num_al_iters):
  183. if verbose or (al_iter % 10 == 0):
  184. print("AL Iteration %d, acc=%1.6f" % (al_iter + 1, acc))
  185. # select query points via active learning
  186. tic = time.perf_counter()
  187. Q = AL.select_query_points(
  188. model, B_per_al_iter, method=select_method, verbose=verbose)
  189. toc = time.perf_counter()
  190. # query oracle
  191. if model.nc > 2:
  192. yQ = oracle_onehot[Q]
  193. else:
  194. yQ = list(oracle[Q])
  195. # update model, and calculate updated model's accuracy
  196. model.update_model(Q, yQ)
  197. if model.nc > 2:
  198. acc = get_acc_multi(np.argmax(model.f, axis=1),
  199. oracle, unlabeled=model.unlabeled)[1]
  200. else:
  201. acc = get_acc(model.f, oracle, unlabeled=model.unlabeled)[1]
  202. iter_acc.append(acc)
  203. iter_time.append(toc - tic)
  204. al_choices.append(Q)
  205. np.savez('tmp/iter_stats.npz', al_choices=np.array(al_choices), iter_acc=np.array(iter_acc), iter_time=np.array(iter_time))
  206. mlflow.log_artifact('tmp/iter_stats.npz')
  207. return
  208. def run_multi(w, v, tau, gamma, oracle, init_labeled, num_al_iters, B_per_al_iter,
  209. modelname='mgr', acq='mc', cand='rand', select_method='top', full=False,
  210. verbose=False):
  211. '''
  212. Inputs:
  213. w = eigenvalue numpy array
  214. v = eigenvectors numpy array (columns)
  215. oracle = "labels" ground truth numpy array, in {0, 1, ..., n_c} or {-1, 1}
  216. init_labeled = list of indices that are initially labeled, per ordering in oracle and rows of v
  217. num_al_iters = total number of active learning iterations to perform
  218. B_per_al_iter = batch size B that will be done on each iteration
  219. acq = string that refers to the acquisition function to be tried in this experiment
  220. Outputs:
  221. labeled : list of indices of labeled points chosen throughout whole active learning process
  222. acc : list of length (num_al_iters + 1) corresponding to the accuracies of the current classifer at each AL iteration
  223. '''
  224. if modelname not in MMODELNAMES:
  225. raise ValueError("modelname %s not in list of possible modelnames : \n%s" % (
  226. modelname, str(MMODELNAMES)))
  227. if acq not in ACQS:
  228. raise ValueError(
  229. "acq = %s is not a valid acquisition function currently implemented:\n\t%s" % (acq, str(ACQS)))
  230. N, M = v.shape
  231. if M < N:
  232. truncated = True
  233. else:
  234. truncated = False
  235. if modelname == 'mgr': # GR is implemented in the Binary model since it requires same storage structure
  236. if truncated and not full:
  237. print(
  238. "Multi %s Reduced Model -- i.e. not storing full C covariance matrix" % modelname)
  239. model = BinaryGraphBasedSSLModelReduced(
  240. modelname, gamma, tau, w=w, v=v)
  241. elif truncated and full:
  242. print("Multi %s FULL Model, but Truncated eigenvalues" % modelname)
  243. model = BinaryGraphBasedSSLModel(modelname, gamma, tau, w=w, v=v)
  244. else:
  245. print("Multi %s FULL Model, with ALL eigenvalues" % modelname)
  246. model = BinaryGraphBasedSSLModel(modelname, gamma, tau, w=w, v=v)
  247. else:
  248. print("Multi %s Reduced Model -- i.e. not storing full C covariance matrix" % modelname)
  249. model = CrossEntropyGraphBasedSSLModelReduced(gamma, tau, w=w, v=v)
  250. # calculate one-hot labels for oracle
  251. enc = OneHotEncoder()
  252. enc.fit(oracle.reshape((-1, 1)))
  253. oracle_onehot = enc.transform(oracle.reshape((-1, 1))).todense()
  254. # train the initial model, record accuracy
  255. model.calculate_model(
  256. labeled=init_labeled[:], y=oracle_onehot[init_labeled])
  257. acc = get_acc_multi(np.argmax(model.m, axis=1),
  258. oracle, unlabeled=model.unlabeled)[1]
  259. mlflow.log_metric('init_acc', acc)
  260. # instantiate ActiveLearner object
  261. print("ActiveLearner Settings:\n\tacq = \t%s\n\tcand = \t%s" % (acq, cand))
  262. print("\tselect_method = %s, B = %d" % (select_method, B_per_al_iter))
  263. AL = ActiveLearner(acquisition=acq, candidate=cand)
  264. iter_acc = []
  265. iter_time = []
  266. al_choices = []
  267. beta = 0.
  268. for al_iter in range(num_al_iters):
  269. if verbose or (al_iter % 1 == 0):
  270. print("AL Iteration %d, acc=%1.6f" % (al_iter + 1, acc))
  271. if acq in ['mcavg', 'mcavgf']:
  272. #beta = 1./(1. + al_iter // 10)
  273. beta = (1. - (al_iter/float(num_al_iters)))
  274. if beta < 0:
  275. beta = 0.0
  276. # if al_iter < 8:
  277. # beta = 1.0
  278. # else:
  279. # beta = 0.0
  280. print("\tbeta = {:.3f}".format(beta))
  281. # select query points via active learning
  282. tic = time.perf_counter()
  283. Q = AL.select_query_points(
  284. model, B_per_al_iter, method=select_method, verbose=verbose, mcavg_beta=beta)
  285. toc = time.perf_counter()
  286. # query oracle
  287. yQ = oracle_onehot[Q]
  288. # update model, and calculate updated model's accuracy
  289. model.update_model(Q, yQ)
  290. acc = get_acc_multi(np.argmax(model.m, axis=1),
  291. oracle, unlabeled=model.unlabeled)[1]
  292. iter_acc.append(acc)
  293. iter_time.append(toc - tic)
  294. al_choices.append(Q)
  295. np.savez('tmp/iter_stats.npz', al_choices=np.array(al_choices), iter_acc=np.array(iter_acc), iter_time=np.array(iter_time))
  296. mlflow.log_artifact('tmp/iter_stats.npz')
  297. return
  298. #
  299. #
  300. # def run_test(oracle, init_labeled, num_al_iters, B_per_al_iter, modelname='gr', acq='mc',
  301. # cand='rand', select_method='top', w=None, v=None, tau=0.1, gamma=0.1,
  302. # X=None, L=None, h=0.1, delta=0.1,full=False, verbose=False):
  303. #
  304. # # if modelname not in BMODELNAMES:
  305. # # raise ValueError("modelname %s not in list of possible modelnames : \n%s" % (
  306. # # modelname, str(BMODELNAMES)))
  307. # # if acq not in ACQS:
  308. # # raise ValueError(
  309. # # "acq = %s is not a valid acquisition function currently implemented:\n\t%s" % (acq, str(ACQS)))
  310. #
  311. # if v is not None:
  312. # N, M = v.shape
  313. # if M < N:
  314. # truncated = True
  315. # else:
  316. # truncated = False
  317. #
  318. # if modelname in BMODELNAMES:
  319. # assert v is not None
  320. # assert w is not None
  321. # if -1 not in np.unique(oracle):
  322. # oracle[oracle == 0] = -1
  323. # if truncated and not full:
  324. # print("Binary %s Reduced Model -- i.e. not storing full C covariance matrix" % modelname)
  325. # model = BinaryGraphBasedSSLModelReduced(
  326. # modelname, gamma, tau, w=w, v=v)
  327. # elif truncated and full:
  328. # print("Binary %s FULL Model, but Truncated eigenvalues" % modelname)
  329. # model = BinaryGraphBasedSSLModel(modelname, gamma, tau, w=w, v=v)
  330. # else:
  331. # print("Binary %s FULL Model, with ALL eigenvalues" % modelname)
  332. # model = BinaryGraphBasedSSLModel(modelname, gamma, tau, w=w, v=v)
  333. #
  334. # ylab = list(oracle[init_labeled])
  335. #
  336. # elif modelname in MMODELNAMES:
  337. # assert v is not None
  338. # assert w is not None
  339. # if modelname == 'mgr': # GR is implemented in the Binary model since it requires same storage structure
  340. # if truncated and not full:
  341. # print(
  342. # "Multi %s Reduced Model -- i.e. not storing full C covariance matrix" % modelname)
  343. # model = BinaryGraphBasedSSLModelReduced(
  344. # modelname, gamma, tau, w=w, v=v)
  345. # elif truncated and full:
  346. # print("Multi %s FULL Model, but Truncated eigenvalues" % modelname)
  347. # model = BinaryGraphBasedSSLModel(modelname, gamma, tau, w=w, v=v)
  348. # else:
  349. # print("Multi %s FULL Model, with ALL eigenvalues" % modelname)
  350. # model = BinaryGraphBasedSSLModel(modelname, gamma, tau, w=w, v=v)
  351. # else:
  352. # print("Multi %s Reduced Model -- i.e. not storing full C covariance matrix" % modelname)
  353. # model = CrossEntropyGraphBasedSSLModelReduced(gamma, tau, w=w, v=v)
  354. #
  355. # enc = OneHotEncoder()
  356. # enc.fit(oracle.reshape((-1, 1)))
  357. # oracle_onehot = enc.transform(oracle.reshape((-1, 1))).todense()
  358. # ylab = oracle_onehot[init_labeled]
  359. #
  360. # elif modelname in OTHERMODELNAMES:
  361. # if modelname == 'rkhs':
  362. # assert X is not None
  363. # assert acq == 'db'
  364. # model = RKHSClassifier(X, sigma=h) # bandwidth from Karzand paper
  365. # else:
  366. # assert L is not None
  367. # assert acq in ['vopt', 'sopt']
  368. # model = HFGraphBasedSSLModel(delta, L)
  369. #
  370. # ylab = list(oracle[init_labeled])
  371. # else:
  372. # raise ValueError("{} is not a valid model name")
  373. #
  374. #
  375. #
  376. # # train the initial model, record accuracy
  377. # model.calculate_model(labeled=init_labeled[:], y=ylab[:])
  378. # acc = get_acc(model.m, oracle, unlabeled=model.unlabeled)[1]
  379. # mlflow.log_metric('init_acc', acc)
  380. #
  381. #
  382. # # instantiate ActiveLearner object
  383. # print("ActiveLearner Settings:\n\tacq = \t%s\n\tcand = \t%s" % (acq, cand))
  384. # print("\tselect_method = %s, B = %d" % (select_method, B_per_al_iter))
  385. # AL = ActiveLearner(acquisition=acq, candidate=cand)
  386. #
  387. #
  388. # iter_acc = []
  389. # iter_time = []
  390. # al_choices = []
  391. # for al_iter in range(num_al_iters):
  392. # if verbose or (al_iter % 10 == 0):
  393. # print("AL Iteration %d, acc=%1.6f" % (al_iter + 1, acc))
  394. # # select query points via active learning
  395. # tic = time.perf_counter()
  396. # Q = AL.select_query_points(
  397. # model, B_per_al_iter, method=select_method, verbose=verbose)
  398. # toc = time.perf_counter()
  399. #
  400. # # query oracle
  401. # yQ = list(oracle[Q])
  402. #
  403. #
  404. # # update model, and calculate updated model's accuracy
  405. # model.update_model(Q, yQ)
  406. # acc = get_acc(model.m, oracle, unlabeled=model.unlabeled)[1]
  407. # iter_acc.append(acc)
  408. # iter_time.append(toc - tic)
  409. # al_choices.append(Q)
  410. #
  411. # np.savez('tmp/iter_stats.npz', al_choices=np.array(al_choices), iter_acc=np.array(iter_acc), iter_time=np.array(iter_time))
  412. # mlflow.log_artifact('tmp/iter_stats.npz')
  413. #
  414. # return
  415. #
  416. #
  417. #
  418. #
  419. #
  420. # def get_data_from_runs(acq, modelname, M, tau, gamma, cand, select_method, B, num_al_iters, runs=[1], root_filename='./'):
  421. # parent_filename = root_filename + "%s-%s-%d-%s-%s/" % (acq, modelname, M, str(tau), str(gamma))
  422. # if not os.path.exists(parent_filename):
  423. # raise ValueError("data at %s does not exist..." % parent_filename)
  424. # RUNS = {}
  425. # for run in runs:
  426. # experiment_name = "%s-%s-%d-%d-%d.txt" % (cand, select_method, B, num_al_iters, run)
  427. # if not os.path.exists(parent_filename + experiment_name):
  428. # print('Run #%d that you requested does not exist at %s, skipping' % (run, parent_filename + experiment_name))
  429. # else:
  430. # with open(parent_filename + experiment_name, 'r') as f:
  431. # for i, line in enumerate(f.readlines()):
  432. # # read in init_labeled, and initial accuracy
  433. # if i == 0:
  434. # line = line.split(',')
  435. # RUNS[run] = {'init_labeled': [int(x) for x in line[:-2]], 'acc':[float(line[-1])], 'times':[], 'choices':[]}
  436. # else:
  437. # line = line.split(',')
  438. # RUNS[run]['acc'].append(float(line[-1]))
  439. # RUNS[run]['choices'].extend(int(x) for x in line[:-2])
  440. # RUNS[run]['times'].append(float(line[-2]))
  441. #
  442. # return RUNS
  443. #
  444. # def get_avg_acc_from_runs_dict(RUNS, runs=[1]):
  445. # count = len(runs)
  446. # accs = []
  447. # for run in runs:
  448. # if run not in RUNS:
  449. # print("Run #%d not in RUNS dictionary given, skipping..." % run)
  450. # else:
  451. # accs.append(RUNS[run]['acc'])
  452. # if len(accs) == 0:
  453. # print("No valid runs found, returning None")
  454. # return
  455. # accs = np.array(accs)
  456. # return np.average(accs, axis=0), np.std(accs, axis=0)