rkhs.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # Python implementation of the Laplace Kernel RKHS data-based norm of Karzand and Nowak
  2. # @author : Kevin Miller
  3. '''
  4. Currently only implemented for binary (as is the case in Karzand and Nowak paper)
  5. So current implementation doesn't scale well, because we are calculating the Gram matrix on the go.
  6. * Try doing Nystrom on Gram matrix, use this to compare?
  7. * Could use the full matrix, and show how long it takes then? This will probably do better than our method?
  8. THOUGHT- Should we just phrase my method as a generalization and speed improvement of Nowak and Karzand?
  9. * Then all the tests their better performance can be put into perspective.
  10. '''
  11. import numpy as np
  12. from scipy.spatial.distance import cdist
  13. from argparse import ArgumentParser
  14. import matplotlib.pyplot as plt
  15. import pickle
  16. class RKHSClassifierOld(object):
  17. def __init__(self, X, sigma):
  18. self.X = X
  19. self.N = self.X.shape[0]
  20. self.sigma = sigma
  21. self.f = None
  22. def calculate_model(self, labeled, y):
  23. self.labeled = labeled
  24. self.unlabeled = list(filter(lambda x: x not in self.labeled, range(self.N)))
  25. self.y = y
  26. self.K_inv = np.linalg.inv(np.exp(-cdist(self.X[self.labeled, :], self.X[self.labeled,:]) / self.sigma))
  27. self.f = np.empty(self.N)
  28. self.f[self.labeled] = self.y
  29. self.K_ul = np.exp(-cdist(self.X[self.unlabeled, :], self.X[self.labeled,:]) / self.sigma)
  30. self.f[self.unlabeled] = self.K_ul @ self.K_inv @ np.array(self.y) # Kul @ Kll^{-1} y
  31. def update_model_old(self, Q, yQ):
  32. len_lab_old = len(self.labeled)
  33. if self.f is None:
  34. print("No previous model calculated, so we will do initial calculation")
  35. self.calculate_model(Q, yQ)
  36. return
  37. self.labeled += Q
  38. self.y += yQ
  39. aQ = np.exp(-cdist(self.X[self.labeled,:], self.X[Q, :])/ self.sigma)
  40. aQ1 = aQ[:len_lab_old,:]
  41. aQ2 = aQ[-len(Q):,:]
  42. Z = np.linalg.inv(aQ2 - aQ1.T @ self.K_inv @ aQ1)
  43. A12 = self.K_inv @ aQ1 @ Z
  44. A11 = A12 @ aQ1.T
  45. A11.ravel()[::len_lab_old+1] += 1.
  46. A11 = A11 @ self.K_inv
  47. self.K_inv = np.hstack((A11, -A12))
  48. self.K_inv = np.vstack((self.K_inv, np.hstack((-A12.T, Z))))
  49. unl_Qi, unl_Q = zip(*list(filter(lambda x: x[1] not in Q, enumerate(self.unlabeled))))
  50. K_Qu_Q = np.exp(-cdist(self.X[unl_Q,:], self.X[Q,:])/self.sigma)
  51. self.K_ul = np.hstack((self.K_ul[unl_Qi, :], K_Qu_Q))
  52. self.unlabeled = list(filter(lambda x: x not in self.labeled, range(self.N)))
  53. self.f[Q] = np.array(yQ)
  54. self.f[self.unlabeled] = self.K_ul @ self.K_inv @ np.array(self.y) # Kul @ Kll^{-1} y
  55. return
  56. def update_model(self, Q, yQ):
  57. unl_Q = [self.unlabeled.index(k) for k in Q]
  58. unl_notQ = list(filter(lambda x: x not in unl_Q, range(len(self.unlabeled))))
  59. notQ = list(filter(lambda x: x not in Q, self.unlabeled))
  60. # update f
  61. SQ_inv = np.linalg.inv(np.exp(-cdist(self.X[Q,:], self.X[Q,:])/self.sigma))
  62. K_notQ_Q = np.exp(-cdist(self.X[notQ,:], self.X[Q,:])/self.sigma)
  63. sc = K_notQ_Q - self.K_ul[unl_notQ,:] @ self.K_inv @ self.K_ul[unl_Q,:].T
  64. self.f[notQ] += sc @ SQ_inv @ (np.array(yQ) - self.f[Q])
  65. self.f[Q] = yQ
  66. # update self.K_inv, self.K_ul for future calculations
  67. # calculate self.K_inv new by block matrix inversion formula
  68. Mat = self.K_ul[unl_Q,:].T @ SQ_inv @ self.K_ul[unl_Q,:] @ self.K_inv
  69. Mat.ravel()[::len(self.labeled)+1] += 1.0
  70. new_Kinv_lower_left = -SQ_inv @ self.K_ul[unl_Q,:] @ self.K_inv
  71. bottom_new_Kinv = np.hstack((new_Kinv_lower_left, SQ_inv))
  72. self.K_inv = self.K_inv @ Mat
  73. self.K_inv = np.hstack((self.K_inv, new_Kinv_lower_left.T))
  74. self.K_inv = np.vstack((self.K_inv, bottom_new_Kinv))
  75. # self.K_ul
  76. self.K_ul = self.K_ul[unl_notQ,:]
  77. self.K_ul = np.hstack((self.K_ul, K_notQ_Q))
  78. self.labeled.extend(Q)
  79. self.unlabeled = notQ
  80. return
  81. # def look_ahead_db_norm(self, k):
  82. # ak = self.K_ul[self.unlabeled.index(k), :]
  83. # Z = 1./(1. - np.inner(ak, self.K_inv @ ak))
  84. # A12 = self.K_inv @ ak * Z
  85. # akK_inv = self.K_inv @ ak
  86. # A11 = self.K_inv + Z * (np.outer(akK_inv, akK_inv))
  87. # K_inv_new = np.hstack((A11, -A12[:,np.newaxis]))
  88. # K_inv_new = np.vstack((K_inv_new, np.hstack((-A12[np.newaxis,:], np.array([[Z]])))))
  89. # f_k = self.f.copy()
  90. # f_k[k] = np.sign(self.f[k])# f_k [k] = yk = lowest interpolating label
  91. # unl_i, unl_k = zip(*list(filter(lambda x : x[1] != k, enumerate(self.unlabeled))))
  92. # K_ku_k = np.exp(-cdist(self.X[unl_k,:], self.X[k,:][np.newaxis,:])/self.sigma)
  93. # f_k[list(unl_k)] = np.hstack((self.K_ul[unl_i, :], K_ku_k)) @ K_inv_new \
  94. # @ np.array(self.y + [f_k[k]])
  95. # return np.linalg.norm(f_k - self.f)#**2./len(self.unlabeled)
  96. #
  97. # def look_ahead_db_norm2(self, k):
  98. # modelk = RKHSClassifier(self.X, self.sigma)
  99. # modelk.calculate_model(self.labeled[:] + [k], self.y[:] + [np.sign(self.f[k])])
  100. #
  101. # return np.linalg.norm(modelk.f - self.f)
  102. def look_ahead_db_norms(self, Cand):
  103. unl_Cand = [self.unlabeled.index(k) for k in Cand]
  104. KCandl = self.K_ul[unl_Cand,:]
  105. sc = np.exp(-cdist(self.X[self.unlabeled, :], self.X[Cand,:])/self.sigma) \
  106. - self.K_ul @ self.K_inv @ KCandl.T
  107. return (1. - np.absolute(self.f[Cand]))*np.linalg.norm(sc, axis=0)/sc[unl_Cand, range(len(Cand))]
  108. vals = []
  109. for ii, i in enumerate(unl_Cand):
  110. vals.append((1. - np.absolute(self.f[self.unlabeled[i]]))/(1. - B[i,ii]) * np.linalg.norm(B[:,ii]))
  111. return np.array(vals)
  112. # def look_ahead_db_norms2(self, Cand):
  113. # unl_Cand = list(filter(lambda x: self.unlabeled[x] in Cand, range(len(self.unlabeled))))
  114. # sc_submatrix = np.exp(-cdist(self.X[self.unlabeled, :], self.X[Cand,:])/self.sigma) \
  115. # - self.K_ul @ self.K_inv @ self.K_ul[unl_Cand,:].T
  116. # return np.abs(1. - self.f[Cand]) * np.linalg.norm(sc_submatrix, axis=0) / np.abs(sc_submatrix[unl_Cand, range(len(Cand))])
  117. # def look_ahead_db_norms2(self, Cand):
  118. # unl_Cand = [self.unlabeled.index(k) for k in Cand]
  119. # sc = np.exp(-cdist(self.X[self.unlabeled, :], self.X[self.unlabeled,:])/self.sigma) \
  120. # - self.K_ul @ self.K_inv @ self.K_ul.T
  121. # return (1. - np.absolute(self.f[Cand])) * np.linalg.norm(sc[:,unl_Cand], axis=0) / np.abs(sc[unl_Cand, unl_Cand])
  122. class RKHSClassifier(object):
  123. def __init__(self, X, sigma):
  124. self.N = X.shape[0]
  125. self.sigma = sigma
  126. self.f = None
  127. self.K = np.exp(-cdist(X, X)/self.sigma) # calculate full, dense kernel matrix upfront
  128. self.modelname = 'rkhs'
  129. self.nc = 2
  130. def calculate_model(self, labeled, y):
  131. self.labeled = labeled
  132. self.unlabeled = list(filter(lambda x: x not in self.labeled, range(self.N)))
  133. self.y = y
  134. self.K_inv = np.linalg.inv(self.K[np.ix_(self.labeled, self.labeled)])
  135. self.f = np.empty(self.N)
  136. self.f[self.labeled] = self.y
  137. self.K_ul = self.K[np.ix_(self.unlabeled, self.labeled)]
  138. self.f[self.unlabeled] = self.K_ul @ self.K_inv @ np.array(self.y) # Kul @ Kll^{-1} y
  139. def update_model(self, Q, yQ):
  140. unl_Q = [self.unlabeled.index(k) for k in Q]
  141. unl_notQ = list(filter(lambda x: x not in unl_Q, range(len(self.unlabeled))))
  142. notQ = list(filter(lambda x: x not in Q, self.unlabeled))
  143. # update f
  144. SQ_inv = np.linalg.inv(self.K[np.ix_(Q, Q)])
  145. K_notQ_Q = self.K[np.ix_(notQ, Q)]
  146. sc = K_notQ_Q - self.K_ul[unl_notQ,:] @ self.K_inv @ self.K_ul[unl_Q,:].T
  147. self.f[notQ] += sc @ SQ_inv @ (np.array(yQ) - self.f[Q])
  148. self.f[Q] = yQ
  149. # update self.K_inv, self.K_ul for future calculations
  150. # calculate self.K_inv new by block matrix inversion formula
  151. Mat = self.K_ul[unl_Q,:].T @ SQ_inv @ self.K_ul[unl_Q,:] @ self.K_inv
  152. Mat.ravel()[::len(self.labeled)+1] += 1.0
  153. new_Kinv_lower_left = -SQ_inv @ self.K_ul[unl_Q,:] @ self.K_inv
  154. bottom_new_Kinv = np.hstack((new_Kinv_lower_left, SQ_inv))
  155. self.K_inv = self.K_inv @ Mat
  156. self.K_inv = np.hstack((self.K_inv, new_Kinv_lower_left.T))
  157. self.K_inv = np.vstack((self.K_inv, bottom_new_Kinv))
  158. # self.K_ul
  159. self.K_ul = self.K_ul[unl_notQ,:]
  160. self.K_ul = np.hstack((self.K_ul, K_notQ_Q))
  161. self.labeled.extend(Q)
  162. self.unlabeled = notQ
  163. return
  164. def look_ahead_db_norms(self, Cand):
  165. unl_Cand = [self.unlabeled.index(k) for k in Cand]
  166. KCandl = self.K_ul[unl_Cand,:]
  167. sc = self.K[np.ix_(self.unlabeled, Cand)] - self.K_ul @ self.K_inv @ KCandl.T
  168. return (1. - np.absolute(self.f[Cand]))*np.linalg.norm(sc, axis=0)/sc[unl_Cand, range(len(Cand))]
  169. def return_parts(self, Cand):
  170. unl_Cand = [self.unlabeled.index(k) for k in Cand]
  171. KCandl = self.K_ul[unl_Cand,:]
  172. sc = self.K[np.ix_(self.unlabeled, Cand)] - self.K_ul @ self.K_inv @ KCandl.T
  173. return (1. - np.absolute(self.f[Cand])), np.linalg.norm(sc, axis=0), 1./sc[unl_Cand, range(len(Cand))]
  174. if __name__ == "__main__":
  175. import sys
  176. sys.path.append('..')
  177. import os
  178. from runs_util import get_acc
  179. parser = ArgumentParser(description="Read in previous RKHS run and check classifier")
  180. parser.add_argument("--loc", default='../checker2/db-rkhs-2000-0.1-0.1/rand-top-5-100-1.txt', type=str)
  181. parser.add_argument("--Xloc", default='../checker2/X_labels.npz', type=str)
  182. parser.add_argument("--sigma", default='0.1', type=str)
  183. args = parser.parse_args()
  184. print(float(args.sigma))
  185. labeled = []
  186. with open(args.loc, 'r') as f:
  187. for i, line in enumerate(f.readlines()):
  188. # read in init_labeled, and initial accuracy
  189. line = line.split(',')
  190. labeled.extend([int(x) for x in line[:-2]])
  191. if i == 0:
  192. num_init = len(labeled)
  193. lab_set = set(labeled)
  194. print(len(lab_set), len(labeled))
  195. data = np.load(args.Xloc, allow_pickle=True)
  196. X, labels = data['X'], data['labels']
  197. model = RKHSClassifier(X, sigma=float(args.sigma))
  198. model_new = RKHSClassifierNew(X, sigma=float(args.sigma))
  199. model.calculate_model(labeled[:20], list(labels[labeled[:20]]))
  200. model_new.calculate_model(labeled[:20], list(labels[labeled[:20]]))
  201. assert np.allclose(model.f, model_new.f)
  202. Cand = list(np.random.choice(model.unlabeled, 50))
  203. orig_vals = model.look_ahead_db_norms(Cand[:])
  204. new_vals = model_new.look_ahead_db_norms(Cand[:])
  205. assert np.allclose(orig_vals, new_vals)
  206. model.update_model(Cand[:5], list(labels[Cand[:5]]))
  207. model_new.update_model(Cand[:5], list(labels[Cand[:5]]))
  208. assert np.allclose(model.f, model_new.f)
  209. print("passed all tests!")
  210. # for i in np.arange(0,100,10):
  211. # K = num_init+i*5
  212. # model = RKHSClassifier(X, sigma=float(args.sigma))
  213. # model.calculate_model(labeled[:K], list(labels[labeled[:K]]))
  214. # model2 = RKHSClassifier(X, sigma=float(args.sigma))
  215. # model2.calculate_model(labeled[:K], list(labels[labeled[:K]]))
  216. #
  217. #
  218. #
  219. # assert np.allclose(model.f, model2.f)
  220. #
  221. # Q = list(np.random.choice(model.unlabeled, 5))
  222. #
  223. # model3 = RKHSClassifier(X, sigma=float(args.sigma))
  224. # model3.calculate_model(labeled[:K] + Q, list(labels[labeled[:K]]) + list(labels[Q]))
  225. #
  226. # #print(list(labels[Q]))
  227. # model.update_model_old(Q, list(labels[Q]))
  228. # model2.update_model(Q, list(labels[Q]))
  229. #
  230. # if np.allclose(model.f, model2.f):
  231. # print("both models are close in the update")
  232. # else:
  233. # print("true to old - " + str(np.allclose(model.f, model3.f)))
  234. # print("true to new - " + str(np.allclose(model2.f, model3.f)))
  235. # print(np.linalg.norm(model.f - model2.f)/np.linalg.norm(model.f))
  236. # # print(labels[model.labeled], model.f[model.labeled])
  237. # # print(labels[model2.labeled], model.f[model2.labeled])
  238. # plt.scatter(range(model.N), model3.f, label='true', marker='^')
  239. # plt.scatter(range(model.N), model.f, label='old', marker='x')
  240. # plt.scatter(range(model.N), model2.f, label='new', marker='.')
  241. #
  242. # plt.legend()
  243. # plt.savefig('./comp-%d.png' % i)
  244. # plt.show(0)
  245. # plt.close()
  246. # print()
  247. # #assert np.allclose(model.f, model2.f)
  248. #
  249. # # print(os.path.exists('rkhs-model-0.npz'))
  250. # # data = np.load('rkhs-model-%d.npz' % i)
  251. # # print(type(data))
  252. # # print(list(data.keys()))
  253. # # saved_f, saved_lab = data['f'], data['lab']
  254. # # print(model.labeled)
  255. # # print(saved_lab)
  256. # # print(np.allclose(saved_f, model.f))
  257. # # print(K, len(model.labeled), get_acc(model.f, labels, unlabeled=model.unlabeled)[1], get_acc(saved_f, labels, unlabeled=model.unlabeled)[1])
  258. # # fig, (ax1, ax2) = plt.subplots(1,2)
  259. # # ax1.scatter(X[:, 0], X[:, 1], c=labels)
  260. # # ax1.set_title("Ground Truth")
  261. # # ax2.scatter(X[:,0], X[:,1], c=np.sign(model.f))
  262. # # ax2.scatter(X[labeled[:K],0], X[labeled[:K],1], c='k', marker='^')
  263. # # ax2.set_title("Calculated -- acc = {:.4f}".format(get_acc(model.f, labels, unlabeled=model.unlabeled)[1]))
  264. # # plt.savefig('./check2-rkhs-%d.png' % K)
  265. # # plt.show(0)