adv.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # torch imports
  2. from torch._C import TensorType
  3. import torch.nn as nn
  4. from torch.serialization import save
  5. import torch.optim as optim
  6. # other modules
  7. import numpy as np
  8. import time
  9. from tqdm import trange
  10. import pickle
  11. # internal imports
  12. from utils.conv.sparse_rs import RSAttack
  13. from utils.conv.models import *
  14. from utils.helpers import *
  15. class adv_trainer():
  16. def __init__(self,
  17. root,
  18. k,
  19. perturb,
  20. beta,
  21. seed,
  22. save_dir,
  23. bs,
  24. num_iters,
  25. num_queries,
  26. num_epochs,
  27. device):
  28. '''
  29. Class responsible for retraining with adverarial examples using the
  30. sparse_rs framework.
  31. Inputs:
  32. root - location of parent directory for the library
  33. k - truncation param.
  34. perturb - l_0 budget for sparse_rs
  35. beta - l_inf norm parameter (scales image domain)
  36. seed - used within sparse_rs
  37. save_dir - location to output log files and models
  38. bs - batch size
  39. num_iters - times to repeat the training and attacking cycle
  40. num_queries - time budget for each attack
  41. num_epochs - how long to train during each iteration
  42. device - where to train network
  43. Outputs:
  44. Saves these files in save_path
  45. net.pth - model state_dict (saved on cpu)
  46. results.p - results as a list of strings (see self.run())
  47. f_results.p - final acc and r_acc when using the full time budget
  48. log.txt - log of all attacks while training
  49. log_final.txt - log of the final attack (used to make figures/eval)
  50. '''
  51. super(adv_trainer, self).__init__()
  52. # init params
  53. self.root = root
  54. self.k = k
  55. self.perturb = perturb
  56. self.beta = beta
  57. self.seed = seed
  58. self.save_path = root+save_dir
  59. self.bs = bs
  60. self.device = torch.device(device)
  61. self.num_queries = num_queries
  62. self.num_epochs = num_epochs
  63. self.num_iters = num_iters
  64. self.iter = 0
  65. # extra params derived
  66. self.mu, self.sigma = mu_sigma(self.beta, CIFAR=True)
  67. self.net_path = self.save_path+'net.pth'
  68. self.results_str = []
  69. pickle.dump(self.results_str, open(self.save_path+'results.p','wb'))
  70. # prep the network
  71. if k == 0:
  72. self.net = VGG().to(self.device)
  73. self.eval_net = VGG_eval(self.mu,self.sigma).to(self.device)
  74. else:
  75. self.net = rob_VGG(self.k).to(self.device)
  76. self.eval_net = rob_VGG_eval(self.k, self.mu, self.sigma).to(self.device)
  77. torch.save(self.net.state_dict(), self.net_path)
  78. # We will create optimizers later based on iteration and learning rate
  79. self.criterion = nn.CrossEntropyLoss()
  80. self.lrs = [0.2,0.1,0.05,0.025,0.01,0.005,0.0025,0.001,0.0005,0.00025]
  81. # prep the Data
  82. self.Data = prep_CIFAR(self.root, self.bs)
  83. def train(self):
  84. '''
  85. Trains self.net with self.criterion using self.optimizer.
  86. Performs self.num_epochs passes of the data, saving the weights after.
  87. '''
  88. optimizer = optim.SGD(self.net.parameters(), self.lrs[self.iter], momentum=0.9,weight_decay=5e-4)
  89. self.net.train()
  90. for _ in trange(self.num_epochs):
  91. running_loss = 0.0
  92. for inputs, labels in zip(self.Data['x_train'],self.Data['y_train']):
  93. inputs = inputs.to(self.device)
  94. labels = labels.to(self.device)
  95. optimizer.zero_grad()
  96. outputs = self.net(inputs)
  97. loss = self.criterion(outputs, labels)
  98. loss.backward()
  99. optimizer.step()
  100. running_loss += loss.item()
  101. torch.save(self.net.state_dict(), self.net_path)
  102. def test(self):
  103. '''
  104. Preforms a test on self.net using the CIFAR test dataset.
  105. Returns the clean accuracy
  106. '''
  107. self.net.eval()
  108. correct = 0
  109. total = 0
  110. with torch.no_grad():
  111. for inputs, labels in zip(self.Data['x_test'],self.Data['y_test']):
  112. inputs = inputs.to(self.device)
  113. labels = labels.to(self.device)
  114. outputs = self.net(inputs)
  115. _, predicted = torch.max(outputs.data, 1)
  116. total += labels.size(0)
  117. correct += (predicted == labels).sum().item()
  118. acc = 100 * (correct / total)
  119. return acc
  120. def r_test(self, train=False, test=False):
  121. '''
  122. Preforms an attack on the data using sparse_rs as the adversary.
  123. By default will run attack on only one batch of testset and
  124. return rob. acc. for mid training statistics.
  125. Inputs:
  126. train - If TRUE, attacks ENTIRE TRAINset, and returns adversarial examples
  127. test - If TRUE, attacks ENTIRE TESTset (for longer), only returns rob. acc.
  128. '''
  129. # load net
  130. state_dict = torch.load(self.net_path, map_location=self.device)
  131. self.eval_net.load_state_dict(state_dict)
  132. self.eval_net.eval()
  133. # setup params depending on input
  134. keys = ['x_test','y_test']
  135. batches = 1
  136. num_queries = self.num_queries
  137. log_path=self.save_path+'log.txt'
  138. if train:
  139. adv_xs = []
  140. adv_ys = []
  141. keys = ['x_og','y_og']
  142. batches = len(self.Data['x_og'])
  143. elif test:
  144. all_acc = []
  145. batches = len(self.Data['x_test'])
  146. num_queries = 5000
  147. log_path=self.save_path+'log_final.txt'
  148. # load adversary
  149. adversary = RSAttack(self.eval_net,
  150. norm='L0',
  151. eps=self.perturb,
  152. n_queries=num_queries,
  153. n_restarts=1,
  154. seed=self.seed,
  155. device=self.device,
  156. log_path=log_path
  157. )
  158. # attack over defined bathces (1 or all)
  159. for i in trange(batches):
  160. x = (self.Data[keys[0]][i].to(self.device)-self.mu)/self.sigma
  161. y = self.Data[keys[1]][i].to(self.device)
  162. with torch.no_grad():
  163. # find points originally correctly classified
  164. output = self.eval_net(x)
  165. pred = (output.max(1)[1] == y).float().to(self.device)
  166. ind_to_fool = (pred == 1).nonzero(as_tuple=False).squeeze()
  167. # preform the attack on corresponding indeces and save
  168. _, adv = adversary.perturb(x[ind_to_fool], y[ind_to_fool])
  169. # analyze the attack
  170. output = self.eval_net(adv.to(self.device))
  171. r_acc = (output.max(1)[1] == y[ind_to_fool]).float().to(self.device)
  172. adversary.logger.log('robust accuracy {:.2%}'.format(r_acc.float().mean()))
  173. # save if training
  174. if train:
  175. idx_fooled = (output.max(1)[1] != y[ind_to_fool])
  176. adv_xs.append(torch.clone(adv[idx_fooled]))
  177. adv_ys.append(torch.clone(y[ind_to_fool][idx_fooled]))
  178. # eval if testing
  179. elif test:
  180. all_acc.append(r_acc.float().mean()*100)
  181. if train:
  182. return adv_xs, adv_ys
  183. elif test:
  184. return sum(all_acc[0:-1]).item()/(len(all_acc)-1)
  185. else:
  186. return r_acc.float().mean()*100
  187. def attack_save(self):
  188. '''
  189. Goes through original train set and attacks with sprase_rs for
  190. num_queries. Saves the corresponding new examples
  191. to the trainset in Data and returns how many new examples were created
  192. '''
  193. # re initialize the dataset
  194. self.Data['x_train'] = []
  195. self.Data['y_train'] = []
  196. # get adversarial examples
  197. adv_xs,adv_ys = self.r_test(train=True)
  198. # now we update the Data (on the cpu since its large)
  199. rem_x = torch.tensor([]).float().to('cpu')
  200. rem_y = torch.tensor([]).long().to('cpu')
  201. for x,y,adv_x,adv_y in zip(self.Data['x_og'],self.Data['y_og'],adv_xs,adv_ys):
  202. adv_x=(adv_x.to('cpu')*self.sigma)+self.mu
  203. # concatenate and shuffle, storing the remainder
  204. new_x, new_y = torch.cat((x.to('cpu'),adv_x)), torch.cat((y.to('cpu'),adv_y.to('cpu')))
  205. shuffle = torch.randperm(new_x.size()[0])
  206. new_x, new_y = new_x[shuffle], new_y[shuffle]
  207. rem_x, rem_y = torch.cat((rem_x,torch.clone(new_x[self.Data['bs']:-1]))),torch.cat((rem_y,torch.clone(new_y[self.Data['bs']:-1])))
  208. # Now store the data with proper batch size
  209. self.Data['x_train'].append(torch.clone(new_x[0:self.Data['bs']]))
  210. self.Data['y_train'].append(torch.clone(new_y[0:self.Data['bs']]))
  211. # when done we want to add the remainder as well
  212. for i in range(rem_x.shape[0]//self.Data['bs']):
  213. self.Data['x_train'].append(rem_x[self.Data['bs']*i : self.Data['bs']*(i+1)])
  214. self.Data['y_train'].append(rem_y[self.Data['bs']*i : self.Data['bs']*(i+1)])
  215. return len(self.Data['y_og']), len(self.Data['y_train'])
  216. def run(self):
  217. '''
  218. Runs the retraining loop for num_iters
  219. '''
  220. while self.iter<self.num_iters:
  221. res_str = "Running iter%d, with lr: %.3f"%(self.iter,self.lrs[self.iter])
  222. print(res_str)
  223. self.results_str.append(res_str)
  224. self.train()
  225. acc = self.test()
  226. res_str = "Accuracy on testset: %.3f"%acc
  227. print(res_str)
  228. self.results_str.append(res_str)
  229. r_acc = self.r_test()
  230. res_str = "Robust Accuracy: %.3f"%r_acc
  231. print(res_str)
  232. self.results_str.append(res_str)
  233. if self.num_iters != 1:
  234. old, new = self.attack_save()
  235. res_str = "Went from %d batches to %d batches after attack"%(old, new)
  236. print(res_str)
  237. self.results_str.append(res_str)
  238. self.iter += 1
  239. pickle.dump(self.results_str, open(self.save_path+'results.p','wb'))
  240. torch.save(self.net.state_dict(), self.net_path)
  241. # Now compute final accuracy and save
  242. acc = self.test()
  243. r_acc = self.r_test(test=True)
  244. print(acc, r_acc)
  245. pickle.dump((acc,r_acc),open(self.save_path+'f_results.p','wb'))
  246. self.net.to('cpu')
  247. torch.save(self.net.state_dict(), self.net_path)
  248. print("FINISHED TRAINING AND EVALUATING")