sparse_rs.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988
  1. # Copyright (c) 2020-present
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. #
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. import torch
  12. import time
  13. import math
  14. import torch.nn.functional as F
  15. import numpy as np
  16. import copy
  17. from utils.helpers import Logger
  18. import os
  19. class RSAttack():
  20. """
  21. Sparse-RS attacks
  22. :param predict: forward pass function
  23. :param norm: type of the attack
  24. :param n_restarts: number of random restarts
  25. :param n_queries: max number of queries (each restart)
  26. :param eps: bound on the sparsity of perturbations
  27. :param seed: random seed for the starting point
  28. :param alpha_init: parameter to control alphai
  29. :param loss: loss function optimized ('margin', 'ce' supported)
  30. :param resc_schedule adapt schedule of alphai to n_queries
  31. :param device specify device to use
  32. :param log_path path to save logfile.txt
  33. :param constant_schedule use constant alphai
  34. :param targeted perform targeted attacks
  35. :param init_patches initialization for patches
  36. :param resample_loc period in queries of resampling images and
  37. locations for universal attacks
  38. :param data_loader loader to get new images for resampling
  39. :param update_loc_period period in queries of updates of the location
  40. for image-specific patches
  41. """
  42. def __init__(
  43. self,
  44. predict,
  45. norm='L0',
  46. n_queries=5000,
  47. eps=None,
  48. p_init=.8,
  49. n_restarts=1,
  50. seed=0,
  51. verbose=True,
  52. targeted=False,
  53. loss='margin',
  54. resc_schedule=True,
  55. device=None,
  56. log_path=None,
  57. constant_schedule=False,
  58. init_patches='random_squares',
  59. resample_loc=None,
  60. data_loader=None,
  61. update_loc_period=None):
  62. """
  63. Sparse-RS implementation in PyTorch
  64. """
  65. self.predict = predict
  66. self.norm = norm
  67. self.n_queries = n_queries
  68. self.eps = eps
  69. self.p_init = p_init
  70. self.n_restarts = n_restarts
  71. self.seed = seed
  72. self.verbose = verbose
  73. self.targeted = targeted
  74. self.loss = loss
  75. self.rescale_schedule = resc_schedule
  76. self.device = device
  77. self.logger = Logger(log_path)
  78. self.constant_schedule = constant_schedule
  79. self.init_patches = init_patches
  80. self.resample_loc = n_queries // 10 if resample_loc is None else resample_loc
  81. self.data_loader = data_loader
  82. self.update_loc_period = update_loc_period if not update_loc_period is None else 4 if not targeted else 10
  83. def margin_and_loss(self, x, y):
  84. """
  85. :param y: correct labels if untargeted else target labels
  86. """
  87. logits = self.predict(x)
  88. # breakpoint()
  89. # I CAHNGED THIS #
  90. # print('shapes: ',logits.shape, y.shape)
  91. # print(F.cross_entropy(logits, y, reduction='none'))
  92. # print('shape bef: ',logits.shape)
  93. if logits.dim() == 1:
  94. logits.unsqueeze_(dim=0)
  95. ## CHANGE ENDS HERE
  96. # print('shape aft: ',logits.shape)
  97. xent = F.cross_entropy(logits, y, reduction='none')
  98. u = torch.arange(x.shape[0])
  99. y_corr = logits[u, y].clone()
  100. logits[u, y] = -float('inf')
  101. # print('val:',logits)
  102. # print('right before',logits.shape)
  103. # print('max value',logits.max(dim=-1)[0])
  104. y_others = logits.max(dim=-1)[0]
  105. if not self.targeted:
  106. if self.loss == 'ce':
  107. return y_corr - y_others, -1. * xent
  108. elif self.loss == 'margin':
  109. return y_corr - y_others, y_corr - y_others
  110. else:
  111. return y_others - y_corr, xent
  112. def init_hyperparam(self, x):
  113. assert self.norm in ['L0', 'patches', 'frames',
  114. 'patches_universal', 'frames_universal']
  115. assert not self.eps is None
  116. assert self.loss in ['ce', 'margin']
  117. if self.device is None:
  118. self.device = x.device
  119. self.orig_dim = list(x.shape[1:])
  120. self.ndims = len(self.orig_dim)
  121. if self.seed is None:
  122. self.seed = time.time()
  123. if self.targeted:
  124. self.loss = 'ce'
  125. def random_target_classes(self, y_pred, n_classes):
  126. y = torch.zeros_like(y_pred)
  127. for counter in range(y_pred.shape[0]):
  128. l = list(range(n_classes))
  129. l.remove(y_pred[counter])
  130. t = self.random_int(0, len(l))
  131. y[counter] = l[t]
  132. return y.long().to(self.device)
  133. def check_shape(self, x):
  134. return x if len(x.shape) == (self.ndims + 1) else x.unsqueeze(0)
  135. def random_choice(self, shape):
  136. t = 2 * torch.rand(shape).to(self.device) - 1
  137. return torch.sign(t)
  138. def random_int(self, low=0, high=1, shape=[1]):
  139. t = low + (high - low) * torch.rand(shape).to(self.device)
  140. return t.long()
  141. def normalize(self, x):
  142. if self.norm == 'Linf':
  143. t = x.abs().view(x.shape[0], -1).max(1)[0]
  144. return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
  145. elif self.norm == 'L2':
  146. t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
  147. return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
  148. def lp_norm(self, x):
  149. if self.norm == 'L2':
  150. t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
  151. return t.view(-1, *([1] * self.ndims))
  152. def p_selection(self, it):
  153. """ schedule to decrease the parameter p """
  154. if self.rescale_schedule:
  155. it = int(it / self.n_queries * 10000)
  156. if 'patches' in self.norm:
  157. if 10 < it <= 50:
  158. p = self.p_init / 2
  159. elif 50 < it <= 200:
  160. p = self.p_init / 4
  161. elif 200 < it <= 500:
  162. p = self.p_init / 8
  163. elif 500 < it <= 1000:
  164. p = self.p_init / 16
  165. elif 1000 < it <= 2000:
  166. p = self.p_init / 32
  167. elif 2000 < it <= 4000:
  168. p = self.p_init / 64
  169. elif 4000 < it <= 6000:
  170. p = self.p_init / 128
  171. elif 6000 < it <= 8000:
  172. p = self.p_init / 256
  173. elif 8000 < it:
  174. p = self.p_init / 512
  175. else:
  176. p = self.p_init
  177. elif 'frames' in self.norm:
  178. if not 'universal' in self.norm :
  179. tot_qr = 10000 if self.rescale_schedule else self.n_queries
  180. p = max((float(tot_qr - it) / tot_qr - .5) * self.p_init * self.eps ** 2, 0.)
  181. return 3. * math.ceil(p)
  182. else:
  183. assert self.rescale_schedule
  184. its = [200, 600, 1200, 1800, 2500, 10000, 100000]
  185. resc_factors = [1., .8, .6, .4, .2, .1, 0.]
  186. c = 0
  187. while it >= its[c]:
  188. c += 1
  189. return resc_factors[c] * self.p_init
  190. elif 'L0' in self.norm:
  191. if 0 < it <= 50:
  192. p = self.p_init / 2
  193. elif 50 < it <= 200:
  194. p = self.p_init / 4
  195. elif 200 < it <= 500:
  196. p = self.p_init / 5
  197. elif 500 < it <= 1000:
  198. p = self.p_init / 6
  199. elif 1000 < it <= 2000:
  200. p = self.p_init / 8
  201. elif 2000 < it <= 4000:
  202. p = self.p_init / 10
  203. elif 4000 < it <= 6000:
  204. p = self.p_init / 12
  205. elif 6000 < it <= 8000:
  206. p = self.p_init / 15
  207. elif 8000 < it:
  208. p = self.p_init / 20
  209. else:
  210. p = self.p_init
  211. if self.constant_schedule:
  212. p = self.p_init / 2
  213. return p
  214. def sh_selection(self, it):
  215. """ schedule to decrease the parameter p """
  216. t = max((float(self.n_queries - it) / self.n_queries - .0) ** 1., 0) * .75
  217. return t
  218. def get_init_patch(self, c, s, n_iter=1000):
  219. if self.init_patches == 'stripes':
  220. patch_univ = torch.zeros([1, c, s, s]).to(self.device) + self.random_choice(
  221. [1, c, 1, s]).clamp(0., 1.)
  222. elif self.init_patches == 'uniform':
  223. patch_univ = torch.zeros([1, c, s, s]).to(self.device) + self.random_choice(
  224. [1, c, 1, 1]).clamp(0., 1.)
  225. elif self.init_patches == 'random':
  226. patch_univ = self.random_choice([1, c, s, s]).clamp(0., 1.)
  227. elif self.init_patches == 'random_squares':
  228. patch_univ = torch.zeros([1, c, s, s]).to(self.device)
  229. for _ in range(n_iter):
  230. size_init = torch.randint(low=1, high=math.ceil(s ** .5), size=[1]).item()
  231. loc_init = torch.randint(s - size_init + 1, size=[2])
  232. patch_univ[0, :, loc_init[0]:loc_init[0] + size_init, loc_init[1]:loc_init[1] + size_init] = 0.
  233. patch_univ[0, :, loc_init[0]:loc_init[0] + size_init, loc_init[1]:loc_init[1] + size_init
  234. ] += self.random_choice([c, 1, 1]).clamp(0., 1.)
  235. elif self.init_patches == 'sh':
  236. patch_univ = torch.ones([1, c, s, s]).to(self.device)
  237. return patch_univ.clamp(0., 1.)
  238. def attack_single_run(self, x, y):
  239. ### INITIAL DATA CHECK ###
  240. # print('-----'*8)
  241. # print('INITAL X',min(x[0].squeeze()),max(x[0].squeeze()))
  242. # print('-----'*8)
  243. with torch.no_grad():
  244. adv = x.clone()
  245. c, h, w = x.shape[1:]
  246. n_features = c * h * w
  247. n_ex_total = x.shape[0]
  248. if self.norm == 'L0':
  249. eps = self.eps
  250. x_best = x.clone()
  251. n_pixels = h * w
  252. b_all, be_all = torch.zeros([x.shape[0], eps]).long(), torch.zeros([x.shape[0], n_pixels - eps]).long()
  253. for img in range(x.shape[0]):
  254. # print('randperm')
  255. ind_all = torch.randperm(n_pixels)
  256. ind_p = ind_all[:eps]
  257. ind_np = ind_all[eps:]
  258. x_best[img, :, ind_p // w, ind_p % w] = self.random_choice([c, eps]).clamp(0., 1.)
  259. b_all[img] = ind_p.clone()
  260. be_all[img] = ind_np.clone()
  261. margin_min, loss_min = self.margin_and_loss(x_best, y)
  262. n_queries = torch.ones(x.shape[0]).to(self.device)
  263. for it in range(1, self.n_queries):
  264. # print(it,self.n_queries)
  265. # print(x.shape)
  266. # check points still to fool
  267. idx_to_fool = (margin_min > 0.).nonzero(as_tuple=False).squeeze()
  268. x_curr = self.check_shape(x[idx_to_fool])
  269. x_best_curr = self.check_shape(x_best[idx_to_fool])
  270. y_curr = y[idx_to_fool]
  271. margin_min_curr = margin_min[idx_to_fool]
  272. loss_min_curr = loss_min[idx_to_fool]
  273. b_curr, be_curr = b_all[idx_to_fool], be_all[idx_to_fool]
  274. # print(b_curr.shape,be_curr.shape,'inside 1')
  275. if len(y_curr.shape) == 0:
  276. y_curr.unsqueeze_(0)
  277. margin_min_curr.unsqueeze_(0)
  278. loss_min_curr.unsqueeze_(0)
  279. b_curr.unsqueeze_(0)
  280. be_curr.unsqueeze_(0)
  281. idx_to_fool.unsqueeze_(0)
  282. # build new candidate
  283. x_new = x_best_curr.clone()
  284. eps_it = max(int(self.p_selection(it) * eps), 1)
  285. ind_p = torch.randperm(eps)[:eps_it]
  286. ind_np = torch.randperm(n_pixels - eps)[:eps_it]
  287. for img in range(x_new.shape[0]):
  288. p_set = b_curr[img, ind_p]
  289. np_set = be_curr[img, ind_np]
  290. x_new[img, :, p_set // w, p_set % w] = x_curr[img, :, p_set // w, p_set % w].clone()
  291. if eps_it > 1:
  292. x_new[img, :, np_set // w, np_set % w] = self.random_choice([c, eps_it]).clamp(0., 1.)
  293. else:
  294. # if update is 1x1 make sure the sampled color is different from the current one
  295. old_clr = x_new[img, :, np_set // w, np_set % w].clone()
  296. # changed to color shape (3,1) --> 1,1
  297. assert old_clr.shape == (1, 1), print(old_clr.shape,old_clr)
  298. # assert old_clr.shape == (3, 1), print(old_clr.shape,old_clr)
  299. new_clr = old_clr.clone()
  300. while (new_clr == old_clr).all().item():
  301. new_clr = self.random_choice([1, 1]).clone().clamp(0., 1.)
  302. x_new[img, :, np_set // w, np_set % w] = new_clr.clone()
  303. #############################################################
  304. # Check exactly what is happening here with the image
  305. # print(min(x_new[0].squeeze()),max(x_new[0].squeeze()),it, 'x_new')
  306. # print(min(x_curr[0].squeeze()),max(x_curr[0].squeeze()),it, 'x_curr')
  307. # print(min(x_best_curr[0].squeeze()),max(x_best_curr[0].squeeze()),it, 'x_best_curr')
  308. #############################################################
  309. # compute loss of the new candidates
  310. margin, loss = self.margin_and_loss(x_new, y_curr)
  311. n_queries[idx_to_fool] += 1
  312. # update best solution
  313. idx_improved = (loss < loss_min_curr).float()
  314. idx_to_update = (idx_improved > 0.).nonzero(as_tuple=False).squeeze()
  315. loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
  316. idx_miscl = (margin < -1e-6).float()
  317. idx_improved = torch.max(idx_improved, idx_miscl)
  318. nimpr = idx_improved.sum().item()
  319. if nimpr > 0.:
  320. idx_improved = (idx_improved.view(-1) > 0).nonzero(as_tuple=False).squeeze()
  321. margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
  322. x_best[idx_to_fool[idx_improved]] = x_new[idx_improved].clone()
  323. t = b_curr[idx_improved].clone()
  324. te = be_curr[idx_improved].clone()
  325. if nimpr > 1:
  326. t[:, ind_p] = be_curr[idx_improved][:, ind_np] + 0
  327. te[:, ind_np] = b_curr[idx_improved][:, ind_p] + 0
  328. else:
  329. t[ind_p] = be_curr[idx_improved][ind_np] + 0
  330. te[ind_np] = b_curr[idx_improved][ind_p] + 0
  331. b_all[idx_to_fool[idx_improved]] = t.clone()
  332. be_all[idx_to_fool[idx_improved]] = te.clone()
  333. # log results current iteration
  334. ind_succ = (margin_min <= 0.).nonzero(as_tuple=False).squeeze()
  335. if self.verbose and ind_succ.numel() != 0:
  336. self.logger.log(' '.join(['{}'.format(it + 1),
  337. '- success rate={}/{} ({:.2%})'.format(
  338. ind_succ.numel(), n_ex_total,
  339. float(ind_succ.numel()) / n_ex_total),
  340. '- avg # queries={:.1f}'.format(
  341. n_queries[ind_succ].mean().item()),
  342. '- med # queries={:.1f}'.format(
  343. n_queries[ind_succ].median().item()),
  344. '- loss={:.3f}'.format(loss_min.mean()),
  345. '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
  346. ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
  347. '- epsit={:.0f}'.format(eps_it),
  348. ]))
  349. if ind_succ.numel() == n_ex_total:
  350. break
  351. elif self.norm == 'patches':
  352. ''' assumes square images and patches '''
  353. s = int(math.ceil(self.eps ** .5))
  354. x_best = x.clone()
  355. x_new = x.clone()
  356. loc = torch.randint(h - s, size=[x.shape[0], 2])
  357. patches_coll = torch.zeros([x.shape[0], c, s, s]).to(self.device)
  358. assert abs(self.update_loc_period) > 1
  359. loc_t = abs(self.update_loc_period)
  360. # set when to start single channel updates
  361. it_start_cu = None
  362. for it in range(0, self.n_queries):
  363. s_it = int(max(self.p_selection(it) ** .5 * s, 1))
  364. if s_it == 1:
  365. break
  366. it_start_cu = it + (self.n_queries - it) // 2
  367. if self.verbose:
  368. self.logger.log('starting single channel updates at query {}'.format(
  369. it_start_cu))
  370. # initialize patches
  371. if self.verbose:
  372. self.logger.log('using {} initialization'.format(self.init_patches))
  373. for counter in range(x.shape[0]):
  374. patches_coll[counter] += self.get_init_patch(c, s).squeeze().clamp(0., 1.)
  375. x_new[counter, :, loc[counter, 0]:loc[counter, 0] + s,
  376. loc[counter, 1]:loc[counter, 1] + s] = patches_coll[counter].clone()
  377. margin_min, loss_min = self.margin_and_loss(x_new, y)
  378. n_queries = torch.ones(x.shape[0]).to(self.device)
  379. for it in range(1, self.n_queries):
  380. # check points still to fool
  381. idx_to_fool = (margin_min > -1e-6).nonzero(as_tuple=False).squeeze()
  382. x_curr = self.check_shape(x[idx_to_fool])
  383. patches_curr = self.check_shape(patches_coll[idx_to_fool])
  384. y_curr = y[idx_to_fool]
  385. margin_min_curr = margin_min[idx_to_fool]
  386. loss_min_curr = loss_min[idx_to_fool]
  387. loc_curr = loc[idx_to_fool]
  388. if len(y_curr.shape) == 0:
  389. y_curr.unsqueeze_(0)
  390. margin_min_curr.unsqueeze_(0)
  391. loss_min_curr.unsqueeze_(0)
  392. loc_curr.unsqueeze_(0)
  393. idx_to_fool.unsqueeze_(0)
  394. # sample update
  395. s_it = int(max(self.p_selection(it) ** .5 * s, 1))
  396. p_it = torch.randint(s - s_it + 1, size=[2])
  397. sh_it = int(max(self.sh_selection(it) * h, 0))
  398. patches_new = patches_curr.clone()
  399. x_new = x_curr.clone()
  400. loc_new = loc_curr.clone()
  401. update_loc = int((it % loc_t == 0) and (sh_it > 0))
  402. update_patch = 1. - update_loc
  403. if self.update_loc_period < 0 and sh_it > 0:
  404. update_loc = 1. - update_loc
  405. update_patch = 1. - update_patch
  406. for counter in range(x_curr.shape[0]):
  407. if update_patch == 1.:
  408. # update patch
  409. if it < it_start_cu:
  410. if s_it > 1:
  411. patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] += self.random_choice([c, 1, 1])
  412. else:
  413. # make sure to sample a different color
  414. old_clr = patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
  415. new_clr = old_clr.clone()
  416. while (new_clr == old_clr).all().item():
  417. new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
  418. patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
  419. else:
  420. assert s_it == 1
  421. assert it >= it_start_cu
  422. # single channel updates
  423. new_ch = self.random_int(low=0, high=3, shape=[1])
  424. patches_new[counter, new_ch, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = 1. - patches_new[
  425. counter, new_ch, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it]
  426. patches_new[counter].clamp_(0., 1.)
  427. if update_loc == 1:
  428. # update location
  429. loc_new[counter] += (torch.randint(low=-sh_it, high=sh_it + 1, size=[2]))
  430. loc_new[counter].clamp_(0, h - s)
  431. x_new[counter, :, loc_new[counter, 0]:loc_new[counter, 0] + s,
  432. loc_new[counter, 1]:loc_new[counter, 1] + s] = patches_new[counter].clone()
  433. # check loss of new candidate
  434. margin, loss = self.margin_and_loss(x_new, y_curr)
  435. n_queries[idx_to_fool]+= 1
  436. # update best solution
  437. idx_improved = (loss < loss_min_curr).float()
  438. idx_to_update = (idx_improved > 0.).nonzero(as_tuple=False).squeeze()
  439. loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
  440. idx_miscl = (margin < -1e-6).float()
  441. idx_improved = torch.max(idx_improved, idx_miscl)
  442. nimpr = idx_improved.sum().item()
  443. if nimpr > 0.:
  444. idx_improved = (idx_improved.view(-1) > 0).nonzero(as_tuple=False).squeeze()
  445. margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
  446. patches_coll[idx_to_fool[idx_improved]] = patches_new[idx_improved].clone()
  447. loc[idx_to_fool[idx_improved]] = loc_new[idx_improved].clone()
  448. # log results current iteration
  449. ind_succ = (margin_min <= 0.).nonzero(as_tuple=False).squeeze()
  450. if self.verbose and ind_succ.numel() != 0:
  451. self.logger.log(' '.join(['{}'.format(it + 1),
  452. '- success rate={}/{} ({:.2%})'.format(
  453. ind_succ.numel(), n_ex_total,
  454. float(ind_succ.numel()) / n_ex_total),
  455. '- avg # queries={:.1f}'.format(
  456. n_queries[ind_succ].mean().item()),
  457. '- med # queries={:.1f}'.format(
  458. n_queries[ind_succ].median().item()),
  459. '- loss={:.3f}'.format(loss_min.mean()),
  460. '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
  461. ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
  462. #'- sit={:.0f} - sh={:.0f}'.format(s_it, sh_it),
  463. '{}'.format(' - loc' if update_loc == 1. else ''),
  464. ]))
  465. if ind_succ.numel() == n_ex_total:
  466. break
  467. # apply patches
  468. for counter in range(x.shape[0]):
  469. x_best[counter, :, loc[counter, 0]:loc[counter, 0] + s,
  470. loc[counter, 1]:loc[counter, 1] + s] = patches_coll[counter].clone()
  471. elif self.norm == 'patches_universal':
  472. ''' assumes square images and patches '''
  473. s = int(math.ceil(self.eps ** .5))
  474. x_best = x.clone()
  475. self.n_imgs = x.shape[0]
  476. x_new = x.clone()
  477. loc = torch.randint(h - s + 1, size=[x.shape[0], 2])
  478. # set when to start single channel updates
  479. it_start_cu = None
  480. for it in range(0, self.n_queries):
  481. s_it = int(max(self.p_selection(it) ** .5 * s, 1))
  482. if s_it == 1:
  483. break
  484. it_start_cu = it + (self.n_queries - it) // 2
  485. if self.verbose:
  486. self.logger.log('starting single channel updates at query {}'.format(
  487. it_start_cu))
  488. # initialize patch
  489. if self.verbose:
  490. self.logger.log('using {} initialization'.format(self.init_patches))
  491. patch_univ = self.get_init_patch(c, s)
  492. it_init = 0
  493. loss_batch = float(1e10)
  494. n_succs = 0
  495. n_iter = self.n_queries
  496. # init update batch
  497. assert not self.data_loader is None
  498. assert not self.resample_loc is None
  499. assert self.targeted
  500. new_train_imgs = []
  501. n_newimgs = self.n_imgs + 0
  502. n_imgsneeded = math.ceil(self.n_queries / self.resample_loc) * n_newimgs
  503. tot_imgs = 0
  504. if self.verbose:
  505. self.logger.log('imgs updated={}, imgs needed={}'.format(
  506. n_newimgs, n_imgsneeded))
  507. while tot_imgs < min(100000, n_imgsneeded):
  508. x_toupdatetrain, _ = next(self.data_loader)
  509. new_train_imgs.append(x_toupdatetrain)
  510. tot_imgs += x_toupdatetrain.shape[0]
  511. newimgstoadd = torch.cat(new_train_imgs, axis=0)
  512. counter_resamplingimgs = 0
  513. for it in range(it_init, n_iter):
  514. # sample size and location of the update
  515. s_it = int(max(self.p_selection(it) ** .5 * s, 1))
  516. p_it = torch.randint(s - s_it + 1, size=[2])
  517. patch_new = patch_univ.clone()
  518. if s_it > 1:
  519. patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] += self.random_choice([c, 1, 1])
  520. else:
  521. old_clr = patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
  522. new_clr = old_clr.clone()
  523. if it < it_start_cu:
  524. while (new_clr == old_clr).all().item():
  525. new_clr = self.random_choice(new_clr).clone().clamp(0., 1.)
  526. else:
  527. # single channel update
  528. new_ch = self.random_int(low=0, high=3, shape=[1])
  529. new_clr[new_ch] = 1. - new_clr[new_ch]
  530. patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
  531. patch_new.clamp_(0., 1.)
  532. # compute loss for new candidate
  533. x_new = x.clone()
  534. for counter in range(x.shape[0]):
  535. loc_new = loc[counter]
  536. x_new[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] = 0.
  537. x_new[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] += patch_new[0]
  538. margin_run, loss_run = self.margin_and_loss(x_new, y)
  539. if self.loss == 'ce':
  540. loss_run += x_new.shape[0]
  541. loss_new = loss_run.sum()
  542. n_succs_new = (margin_run < -1e-6).sum().item()
  543. # accept candidate if loss improves
  544. if loss_new < loss_batch:
  545. is_accepted = True
  546. loss_batch = loss_new + 0.
  547. patch_univ = patch_new.clone()
  548. n_succs = n_succs_new + 0
  549. else:
  550. is_accepted = False
  551. # sample new locations and images
  552. if (it + 1) % self.resample_loc == 0:
  553. newimgstoadd_it = newimgstoadd[counter_resamplingimgs * n_newimgs:(
  554. counter_resamplingimgs + 1) * n_newimgs].clone().cuda()
  555. new_batch = [x[n_newimgs:].clone(), newimgstoadd_it.clone()]
  556. x = torch.cat(new_batch, dim=0)
  557. assert x.shape[0] == self.n_imgs
  558. loc = torch.randint(h - s + 1, size=[self.n_imgs, 2])
  559. assert loc.shape == (self.n_imgs, 2)
  560. loss_batch = loss_batch * 0. + 1e6
  561. counter_resamplingimgs += 1
  562. # logging current iteration
  563. if self.verbose:
  564. self.logger.log(' '.join(['{}'.format(it + 1),
  565. '- success rate={}/{} ({:.2%})'.format(
  566. n_succs, n_ex_total,
  567. float(n_succs) / n_ex_total),
  568. '- loss={:.3f}'.format(loss_batch),
  569. '- max pert={:.0f}'.format(((x_new - x).abs() > 0
  570. ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
  571. ]))
  572. # apply patches on the initial images
  573. for counter in range(x_best.shape[0]):
  574. loc_new = loc[counter]
  575. x_best[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] = 0.
  576. x_best[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] += patch_univ[0]
  577. elif self.norm == 'frames':
  578. # set width and indices of frames
  579. mask = torch.zeros(x.shape[-2:])
  580. s = self.eps + 0
  581. mask[:s] = 1.
  582. mask[-s:] = 1.
  583. mask[:, :s] = 1.
  584. mask[:, -s:] = 1.
  585. ind = (mask == 1.).nonzero(as_tuple=False).squeeze()
  586. eps = ind.shape[0]
  587. x_best = x.clone()
  588. x_new = x.clone()
  589. mask = mask.view(1, 1, h, w).to(self.device)
  590. mask_frame = torch.ones([1, c, h, w], device=x.device) * mask
  591. #
  592. # set when starting single channel updates
  593. it_start_cu = None
  594. for it in range(0, self.n_queries):
  595. s_it = int(max(self.p_selection(it), 1))
  596. if s_it == 1:
  597. break
  598. it_start_cu = it + (self.n_queries - it) // 2
  599. #it_start_cu = 10000
  600. if self.verbose:
  601. self.logger.log('starting single channel updates at query {}'.format(
  602. it_start_cu))
  603. # initialize frames
  604. x_best[:, :, ind[:, 0], ind[:, 1]] = self.random_choice(
  605. [x.shape[0], c, eps]).clamp(0., 1.)
  606. margin_min, loss_min = self.margin_and_loss(x_best, y)
  607. n_queries = torch.ones(x.shape[0]).to(self.device)
  608. for it in range(1, self.n_queries):
  609. # check points still to fool
  610. idx_to_fool = (margin_min > -1e-6).nonzero(as_tuple=False).squeeze()
  611. x_curr = self.check_shape(x[idx_to_fool])
  612. x_best_curr = self.check_shape(x_best[idx_to_fool])
  613. y_curr = y[idx_to_fool]
  614. margin_min_curr = margin_min[idx_to_fool]
  615. loss_min_curr = loss_min[idx_to_fool]
  616. if len(y_curr.shape) == 0:
  617. y_curr.unsqueeze_(0)
  618. margin_min_curr.unsqueeze_(0)
  619. loss_min_curr.unsqueeze_(0)
  620. idx_to_fool.unsqueeze_(0)
  621. # sample update
  622. s_it = max(int(self.p_selection(it)), 1)
  623. ind_it = torch.randperm(eps)[0]
  624. x_new = x_best_curr.clone()
  625. if s_it > 1:
  626. dir_h = self.random_choice([1]).long().cpu()
  627. dir_w = self.random_choice([1]).long().cpu()
  628. new_clr = self.random_choice([c, 1]).clamp(0., 1.)
  629. for counter in range(x_curr.shape[0]):
  630. if s_it > 1:
  631. for counter_h in range(s_it):
  632. for counter_w in range(s_it):
  633. x_new[counter, :, (ind[ind_it, 0] + dir_h * counter_h).clamp(0, h - 1),
  634. (ind[ind_it, 1] + dir_w * counter_w).clamp(0, w - 1)] = new_clr.clone()
  635. else:
  636. p_it = ind[ind_it].clone()
  637. old_clr = x_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
  638. new_clr = old_clr.clone()
  639. if it < it_start_cu:
  640. while (new_clr == old_clr).all().item():
  641. new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
  642. else:
  643. # single channel update
  644. new_ch = self.random_int(low=0, high=3, shape=[1])
  645. new_clr[new_ch] = 1. - new_clr[new_ch]
  646. x_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
  647. x_new.clamp_(0., 1.)
  648. x_new = (x_new - x_curr) * mask_frame + x_curr
  649. # check loss of new candidate
  650. margin, loss = self.margin_and_loss(x_new, y_curr)
  651. n_queries[idx_to_fool]+= 1
  652. # update best solution
  653. idx_improved = (loss < loss_min_curr).float()
  654. idx_to_update = (idx_improved > 0.).nonzero(as_tuple=False).squeeze()
  655. loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
  656. idx_miscl = (margin < -1e-6).float()
  657. idx_improved = torch.max(idx_improved, idx_miscl)
  658. nimpr = idx_improved.sum().item()
  659. if nimpr > 0.:
  660. idx_improved = (idx_improved.view(-1) > 0).nonzero(as_tuple=False).squeeze()
  661. margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
  662. x_best[idx_to_fool[idx_improved]] = x_new[idx_improved].clone()
  663. # log results current iteration
  664. ind_succ = (margin_min <= 0.).nonzero(as_tuple=False).squeeze()
  665. if self.verbose and ind_succ.numel() != 0:
  666. self.logger.log(' '.join(['{}'.format(it + 1),
  667. '- success rate={}/{} ({:.2%})'.format(
  668. ind_succ.numel(), n_ex_total,
  669. float(ind_succ.numel()) / n_ex_total),
  670. '- avg # queries={:.1f}'.format(
  671. n_queries[ind_succ].mean().item()),
  672. '- med # queries={:.1f}'.format(
  673. n_queries[ind_succ].median().item()),
  674. '- loss={:.3f}'.format(loss_min.mean()),
  675. '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
  676. ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
  677. #'- min pert={:.0f}'.format(((x_new - x_curr).abs() > 0
  678. #).max(1)[0].view(x_new.shape[0], -1).sum(-1).min()),
  679. #'- sit={:.0f} - indit={}'.format(s_it, ind_it.item()),
  680. ]))
  681. if ind_succ.numel() == n_ex_total:
  682. break
  683. elif self.norm == 'frames_universal':
  684. # set width and indices of frames
  685. mask = torch.zeros(x.shape[-2:])
  686. s = self.eps + 0
  687. mask[:s] = 1.
  688. mask[-s:] = 1.
  689. mask[:, :s] = 1.
  690. mask[:, -s:] = 1.
  691. ind = (mask == 1.).nonzero(as_tuple=False).squeeze()
  692. eps = ind.shape[0]
  693. x_best = x.clone()
  694. x_new = x.clone()
  695. mask = mask.view(1, 1, h, w).to(self.device)
  696. mask_frame = torch.ones([1, c, h, w], device=x.device) * mask
  697. frame_univ = self.random_choice([1, c, eps]).clamp(0., 1.)
  698. # set when to start single channel updates
  699. it_start_cu = None
  700. for it in range(0, self.n_queries):
  701. s_it = int(max(self.p_selection(it) * s, 1))
  702. if s_it == 1:
  703. break
  704. it_start_cu = it + (self.n_queries - it) // 2
  705. if self.verbose:
  706. self.logger.log('starting single channel updates at query {}'.format(
  707. it_start_cu))
  708. self.n_imgs = x.shape[0]
  709. loss_batch = float(1e10)
  710. n_queries = torch.ones_like(y).float()
  711. # init update batch
  712. assert not self.data_loader is None
  713. assert not self.resample_loc is None
  714. assert self.targeted
  715. new_train_imgs = []
  716. n_newimgs = self.n_imgs + 0
  717. n_imgsneeded = math.ceil(self.n_queries / self.resample_loc) * n_newimgs
  718. tot_imgs = 0
  719. if self.verbose:
  720. self.logger.log('imgs updated={}, imgs needed={}'.format(
  721. n_newimgs, n_imgsneeded))
  722. while tot_imgs < min(100000, n_imgsneeded):
  723. x_toupdatetrain, _ = next(self.data_loader)
  724. new_train_imgs.append(x_toupdatetrain)
  725. tot_imgs += x_toupdatetrain.shape[0]
  726. newimgstoadd = torch.cat(new_train_imgs, axis=0)
  727. counter_resamplingimgs = 0
  728. for it in range(self.n_queries):
  729. # sample update
  730. s_it = max(int(self.p_selection(it) * self.eps), 1)
  731. ind_it = torch.randperm(eps)[0]
  732. mask_frame[:, :, ind[:, 0], ind[:, 1]] = 0
  733. mask_frame[:, :, ind[:, 0], ind[:, 1]] += frame_univ
  734. if s_it > 1:
  735. dir_h = self.random_choice([1]).long().cpu()
  736. dir_w = self.random_choice([1]).long().cpu()
  737. new_clr = self.random_choice([c, 1]).clamp(0., 1.)
  738. for counter_h in range(s_it):
  739. for counter_w in range(s_it):
  740. mask_frame[0, :, (ind[ind_it, 0] + dir_h * counter_h).clamp(0, h - 1),
  741. (ind[ind_it, 1] + dir_w * counter_w).clamp(0, w - 1)] = new_clr.clone()
  742. else:
  743. p_it = ind[ind_it]
  744. old_clr = mask_frame[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
  745. new_clr = old_clr.clone()
  746. if it < it_start_cu:
  747. while (new_clr == old_clr).all().item():
  748. new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
  749. else:
  750. # single channel update
  751. new_ch = self.random_int(low=0, high=3, shape=[1])
  752. new_clr[new_ch] = 1. - new_clr[new_ch]
  753. mask_frame[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
  754. frame_new = mask_frame[:, :, ind[:, 0], ind[:, 1]].clone()
  755. frame_new.clamp_(0., 1.)
  756. if len(frame_new.shape) == 2:
  757. frame_new.unsqueeze_(0)
  758. x_new[:, :, ind[:, 0], ind[:, 1]] = 0.
  759. x_new[:, :, ind[:, 0], ind[:, 1]] += frame_new
  760. margin_run, loss_run = self.margin_and_loss(x_new, y)
  761. if self.loss == 'ce':
  762. loss_run += x_new.shape[0]
  763. loss_new = loss_run.sum()
  764. n_succs_new = (margin_run < -1e-6).sum().item()
  765. # accept candidate if loss improves
  766. if loss_new < loss_batch:
  767. #is_accepted = True
  768. loss_batch = loss_new + 0.
  769. frame_univ = frame_new.clone()
  770. n_succs = n_succs_new + 0
  771. # sample new images
  772. if (it + 1) % self.resample_loc == 0:
  773. newimgstoadd_it = newimgstoadd[counter_resamplingimgs * n_newimgs:(
  774. counter_resamplingimgs + 1) * n_newimgs].clone().cuda()
  775. new_batch = [x[n_newimgs:].clone(), newimgstoadd_it.clone()]
  776. x = torch.cat(new_batch, dim=0)
  777. assert x.shape[0] == self.n_imgs
  778. loss_batch = loss_batch * 0. + 1e6
  779. x_new = x.clone()
  780. counter_resamplingimgs += 1
  781. # loggin current iteration
  782. if self.verbose:
  783. self.logger.log(' '.join(['{}'.format(it + 1),
  784. '- success rate={}/{} ({:.2%})'.format(
  785. n_succs, n_ex_total,
  786. float(n_succs) / n_ex_total),
  787. '- loss={:.3f}'.format(loss_batch),
  788. '- max pert={:.0f}'.format(((x_new - x).abs() > 0
  789. ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
  790. ]))
  791. # apply frame on initial images
  792. x_best[:, :, ind[:, 0], ind[:, 1]] = 0.
  793. x_best[:, :, ind[:, 0], ind[:, 1]] += frame_univ
  794. return n_queries, x_best
  795. def perturb(self, x, y=None):
  796. """
  797. :param x: clean images
  798. :param y: untargeted attack -> clean labels,
  799. if None we use the predicted labels
  800. targeted attack -> target labels, if None random classes,
  801. different from the predicted ones, are sampled
  802. """
  803. #SUPER INIT CHECK
  804. # print('-----'*8)
  805. # print("SUPER CHECK",min(x[4].squeeze()),max(x[4].squeeze()))
  806. # print('-----'*8)
  807. self.init_hyperparam(x)
  808. adv = x.clone()
  809. qr = torch.zeros([x.shape[0]]).to(self.device)
  810. if y is None:
  811. if not self.targeted:
  812. with torch.no_grad():
  813. output = self.predict(x)
  814. y_pred = output.max(1)[1]
  815. y = y_pred.detach().clone().long().to(self.device)
  816. else:
  817. with torch.no_grad():
  818. output = self.predict(x)
  819. n_classes = output.shape[-1]
  820. y_pred = output.max(1)[1]
  821. y = self.random_target_classes(y_pred, n_classes)
  822. else:
  823. y = y.detach().clone().long().to(self.device)
  824. if not self.targeted:
  825. acc = self.predict(x).max(1)[1] == y
  826. else:
  827. acc = self.predict(x).max(1)[1] != y
  828. startt = time.time()
  829. torch.random.manual_seed(self.seed)
  830. torch.cuda.random.manual_seed(self.seed)
  831. np.random.seed(self.seed)
  832. # print('-----'*8)
  833. # print("SUPER 1.5 CHECK",min(x[4].squeeze()),max(x[4].squeeze()))
  834. # print('-----'*8)
  835. for counter in range(self.n_restarts):
  836. ind_to_fool = acc.nonzero(as_tuple=False).squeeze()
  837. if len(ind_to_fool.shape) == 0:
  838. ind_to_fool = ind_to_fool.unsqueeze(0)
  839. if ind_to_fool.numel() != 0:
  840. x_to_fool = x[ind_to_fool].clone()
  841. y_to_fool = y[ind_to_fool].clone()
  842. # print('-----'*8)
  843. # print("SUPER 2 CHECK",min(x_to_fool[0].squeeze()),max(x_to_fool[0].squeeze()))
  844. # print('-----'*8)
  845. qr_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool)
  846. output_curr = self.predict(adv_curr)
  847. if not self.targeted:
  848. acc_curr = output_curr.max(1)[1] == y_to_fool
  849. else:
  850. acc_curr = output_curr.max(1)[1] != y_to_fool
  851. ind_curr = (acc_curr == 0).nonzero(as_tuple=False).squeeze()
  852. acc[ind_to_fool[ind_curr]] = 0
  853. adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone()
  854. qr[ind_to_fool[ind_curr]] = qr_curr[ind_curr].clone()
  855. # if self.verbose:
  856. # print('restart {} - robust accuracy: {:.2%}'.format(
  857. # counter, acc.float().mean()),
  858. # '- cum. time: {:.1f} s'.format(
  859. # time.time() - startt))
  860. return qr, adv