adv.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # torch imports
  2. from torch._C import TensorType
  3. import torch.nn as nn
  4. from torch.serialization import save
  5. import torch.optim as optim
  6. # other modules
  7. import numpy as np
  8. import time
  9. from tqdm import trange
  10. import pickle
  11. # internal imports
  12. from utils.lin.sparse_rs import RSAttack
  13. from utils.lin.models import *
  14. from utils.helpers import *
  15. class adv_trainer():
  16. def __init__(self,
  17. root,
  18. k,
  19. perturb,
  20. beta,
  21. seed,
  22. save_dir,
  23. bs,
  24. num_iters,
  25. num_queries,
  26. num_epochs,
  27. device):
  28. '''
  29. Class responsible for retraining with adverarial examples using the
  30. sparse_rs framework.
  31. Inputs:
  32. root - location of parent directory for the library
  33. k - truncation param.
  34. perturb - l_0 budget for sparse_rs
  35. beta - l_inf norm parameter (scales image domain)
  36. seed - used within sparse_rs
  37. save_dir - location to output log files and models
  38. bs - batch size
  39. num_iters - times to repeat the training and attacking cycle
  40. num_queries - time budget for each attack
  41. num_epochs - how long to train during each iteration
  42. device - where to train network
  43. Outputs:
  44. Saves these files in save_path
  45. net.pth - model state_dict (saved on cpu)
  46. results.p - results as a list of strings (see self.run())
  47. f_results.p - final acc and r_acc when using the full time budget
  48. log.txt - log of all attacks while training
  49. log_final.txt - log of the final attack (used to make figures/eval)
  50. '''
  51. super(adv_trainer, self).__init__()
  52. # init params
  53. self.root = root
  54. self.k = k
  55. self.perturb = perturb
  56. self.beta = beta
  57. self.seed = seed
  58. self.save_path = root+save_dir
  59. self.bs = bs
  60. self.device = torch.device(device)
  61. self.num_queries = num_queries
  62. self.num_epochs = num_epochs
  63. self.num_iters = num_iters
  64. self.iter = 0
  65. # extra params derived
  66. self.mu, self.sigma = mu_sigma(self.beta)
  67. self.net_path = self.save_path+'net.pth'
  68. self.results_str = []
  69. pickle.dump(self.results_str, open(self.save_path+'results.p','wb'))
  70. # prep the network
  71. if k == 0:
  72. self.net = L_Net().to(self.device)
  73. self.eval_net = L_Net_eval(self.mu,self.sigma).to(self.device)
  74. else:
  75. self.net = r_L_Net(self.k).to(self.device)
  76. self.eval_net = r_L_Net_eval(self.k, self.mu,self.sigma).to(self.device)
  77. torch.save(self.net.state_dict(), self.net_path)
  78. # prep the training utils
  79. self.criterion = nn.CrossEntropyLoss()
  80. self.optimizer = optim.SGD(self.net.parameters(), lr=0.001, momentum=0.9)
  81. # prep the data
  82. self.Data = prep_MNIST(self.root, bs)
  83. def train(self):
  84. '''
  85. Trains self.net with self.criterion using self.optimizer.
  86. Performs self.num_epochs passes of the data, saving the weights after.
  87. '''
  88. self.net.train()
  89. for _ in trange(self.num_epochs):
  90. running_loss = 0.0
  91. for inputs, labels in zip(self.Data['x_train'],self.Data['y_train']):
  92. inputs = inputs.to(self.device)
  93. labels = labels.to(self.device)
  94. self.optimizer.zero_grad()
  95. outputs = self.net(inputs)
  96. loss = self.criterion(outputs, labels)
  97. loss.backward()
  98. self.optimizer.step()
  99. running_loss += loss.item()
  100. torch.save(self.net.state_dict(), self.net_path)
  101. def test(self):
  102. '''
  103. Preforms a test on self.net using the MNIST test dataset.
  104. Returns the clean accuracy
  105. '''
  106. self.net.eval()
  107. correct = 0
  108. total = 0
  109. with torch.no_grad():
  110. for inputs, labels in zip(self.Data['x_test'],self.Data['y_test']):
  111. inputs = inputs.to(self.device)
  112. labels = labels.to(self.device)
  113. outputs = self.net(inputs)
  114. _, predicted = torch.max(outputs.data, 1)
  115. total += labels.size(0)
  116. correct += (predicted == labels).sum().item()
  117. acc = 100 * (correct / total)
  118. return acc
  119. def r_test(self, train=False, test=False):
  120. '''
  121. Preforms an attack on the data using sparse_rs as the adversary.
  122. By default will run attack on only one batch of testset and
  123. return rob. acc. for mid training statistics.
  124. Inputs:
  125. train - If TRUE, attacks ENTIRE TRAINset, and returns adversarial examples
  126. test - If TRUE, attacks ENTIRE TESTset (for longer), only returns rob. acc.
  127. '''
  128. # load net
  129. state_dict = torch.load(self.net_path, map_location=self.device)
  130. self.eval_net.load_state_dict(state_dict)
  131. self.eval_net.eval()
  132. # setup params depending on input
  133. keys = ['x_test','y_test']
  134. batches = 1
  135. num_queries = self.num_queries
  136. log_path=self.save_path+'log.txt'
  137. if train:
  138. adv_xs = []
  139. adv_ys = []
  140. keys = ['x_og','y_og']
  141. batches = len(self.Data['x_og'])
  142. elif test:
  143. all_acc = []
  144. batches = len(self.Data['x_test'])
  145. num_queries = 5000
  146. log_path=self.save_path+'log_final.txt'
  147. # load adversary
  148. adversary = RSAttack(self.eval_net,
  149. norm='L0',
  150. eps=self.perturb,
  151. n_queries=num_queries,
  152. n_restarts=1,
  153. seed=self.seed,
  154. device=self.device,
  155. log_path=log_path
  156. )
  157. # attack over defined bathces (1 or all)
  158. for i in trange(batches):
  159. x = (self.Data[keys[0]][i].to(self.device)-self.mu)/self.sigma
  160. y = self.Data[keys[1]][i].to(self.device)
  161. with torch.no_grad():
  162. # find points originally correctly classified
  163. output = self.eval_net(x)
  164. pred = (output.max(1)[1] == y).float().to(self.device)
  165. ind_to_fool = (pred == 1).nonzero(as_tuple=False).squeeze()
  166. # preform the attack on corresponding indeces and save
  167. _, adv = adversary.perturb(x[ind_to_fool], y[ind_to_fool])
  168. # analyze the attack
  169. output = self.eval_net(adv.to(self.device))
  170. r_acc = (output.max(1)[1] == y[ind_to_fool]).float().to(self.device)
  171. adversary.logger.log('robust accuracy {:.2%}'.format(r_acc.float().mean()))
  172. # save if training
  173. if train:
  174. idx_fooled = (output.max(1)[1] != y[ind_to_fool])
  175. adv_xs.append(torch.clone(adv[idx_fooled]))
  176. adv_ys.append(torch.clone(y[ind_to_fool][idx_fooled]))
  177. # eval if testing
  178. elif test:
  179. all_acc.append(r_acc.float().mean()*100)
  180. if train:
  181. return adv_xs, adv_ys
  182. elif test:
  183. return sum(all_acc[0:-1]).item()/(len(all_acc)-1)
  184. else:
  185. return r_acc.float().mean()*100
  186. def attack_save(self):
  187. '''
  188. Goes through original train set and attacks with sprase_rs for
  189. num_queries. Saves the corresponding new examples
  190. to the trainset in Data and returns how many new examples were created
  191. '''
  192. # re initialize the dataset
  193. self.Data['x_train'] = []
  194. self.Data['y_train'] = []
  195. # get adversarial examples
  196. adv_xs,adv_ys = self.r_test(train=True)
  197. # now we update the Data (on the cpu since its large)
  198. rem_x = torch.tensor([]).float().to('cpu')
  199. rem_y = torch.tensor([]).long().to('cpu')
  200. for x,y,adv_x,adv_y in zip(self.Data['x_og'],self.Data['y_og'],adv_xs,adv_ys):
  201. adv_x=(adv_x.to('cpu')*self.sigma)+self.mu
  202. # concatenate and shuffle, storing the remainder
  203. new_x, new_y = torch.cat((x.to('cpu'),adv_x)), torch.cat((y.to('cpu'),adv_y.to('cpu')))
  204. shuffle = torch.randperm(new_x.size()[0])
  205. new_x, new_y = new_x[shuffle], new_y[shuffle]
  206. rem_x = torch.cat((rem_x,torch.clone(new_x[self.Data['bs']:-1])))
  207. rem_y = torch.cat((rem_y,torch.clone(new_y[self.Data['bs']:-1])))
  208. # Now store the data with proper batch size
  209. self.Data['x_train'].append(torch.clone(new_x[0:self.Data['bs']]))
  210. self.Data['y_train'].append(torch.clone(new_y[0:self.Data['bs']]))
  211. # when done we want to add the remainder as well
  212. for i in range(rem_x.shape[0]//self.Data['bs']):
  213. self.Data['x_train'].append(rem_x[self.Data['bs']*i : self.Data['bs']*(i+1)])
  214. self.Data['y_train'].append(rem_y[self.Data['bs']*i : self.Data['bs']*(i+1)])
  215. return len(self.Data['y_og']), len(self.Data['y_train'])
  216. def run(self):
  217. '''
  218. Runs the retraining loop for num_iters
  219. '''
  220. while self.iter<self.num_iters:
  221. res_str = "Running iter%d"%(self.iter)
  222. print(res_str)
  223. self.results_str.append(res_str)
  224. self.train()
  225. acc = self.test()
  226. res_str = "Accuracy on testset: %.3f"%acc
  227. print(res_str)
  228. self.results_str.append(res_str)
  229. r_acc = self.r_test()
  230. res_str = "Robust Accuracy: %.3f"%r_acc
  231. print(res_str)
  232. self.results_str.append(res_str)
  233. if self.num_iters != 1:
  234. old, new = self.attack_save()
  235. res_str = "Went from %d batches to %d batches after attack"%(old, new)
  236. print(res_str)
  237. self.results_str.append(res_str)
  238. self.iter += 1
  239. pickle.dump(self.results_str, open(self.save_path+'results.p','wb'))
  240. torch.save(self.net.state_dict(), self.net_path)
  241. # Now compute final accuracy and save
  242. acc = self.test()
  243. r_acc = self.r_test(test=True)
  244. print(acc, r_acc)
  245. pickle.dump((acc,r_acc),open(self.save_path+'f_results.p','wb'))
  246. self.net.to('cpu')
  247. torch.save(self.net.state_dict(), self.net_path)
  248. print("FINISHED TRAINING AND EVALUATING")