network_run.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # mlflow utilization
  2. import numpy as np
  3. import scipy.sparse as sps
  4. import argparse
  5. import time
  6. from sklearn.model_selection import train_test_split
  7. import copy
  8. import os
  9. import math
  10. import sys
  11. import matplotlib.pyplot as plt
  12. from util.graph_manager import GraphManager
  13. from util.runs_util import *
  14. import mlflow
  15. from util.mlflow_util import *
  16. # ACQ_MODELS = ['vopt--mgr', 'sopt--mgr', 'sopt--hf', 'mc--mgr', 'mcgreedy--ce', 'mc--ce', \
  17. # 'rand--ce', 'rand--mgr', 'vopt--hf', 'uncertainty--mgr', 'uncertainty--ce']
  18. #ACQ_MODELS = ['mcgreedy--ce']
  19. ACQ_MODELS = ['vopt--mgr']
  20. if __name__ == "__main__":
  21. parser = argparse.ArgumentParser(description = 'Run Active Learning experiment on Network datasets, defaults to PubMed (named pubmed)')
  22. parser.add_argument('--data_root', default='./data/pubmed/', type=str, help='Location of data X with labels (X_labels.npz), eigendata(eig.npz) and weight matrix W (W.npz).')
  23. parser.add_argument('--num_eigs', default=50, dest='M', type=int, help='Number of eigenvalues for spectral truncation')
  24. parser.add_argument('--tau-gr', default=0.01, dest='tau_gr', type=float, help='value of diagonal perturbation and scaling of MGR (not HF)')
  25. parser.add_argument('--gamma-gr', default=0.1, dest='gamma_gr', type=float, help='value of noise parameter of MGR (not HF)')
  26. parser.add_argument('--tau-ce', default=0.01, dest='tau_ce', type=float, help='value of diagonal perturbation and scaling of CE model')
  27. parser.add_argument('--gamma-ce', default=0.1, dest='gamma_ce', type=float, help='value of noise parameter of CE model')
  28. parser.add_argument('--delta', default=0.01, type=float, help='value of diagonal perturbation of unnormalized graph Laplacian for HF model.')
  29. parser.add_argument('--B', default=5, type=int, help='batch size for AL iterations')
  30. parser.add_argument('--al_iters', default=11, type=int, help='number of active learning iterations to perform.')
  31. parser.add_argument('--candidate-method', default='rand', type=str, dest='cand', help='candidate set selection method name ["rand", "full"]')
  32. parser.add_argument('--candidate-percent', default=0.1, type=float, dest='cand_perc', help='if --candidate-method == "rand", then this is the percentage of unlabeled data to consider')
  33. parser.add_argument('--select_method', default='top', type=str, help='how to select which points to query from the acquisition values. in ["top", "prop"]')
  34. parser.add_argument('--lab-start', default=3, type=int, dest='lab_start', help='size of initially labeled set.')
  35. parser.add_argument('--runs', default=5, type=int, help='Number of trials to run')
  36. parser.add_argument('--name', default='pubmed', dest='experiment_name', help='Name for this dataset/experiment run ')
  37. args = parser.parse_args()
  38. print("data_root is {} and experiment name is {} --------".format(args.data_root, args.experiment_name))
  39. if not os.path.exists('tmp/'):
  40. os.makedirs('tmp/')
  41. # Load in the Network Dataset
  42. if not os.path.exists(args.data_root + 'X_labels.npz'):
  43. raise ValueError("Cannot find previously saved data at {}".format(args.data_root + 'X_labels.npz'))
  44. print("Loading data at {}".format(args.data_root + 'X_labels.npz'))
  45. data = np.load(args.data_root + 'X_labels.npz', allow_pickle=True)
  46. labels = data['labels']
  47. N = labels.shape[0]
  48. if args.lab_start < len(np.unique(labels)):
  49. print("Number of initial points specified ({}) not enough to at least represent each class ({}), increasing the number of initially labeled points...".format(args.lab_start, len(np.unique(labels))))
  50. args.lab_start = len(np.unique(labels))
  51. if not os.path.exists(args.data_root + 'eig.npz'):
  52. raise ValueError("Cannot find previously saved data at {}".format(args.data_root + 'eig.npz'))
  53. print("Loading graph data at {}".format(args.data_root + 'eig.npz'))
  54. eig_data = np.load(args.data_root + 'eig.npz', allow_pickle=True)
  55. evals, evecs = eig_data['evals'], eig_data['evecs']
  56. evals, evecs = evals[:args.M], evecs[:,:args.M]
  57. # If we are doing a run with the HF model, we need the unnormalized graph Laplacian
  58. L = None
  59. if 'hf' in ''.join(ACQ_MODELS):
  60. W = sps.load_npz(args.data_root + 'W.npz')
  61. gm = GraphManager()
  62. L = sps.csr_matrix(gm.compute_laplacian(W, normalized=False)) + args.delta**2. * sps.eye(N)
  63. # Run the experiments
  64. print("--------------- Parameters for the Run of Experiments -----------------------")
  65. print("\tacq_models = %s" % str(ACQ_MODELS))
  66. print("\tal_iters = %d, B = %d, M = %d" % (args.al_iters, args.B, args.M))
  67. print("\tcand=%s, select_method=%s" % (args.cand, args.select_method))
  68. print("\tnum_init_labeled = %d" % (args.lab_start))
  69. print("\ttau = %1.6f, gamma = %1.6f, tau_ce = %1.6f, gamma_ce = %1.6f" % (args.tau_gr, args.gamma_gr, args.tau_ce, args.gamma_ce))
  70. print("\tdelta = {:.6f}".format(args.delta))
  71. print("\tnumber of runs = {}".format(args.runs))
  72. print("\n\n")
  73. ans = input("Do you want to proceed with this test?? [y/n] ")
  74. while ans not in ['y','n']:
  75. ans = input("Sorry, please input either 'y' or 'n'")
  76. if ans == 'n':
  77. print("Not running test, exiting...")
  78. else:
  79. client = mlflow.tracking.MlflowClient()
  80. mlflow.set_experiment(args.experiment_name)
  81. experiment = client.get_experiment_by_name(args.experiment_name)
  82. if experiment is not None:
  83. print("Looks like you've already run this experiment name previously... Are you sure you want to continue? [y/n]")
  84. ans = input()
  85. if ans not in ['y', 'yes']:
  86. raise ValueError("Exited test")
  87. for i, seed in enumerate(j**2 + 3 for j in range(args.runs)):
  88. np.random.seed(seed)
  89. init_labeled, unlabeled = train_test_split(np.arange(N), train_size=args.lab_start, stratify=labels)#list(np.random.choice(range(N), 10, replace=False))
  90. init_labeled, unlabeled = list(init_labeled), list(unlabeled)
  91. params_shared = {
  92. 'init_labeled': init_labeled,
  93. 'run': i,
  94. 'al_iters' : args.al_iters,
  95. 'B' : args.B,
  96. 'cand' : args.cand,
  97. 'select' : args.select_method
  98. }
  99. query = 'attributes.status = "FINISHED"'
  100. for key, val in params_shared.items():
  101. query += ' and params.{} = "{}"'.format(key, val)
  102. already_completed = [run.data.tags['mlflow.runName'] for run in client.search_runs([experiment.experiment_id], filter_string=query)]
  103. if len(already_completed) > 0:
  104. print("Run {} already completed:".format(i))
  105. for thing in sorted(already_completed, key= lambda x : x[0]):
  106. print("\t", thing)
  107. print()
  108. np.save('tmp/init_labeled', init_labeled)
  109. for acq, model in (am.split('--') for am in ACQ_MODELS):
  110. if model == 'hf':
  111. run_name = "{}-{}-{:.2f}-{}".format(acq, model, args.delta, i)
  112. elif model == 'ce':
  113. run_name = "{}-{}-{:.2f}-{:.2f}-{}-{}".format(acq, model, args.tau_ce, args.gamma_ce, args.M, i)
  114. else:
  115. run_name = "{}-{}-{:.2f}-{:.2f}-{}-{}".format(acq, model, args.tau_gr, args.gamma_gr, args.M, i)
  116. if run_name not in already_completed:
  117. labeled = copy.deepcopy(init_labeled)
  118. with mlflow.start_run(run_name=run_name) as run:
  119. # run AL test
  120. mlflow.log_params(params_shared)
  121. mlflow.log_artifact('tmp/init_labeled.npy')
  122. if model == 'ce':
  123. mlflow.log_params({
  124. 'tau' : args.tau_ce,
  125. 'gamma' : args.gamma_ce,
  126. 'M' : args.M
  127. })
  128. run_multi(evals, evecs, args.tau_ce, args.gamma_ce, labels, labeled, args.al_iters, args.B,
  129. modelname=model, acq=acq, cand=args.cand, select_method=args.select_method, verbose=False)
  130. elif model == 'mgr':
  131. mlflow.log_params({
  132. 'tau' : args.tau_gr,
  133. 'gamma' : args.gamma_gr,
  134. 'M' : args.M
  135. })
  136. run_multi(evals, evecs, args.tau_gr, args.gamma_gr, labels, labeled, args.al_iters, args.B,
  137. modelname=model, acq=acq, cand=args.cand, select_method=args.select_method, verbose=False)
  138. elif model == 'hf':
  139. mlflow.log_param('delta', args.delta)
  140. run_rkhs_hf(labels, labeled, args.al_iters, args.B, delta=args.delta, L=L,
  141. modelname=model, acq=acq, cand=args.cand, select_method=args.select_method, verbose=False)
  142. else:
  143. raise ValueError("{} is not a valid multiclass model".format(model))
  144. # Clean up tmp file
  145. print("Cleaning up files in ./tmp/")
  146. if os.path.exists('tmp/init_labeled.npy'):
  147. os.remove('tmp/init_labeled.npy')
  148. if os.path.exists('tmp/iter_stats.npz'):
  149. os.remove('tmp/iter_stats.npz')