activelearner.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Active Learner object class
  2. # author: Kevin Miller, ksmill327@gmail.com
  3. #
  4. # Need to implement acquisition_values function more efficiently, record values if debug is on.
  5. import numpy as np
  6. #from .dijkstra import *
  7. from .al_util import *
  8. from .acquisitions import *
  9. from .rkhs import *
  10. ACQUISITIONS = ['mc', 'uncertainty', 'rand', 'vopt', 'sopt', 'mbr', 'mcgreedy', 'db', 'mcavg', 'mcavgf', 'mcf']
  11. # MODELS = ['gr', 'probit-log', 'probit-norm', 'softmax', 'log', 'probitnorm']
  12. CANDIDATES = ['rand', 'full', 'dijkstra']
  13. SELECTION_METHODS = ['top', 'prop', '']
  14. def sgn(x):
  15. if x >= 0:
  16. return 1.
  17. else:
  18. return -1.
  19. def acquisition_values(acq, Cand, model, mcavg_beta=0.0):
  20. if acq == "mc":
  21. if model.full_storage:
  22. return mc_full(Cand, model.m, model.C, model.modelname, gamma=model.gamma)
  23. else:
  24. return mc_reduced(model.C_a, model.alpha, model.v[Cand,:], model.modelname, uks=model.m[Cand], gamma=model.gamma)
  25. elif acq == "mcgreedy":
  26. return mc_reduced(model.C_a, model.alpha, model.v[Cand,:], model.modelname, uks=model.m[Cand], gamma=model.gamma, greedy=True)
  27. elif acq == "mcavg":
  28. if not model.full_storage:
  29. return mc_avg_reduced(model, Cand, beta=mcavg_beta)
  30. elif acq == "mcavgf":
  31. if not model.full_storage:
  32. return mcavg_app_full_red(model, Cand, beta=mcavg_beta)
  33. elif acq == "mcf":
  34. if not model.full_storage:
  35. return mc_app_full_red(model, Cand)
  36. elif acq == "uncertainty":
  37. if len(model.m.shape) > 1: # entropy calculation
  38. #print('multi unc')
  39. probs = np.exp(model.m[Cand])
  40. probs /= np.sum(probs, axis=1)[:, np.newaxis]
  41. return -np.sum(probs*np.log(probs), axis=1)
  42. else:
  43. return -np.absolute(model.m[Cand]) # ensuring a "max" formulation for acquisition values
  44. elif acq == "rand":
  45. return np.random.rand(len(Cand))
  46. elif acq == "vopt":
  47. if model.modelname == 'hf':
  48. return model.vopt_vals(Cand)
  49. else:
  50. if model.full_storage:
  51. ips = np.array([np.inner(model.C[k,:], model.C[k,:]) for k in Cand]).flatten()
  52. if model.modelname in ["gr", "mgr"]:
  53. return ips/(model.gamma**2. + np.diag(model.C)[Cand])
  54. if model.modelname == "probit-norm":
  55. return ips * np.array([hess_calc(model.m[k], sgn(model.m[k]), model.gamma)/ \
  56. (hess_calc(model.m[k], sgn(model.m[k]), model.gamma)*model.C[k,k] + 1.) for k in Cand])
  57. if model.modelname == "probit-log":
  58. return ips * np.array([hess_calc2(model.m[k], sgn(model.m[k]), model.gamma)/ \
  59. (hess_calc2(model.m[k], sgn(model.m[k]), model.gamma)*model.C[k,k] + 1.) for k in Cand])
  60. else:
  61. uks = model.m[Cand]
  62. C_a_vk = model.C_a @ (model.v[Cand,:].T)
  63. ips = np.array([np.inner(C_a_vk[:,i],C_a_vk[:,i]) for i in range(len(Cand))])
  64. if model.modelname in ["gr", "mgr"]:
  65. return ips / (model.gamma**2. + np.array([np.inner(model.v[Cand[i],:], C_a_vk[:,i]) for i in range(len(Cand))]))
  66. if model.modelname == 'probit-log' or model.modelname == 'log':
  67. return ips * np.array([hess_calc2(model.m[k], sgn(model.m[k]), model.gamma)/ \
  68. (hess_calc2(model.m[k], sgn(model.m[k]), model.gamma)*np.inner(model.v[Cand,:][i,:], C_a_vk[:,i]) + 1.) for i,k in enumerate(Cand)])
  69. if model.modelname == 'probit-norm' or model.modelname == 'probitnorm':
  70. return ips * np.array([hess_calc(model.m[k], sgn(model.m[k]), model.gamma)/ \
  71. (hess_calc(model.m[k], sgn(model.m[k]), model.gamma)*np.inner(model.v[Cand,:][i,:], C_a_vk[:,i]) + 1.) for i,k in enumerate(Cand)])
  72. elif acq == "sopt":
  73. if model.modelname == 'hf':
  74. return model.sopt_vals(Cand)
  75. else:
  76. if model.full_storage:
  77. sums = np.sum(model.C[Cand,:], axis=1).flatten()**2.
  78. if model.modelname in ["gr", "mgr"]:
  79. return sums/(model.gamma**2. + np.diag(model.C)[Cand])
  80. if model.modelname == "probit-norm":
  81. return sums * np.array([hess_calc(model.m[k], sgn(model.m[k]), model.gamma)/ \
  82. (hess_calc(model.m[k], sgn(model.m[k]), model.gamma)*model.C[k,k] + 1.) for k in Cand])
  83. if model.modelname == "probit-log":
  84. return sums * np.array([hess_calc2(model.m[k], sgn(model.m[k]), model.gamma)/ \
  85. (hess_calc2(model.m[k], sgn(model.m[k]), model.gamma)*model.C[k,k] + 1.) for k in Cand])
  86. else:
  87. uks = model.m[Cand]
  88. C_a_vk = model.C_a @ (model.v[Cand,:].T)
  89. VTones = np.sum(model.v, axis=0).flatten()
  90. tops = np.array([np.inner(VTones, C_a_vk[:,i]) for i in range(len(Cand))])**2.
  91. if model.modelname in ['gr', 'mgr']:
  92. return tops / (model.gamma**2. + np.array([np.inner(model.v[Cand[i],:], C_a_vk[:,i]) for i in range(len(Cand))]))
  93. if model.modelname == 'probit-log' or model.modelname == 'log':
  94. return tops * np.array([hess_calc2(model.m[k], sgn(model.m[k]), model.gamma)/ \
  95. (hess_calc2(model.m[k], sgn(model.m[k]), model.gamma)*np.inner(model.v[Cand,:][i,:], C_a_vk[:,i]) + 1.) for i,k in enumerate(Cand)])
  96. if model.modelname == 'probit-norm' or model.modelname == 'probitnorm':
  97. return tops * np.array([hess_calc(model.m[k], sgn(model.m[k]), model.gamma)/ \
  98. (hess_calc(model.m[k], sgn(model.m[k]), model.gamma)*np.inner(model.v[Cand,:][i,:], C_a_vk[:,i]) + 1.) for i,k in enumerate(Cand)])
  99. elif acq == "mbr":
  100. raise NotImplementedError()
  101. elif acq == "db":
  102. if model.modelname != 'rkhs':
  103. raise NotImplementedError("Databased norm is for RKHS model only")
  104. else:
  105. return model.look_ahead_db_norms(Cand)
  106. else:
  107. raise ValueError("Acquisition function %s not yet implemented" % str(acq))
  108. return
  109. class ActiveLearner(object):
  110. def __init__(self, acquisition='mc', candidate='full', candidate_frac=0.1, W=None, r=None):
  111. if acquisition not in ACQUISITIONS:
  112. raise ValueError("Acquisition function name %s not valid, must be in %s" % (str(acquisition), str(ACQUISITIONS)))
  113. self.acquisition = acquisition
  114. if candidate not in CANDIDATES:
  115. raise ValueError("Candidate Set Selection name %s not valid, must be in %s" % (str(candidate), str(CANDIDATES)))
  116. self.candidate = candidate
  117. if (candidate_frac < 0. or candidate_frac > 1. ) and self.candidate == 'rand':
  118. print("WARNING: Candidate fraction must be between 0 and 1 for 'rand' candidate selection, setting to default 0.1")
  119. self.candidate_frac = 0.1
  120. else:
  121. self.candidate_frac = candidate_frac
  122. # if modelname not in MODELS:
  123. # raise ValueError("Model name %s not valid, must be in %s" % (str(modelname), str(MODELS)))
  124. # self.modelname = modelname
  125. if self.candidate == 'dijkstra':
  126. self.W = W
  127. if self.W is None:
  128. raise ValueError("Candidate set selection %s requires W to be non-empty" % candidate)
  129. self.DIST = {}
  130. if r is None:
  131. self.dijkstra_r = 5.0
  132. else:
  133. self.dijkstra_r = r
  134. # else:
  135. # # If weight matrix is passed to ActiveLearner but not doing Dijkstra, ignore it
  136. # if self.W is not None:
  137. # self.W = None
  138. def select_query_points(self, model, B=1, method='top', prop_func=None, prop_sigma=0.8, mcavg_beta=0.0, debug=False, verbose=False):
  139. if method not in SELECTION_METHODS:
  140. raise ValueError("Selection method %s not valid, must be one of %s" % (method, SELECTION_METHODS))
  141. if verbose:
  142. print("Active Learner settings:")
  143. print("\tacquisition function = %s" % self.acquisition)
  144. print("\tB = %d" % B)
  145. print("\tcandidate set = %s" % self.candidate)
  146. print("\tselection method = %s" % method)
  147. # Define the candidate set
  148. if self.candidate is "rand":
  149. Cand = np.random.choice(model.unlabeled, size=int(self.candidate_frac * len(model.unlabeled)), replace=False)
  150. elif self.candidate is "dijkstra":
  151. raise NotImplementedError("Have not implemented the dikstra candidate selection for this class")
  152. else:
  153. Cand = model.unlabeled
  154. if debug:
  155. self.Cand = Cand
  156. # Compute acquisition values -- save as object attribute for later plotting
  157. self.acq_vals = acquisition_values(self.acquisition, Cand, model, mcavg_beta=mcavg_beta)
  158. if len(self.acq_vals.shape) > 1:
  159. print("WARNING: acq_vals is of shape %s, should be one-dimensional. MIGHT CAUSE PROBLEM" % str(self.acq_vals.shape))
  160. # based on selection method, choose query points
  161. if B == 1:
  162. if method != 'top':
  163. print("Warning : B = 1 but election method is not 'top'. Overriding selection method and selecting top choice for query point.")
  164. return [Cand[np.argmax(self.acq_vals)]]
  165. else:
  166. if method == 'top':
  167. return [Cand[k] for k in (-self.acq_vals).argsort()[:B]]
  168. elif method == 'prop':
  169. if prop_func is None:
  170. # if not given a customized proportionality sampling function, use this default.
  171. # (1) normalize to be 0 to 1
  172. acq_vals = (self.acq_vals - np.min(self.acq_vals))/(np.max(self.acq_vals) - np.min(self.acq_vals))
  173. p = np.exp(acq_vals/prop_sigma)
  174. p /= np.sum(p)
  175. else:
  176. p = prop_func(self.acq_vals)
  177. if debug:
  178. return list(np.random.choice(Cand, B, replace=False, p=p)), p, self.acq_vals, Cand
  179. return list(np.random.choice(Cand, B, replace=False, p=p))
  180. else:
  181. raise ValueError("Have not implemented this selection method, %s. Somehow got passed other parameter checks..." % method)