binary_clusters_run.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # mlflow utilization
  2. import numpy as np
  3. import scipy.sparse as sps
  4. import argparse
  5. import time
  6. from sklearn.model_selection import train_test_split
  7. import copy
  8. import os
  9. import math
  10. import sys
  11. import matplotlib.pyplot as plt
  12. from util.graph_manager import GraphManager
  13. from util.runs_util import *
  14. import mlflow
  15. from util.mlflow_util import *
  16. ACQ_MODELS = ['vopt--gr', 'sopt--gr', 'db--rkhs', 'mc--gr', 'mc--log', 'mc--probitnorm', 'sopt--hf', 'vopt--hf',
  17. 'uncertainty--gr', 'uncertainty--log', 'uncertainty--probitnorm', 'rand--gr', 'rand--log', 'rand--probitnorm']
  18. GRAPH_PARAMS = {
  19. 'knn' :10,
  20. 'sigma' : 3.,
  21. 'normalized' : True,
  22. 'zp_k' : 5
  23. }
  24. if __name__ == "__main__":
  25. parser = argparse.ArgumentParser(description = 'Run Active Learning experiment on Binary Clusters dataset')
  26. parser.add_argument('--data-root', default='./data/binary_clusters/', dest='data_root', type=str, help='Location of data X with labels.')
  27. parser.add_argument('--num-eigs', default=50, dest='M', type=int, help='Number of eigenvalues for spectral truncation')
  28. parser.add_argument('--tau', default=0.005, type=float, help='value of diagonal perturbation and scaling of GBSSL models (minus HF)')
  29. parser.add_argument('--gamma', default=0.1, type=float, help='value of noise parameter to be shared across all GBSSL models (minus HF)')
  30. parser.add_argument('--delta', default=0.01, type=float, help='value of diagonal perturbation of unnormalized graph Laplacian for HF model.')
  31. parser.add_argument('--h', default=0.1, type=float, help='kernel width for RKHS model.')
  32. parser.add_argument('--B', default=5, type=int, help='batch size for AL iterations')
  33. parser.add_argument('--al-iters', default=100, type=int, dest='al_iters', help='number of active learning iterations to perform.')
  34. parser.add_argument('--candidate-method', default='rand', type=str, dest='cand', help='candidate set selection method name ["rand", "full"]')
  35. parser.add_argument('--candidate-percent', default=0.1, type=float, dest='cand_perc', help='if --candidate-method == "rand", then this is the percentage of unlabeled data to consider')
  36. parser.add_argument('--select-method', default='top', type=str, dest='select_method', help='how to select which points to query from the acquisition values. in ["top", "prop"]')
  37. parser.add_argument('--runs', default=5, type=int, help='Number of trials to run')
  38. parser.add_argument('--lab-start', default=2, dest='lab_start', type=int, help='Number of initially labeled points.')
  39. parser.add_argument('--metric', default='euclidean', type=str, help='metric name ("euclidean" or "cosine") for graph construction')
  40. parser.add_argument('--name', default='binary-clusters', dest='experiment_name', help='Name for this dataset/experiment run ')
  41. args = parser.parse_args()
  42. GRAPH_PARAMS['n_eigs'] = args.M
  43. GRAPH_PARAMS['metric'] = args.metric
  44. if not os.path.exists('tmp/'):
  45. os.makedirs('tmp/')
  46. # Load in or Create the Dataset
  47. if not os.path.exists(args.data_root + 'X_labels.npz'):
  48. print("Cannot find previously saved data at {}".format(args.data_root + 'X_labels.npz'))
  49. print("so creating the dataset and labels")
  50. X, labels = create_binary_clusters()
  51. N = X.shape[0]
  52. os.makedirs(args.data_root)
  53. np.savez(args.data_root + 'X_labels.npz', X=X, labels=labels)
  54. else:
  55. data = np.load(args.data_root + 'X_labels.npz')
  56. X, labels = data['X'], data['labels']
  57. N = X.shape[0]
  58. labels[labels == 0] = -1
  59. # Load in or calculate eigenvectors, using mlflow IN Graph_manager
  60. gm = GraphManager()
  61. evals, evecs = gm.from_features(X, knn=GRAPH_PARAMS['knn'], sigma=GRAPH_PARAMS['sigma'],
  62. normalized=GRAPH_PARAMS['normalized'], n_eigs=GRAPH_PARAMS['n_eigs'],
  63. zp_k=GRAPH_PARAMS['zp_k'], metric=GRAPH_PARAMS['metric']) # runs mlflow logging in this function call
  64. print(evals[:6])
  65. # If we are doing a run with the HF model, we need the unnormalized graph Laplacian
  66. L = None
  67. if 'hf' in ''.join(ACQ_MODELS):
  68. prev_run = get_prev_run('GraphManager.from_features',
  69. GRAPH_PARAMS,
  70. tags={"X":str(X), "N":str(X.shape[0])},
  71. git_commit=None)
  72. url_data = urllib.parse.urlparse(os.path.join(prev_run.info.artifact_uri,
  73. 'W.npz'))
  74. path = urllib.parse.unquote(url_data.path)
  75. W = sps.load_npz(path)
  76. L = sps.csr_matrix(gm.compute_laplacian(W, normalized=False)) + args.delta**2. * sps.eye(N)
  77. # Run the experiments
  78. print("--------------- Parameters for the Run of Experiments -----------------------")
  79. print("\tacq_models = %s" % str(ACQ_MODELS))
  80. print("\tal_iters = %d, B = %d, M = %d" % (args.al_iters, args.B, args.M))
  81. print("\tcand=%s, select_method=%s" % (args.cand, args.select_method))
  82. print("\tnum_init_labeled = %d" % (args.lab_start))
  83. print("\ttau = %1.6f, gamma = %1.6f, delta = %1.6f, h = %1.6f" % (args.tau, args.gamma, args.delta, args.h))
  84. print("\tnumber of runs = {}".format(args.runs))
  85. print("\n\n")
  86. ans = input("Do you want to proceed with this test?? [y/n] ")
  87. while ans not in ['y','n']:
  88. ans = input("Sorry, please input either 'y' or 'n'")
  89. if ans == 'n':
  90. print("Not running test, exiting...")
  91. else:
  92. client = mlflow.tracking.MlflowClient()
  93. mlflow.set_experiment(args.experiment_name)
  94. experiment = client.get_experiment_by_name(args.experiment_name)
  95. for i, seed in enumerate(j**2 + 3 for j in range(args.runs)):
  96. print("=======================================")
  97. print("============= Run {}/{} ===============".format(i+1, args.runs))
  98. print("=======================================")
  99. np.random.seed(seed)
  100. init_labeled, unlabeled = train_test_split(np.arange(N), train_size=2, stratify=labels)#list(np.random.choice(range(N), 10, replace=False))
  101. init_labeled, unlabeled = list(init_labeled), list(unlabeled)
  102. params_shared = {
  103. 'init_labeled': init_labeled,
  104. 'run': i,
  105. 'al_iters' : args.al_iters,
  106. 'B' : args.B,
  107. 'cand' : args.cand,
  108. 'select' : args.select_method
  109. }
  110. query = 'attributes.status = "FINISHED"'
  111. for key, val in params_shared.items():
  112. query += ' and params.{} = "{}"'.format(key, val)
  113. already_completed = [run.data.tags['mlflow.runName'] for run in client.search_runs([experiment.experiment_id], filter_string=query)]
  114. if len(already_completed) > 0:
  115. print("Run {} already completed:".format(i))
  116. for thing in sorted(already_completed, key= lambda x : x[0]):
  117. print("\t", thing)
  118. print()
  119. np.save('tmp/init_labeled', init_labeled)
  120. for acq, model in (am.split('--') for am in ACQ_MODELS):
  121. if model == 'hf':
  122. run_name = "{}-{}-{:.2f}-{}".format(acq, model, args.delta, i)
  123. elif model == 'rkhs':
  124. run_name = "{}-{}-{:.2}-{}".format(acq, model, args.h, i)
  125. else:
  126. run_name = "{}-{}-{:.3f}-{:.3f}-{}-{}".format(acq, model, args.tau, args.gamma, args.M, i)
  127. if run_name not in already_completed:
  128. labeled = copy.deepcopy(init_labeled)
  129. with mlflow.start_run(run_name=run_name) as run:
  130. # run AL test
  131. mlflow.log_params(params_shared)
  132. mlflow.log_artifact('tmp/init_labeled.npy')
  133. if model not in ['hf', 'rkhs']:
  134. mlflow.log_params({
  135. 'tau' : args.tau,
  136. 'gamma' : args.gamma,
  137. 'M' : args.M
  138. })
  139. run_binary(evals, evecs, args.tau, args.gamma, labels, labeled, args.al_iters, args.B,
  140. modelname=model, acq=acq, cand=args.cand, select_method=args.select_method, verbose=False)
  141. else:
  142. if model == 'hf':
  143. mlflow.log_param('delta', args.delta)
  144. else:
  145. mlflow.log_param('h', args.h)
  146. run_rkhs_hf(labels, labeled, args.al_iters, args.B, h=args.h, delta=args.delta, X=X, L=L,
  147. modelname=model, acq=acq, cand=args.cand, select_method=args.select_method, verbose=False)
  148. # Clean up tmp file
  149. print("Cleaning up files in ./tmp/")
  150. if os.path.exists('tmp/init_labeled.npy'):
  151. os.remove('tmp/init_labeled.npy')
  152. if os.path.exists('tmp/iter_stats.npz'):
  153. os.remove('tmp/iter_stats.npz')