sparse_rs.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974
  1. # Copyright (c) 2020-present
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. #
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. import torch
  12. import time
  13. import math
  14. import torch.nn.functional as F
  15. import numpy as np
  16. import copy
  17. from utils.helpers import Logger
  18. import os
  19. class RSAttack():
  20. """
  21. Sparse-RS attacks
  22. :param predict: forward pass function
  23. :param norm: type of the attack
  24. :param n_restarts: number of random restarts
  25. :param n_queries: max number of queries (each restart)
  26. :param eps: bound on the sparsity of perturbations
  27. :param seed: random seed for the starting point
  28. :param alpha_init: parameter to control alphai
  29. :param loss: loss function optimized ('margin', 'ce' supported)
  30. :param resc_schedule adapt schedule of alphai to n_queries
  31. :param device specify device to use
  32. :param log_path path to save logfile.txt
  33. :param constant_schedule use constant alphai
  34. :param targeted perform targeted attacks
  35. :param init_patches initialization for patches
  36. :param resample_loc period in queries of resampling images and
  37. locations for universal attacks
  38. :param data_loader loader to get new images for resampling
  39. :param update_loc_period period in queries of updates of the location
  40. for image-specific patches
  41. """
  42. def __init__(
  43. self,
  44. predict,
  45. norm='L0',
  46. n_queries=5000,
  47. eps=None,
  48. p_init=.8,
  49. n_restarts=1,
  50. seed=0,
  51. verbose=True,
  52. targeted=False,
  53. loss='margin',
  54. resc_schedule=True,
  55. device=None,
  56. log_path=None,
  57. constant_schedule=False,
  58. init_patches='random_squares',
  59. resample_loc=None,
  60. data_loader=None,
  61. update_loc_period=None):
  62. """
  63. Sparse-RS implementation in PyTorch
  64. """
  65. self.predict = predict
  66. self.norm = norm
  67. self.n_queries = n_queries
  68. self.eps = eps
  69. self.p_init = p_init
  70. self.n_restarts = n_restarts
  71. self.seed = seed
  72. self.verbose = verbose
  73. self.targeted = targeted
  74. self.loss = loss
  75. self.rescale_schedule = resc_schedule
  76. self.device = device
  77. self.logger = Logger(log_path)
  78. self.constant_schedule = constant_schedule
  79. self.init_patches = init_patches
  80. self.resample_loc = n_queries // 10 if resample_loc is None else resample_loc
  81. self.data_loader = data_loader
  82. self.update_loc_period = update_loc_period if not update_loc_period is None else 4 if not targeted else 10
  83. def margin_and_loss(self, x, y):
  84. """
  85. :param y: correct labels if untargeted else target labels
  86. """
  87. logits = self.predict(x)
  88. xent = F.cross_entropy(logits, y, reduction='none')
  89. u = torch.arange(x.shape[0])
  90. y_corr = logits[u, y].clone()
  91. logits[u, y] = -float('inf')
  92. y_others = logits.max(dim=-1)[0]
  93. if not self.targeted:
  94. if self.loss == 'ce':
  95. return y_corr - y_others, -1. * xent
  96. elif self.loss == 'margin':
  97. return y_corr - y_others, y_corr - y_others
  98. else:
  99. return y_others - y_corr, xent
  100. def init_hyperparam(self, x):
  101. assert self.norm in ['L0', 'patches', 'frames',
  102. 'patches_universal', 'frames_universal']
  103. assert not self.eps is None
  104. assert self.loss in ['ce', 'margin']
  105. if self.device is None:
  106. self.device = x.device
  107. self.orig_dim = list(x.shape[1:])
  108. self.ndims = len(self.orig_dim)
  109. if self.seed is None:
  110. self.seed = time.time()
  111. if self.targeted:
  112. self.loss = 'ce'
  113. def random_target_classes(self, y_pred, n_classes):
  114. y = torch.zeros_like(y_pred)
  115. for counter in range(y_pred.shape[0]):
  116. l = list(range(n_classes))
  117. l.remove(y_pred[counter])
  118. t = self.random_int(0, len(l))
  119. y[counter] = l[t]
  120. return y.long().to(self.device)
  121. def check_shape(self, x):
  122. return x if len(x.shape) == (self.ndims + 1) else x.unsqueeze(0)
  123. def random_choice(self, shape):
  124. t = 2 * torch.rand(shape).to(self.device) - 1
  125. return torch.sign(t)
  126. def random_int(self, low=0, high=1, shape=[1]):
  127. t = low + (high - low) * torch.rand(shape).to(self.device)
  128. return t.long()
  129. def normalize(self, x):
  130. if self.norm == 'Linf':
  131. t = x.abs().view(x.shape[0], -1).max(1)[0]
  132. return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
  133. elif self.norm == 'L2':
  134. t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
  135. return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
  136. def lp_norm(self, x):
  137. if self.norm == 'L2':
  138. t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
  139. return t.view(-1, *([1] * self.ndims))
  140. def p_selection(self, it):
  141. """ schedule to decrease the parameter p """
  142. if self.rescale_schedule:
  143. it = int(it / self.n_queries * 10000)
  144. if 'patches' in self.norm:
  145. if 10 < it <= 50:
  146. p = self.p_init / 2
  147. elif 50 < it <= 200:
  148. p = self.p_init / 4
  149. elif 200 < it <= 500:
  150. p = self.p_init / 8
  151. elif 500 < it <= 1000:
  152. p = self.p_init / 16
  153. elif 1000 < it <= 2000:
  154. p = self.p_init / 32
  155. elif 2000 < it <= 4000:
  156. p = self.p_init / 64
  157. elif 4000 < it <= 6000:
  158. p = self.p_init / 128
  159. elif 6000 < it <= 8000:
  160. p = self.p_init / 256
  161. elif 8000 < it:
  162. p = self.p_init / 512
  163. else:
  164. p = self.p_init
  165. elif 'frames' in self.norm:
  166. if not 'universal' in self.norm :
  167. tot_qr = 10000 if self.rescale_schedule else self.n_queries
  168. p = max((float(tot_qr - it) / tot_qr - .5) * self.p_init * self.eps ** 2, 0.)
  169. return 3. * math.ceil(p)
  170. else:
  171. assert self.rescale_schedule
  172. its = [200, 600, 1200, 1800, 2500, 10000, 100000]
  173. resc_factors = [1., .8, .6, .4, .2, .1, 0.]
  174. c = 0
  175. while it >= its[c]:
  176. c += 1
  177. return resc_factors[c] * self.p_init
  178. elif 'L0' in self.norm:
  179. if 0 < it <= 50:
  180. p = self.p_init / 2
  181. elif 50 < it <= 200:
  182. p = self.p_init / 4
  183. elif 200 < it <= 500:
  184. p = self.p_init / 5
  185. elif 500 < it <= 1000:
  186. p = self.p_init / 6
  187. elif 1000 < it <= 2000:
  188. p = self.p_init / 8
  189. elif 2000 < it <= 4000:
  190. p = self.p_init / 10
  191. elif 4000 < it <= 6000:
  192. p = self.p_init / 12
  193. elif 6000 < it <= 8000:
  194. p = self.p_init / 15
  195. elif 8000 < it:
  196. p = self.p_init / 20
  197. else:
  198. p = self.p_init
  199. if self.constant_schedule:
  200. p = self.p_init / 2
  201. return p
  202. def sh_selection(self, it):
  203. """ schedule to decrease the parameter p """
  204. t = max((float(self.n_queries - it) / self.n_queries - .0) ** 1., 0) * .75
  205. return t
  206. def get_init_patch(self, c, s, n_iter=1000):
  207. if self.init_patches == 'stripes':
  208. patch_univ = torch.zeros([1, c, s, s]).to(self.device) + self.random_choice(
  209. [1, c, 1, s]).clamp(0., 1.)
  210. elif self.init_patches == 'uniform':
  211. patch_univ = torch.zeros([1, c, s, s]).to(self.device) + self.random_choice(
  212. [1, c, 1, 1]).clamp(0., 1.)
  213. elif self.init_patches == 'random':
  214. patch_univ = self.random_choice([1, c, s, s]).clamp(0., 1.)
  215. elif self.init_patches == 'random_squares':
  216. patch_univ = torch.zeros([1, c, s, s]).to(self.device)
  217. for _ in range(n_iter):
  218. size_init = torch.randint(low=1, high=math.ceil(s ** .5), size=[1]).item()
  219. loc_init = torch.randint(s - size_init + 1, size=[2])
  220. patch_univ[0, :, loc_init[0]:loc_init[0] + size_init, loc_init[1]:loc_init[1] + size_init] = 0.
  221. patch_univ[0, :, loc_init[0]:loc_init[0] + size_init, loc_init[1]:loc_init[1] + size_init
  222. ] += self.random_choice([c, 1, 1]).clamp(0., 1.)
  223. elif self.init_patches == 'sh':
  224. patch_univ = torch.ones([1, c, s, s]).to(self.device)
  225. return patch_univ.clamp(0., 1.)
  226. def attack_single_run(self, x, y):
  227. ### INITIAL DATA CHECK ###
  228. # print('-----'*8)
  229. # print('INITAL X',min(x[0].squeeze()),max(x[0].squeeze()))
  230. # print('-----'*8)
  231. with torch.no_grad():
  232. adv = x.clone()
  233. c, h, w = x.shape[1:]
  234. n_features = c * h * w
  235. n_ex_total = x.shape[0]
  236. if self.norm == 'L0':
  237. eps = self.eps
  238. x_best = x.clone()
  239. n_pixels = h * w
  240. b_all, be_all = torch.zeros([x.shape[0], eps]).long(), torch.zeros([x.shape[0], n_pixels - eps]).long()
  241. for img in range(x.shape[0]):
  242. # print('randperm')
  243. ind_all = torch.randperm(n_pixels)
  244. ind_p = ind_all[:eps]
  245. ind_np = ind_all[eps:]
  246. x_best[img, :, ind_p // w, ind_p % w] = self.random_choice([c, eps]).clamp(0., 1.)
  247. b_all[img] = ind_p.clone()
  248. be_all[img] = ind_np.clone()
  249. margin_min, loss_min = self.margin_and_loss(x_best, y)
  250. n_queries = torch.ones(x.shape[0]).to(self.device)
  251. for it in range(1, self.n_queries):
  252. # print(it,self.n_queries)
  253. # print(x.shape)
  254. # check points still to fool
  255. idx_to_fool = (margin_min > 0.).nonzero(as_tuple=False).squeeze()
  256. x_curr = self.check_shape(x[idx_to_fool])
  257. x_best_curr = self.check_shape(x_best[idx_to_fool])
  258. y_curr = y[idx_to_fool]
  259. margin_min_curr = margin_min[idx_to_fool]
  260. loss_min_curr = loss_min[idx_to_fool]
  261. b_curr, be_curr = b_all[idx_to_fool], be_all[idx_to_fool]
  262. # print(b_curr.shape,be_curr.shape,'inside 1')
  263. if len(y_curr.shape) == 0:
  264. y_curr.unsqueeze_(0)
  265. margin_min_curr.unsqueeze_(0)
  266. loss_min_curr.unsqueeze_(0)
  267. b_curr.unsqueeze_(0)
  268. be_curr.unsqueeze_(0)
  269. idx_to_fool.unsqueeze_(0)
  270. # build new candidate
  271. x_new = x_best_curr.clone()
  272. eps_it = max(int(self.p_selection(it) * eps), 1)
  273. ind_p = torch.randperm(eps)[:eps_it]
  274. ind_np = torch.randperm(n_pixels - eps)[:eps_it]
  275. for img in range(x_new.shape[0]):
  276. p_set = b_curr[img, ind_p]
  277. np_set = be_curr[img, ind_np]
  278. x_new[img, :, p_set // w, p_set % w] = x_curr[img, :, p_set // w, p_set % w].clone()
  279. if eps_it > 1:
  280. x_new[img, :, np_set // w, np_set % w] = self.random_choice([c, eps_it]).clamp(0., 1.)
  281. else:
  282. # if update is 1x1 make sure the sampled color is different from the current one
  283. old_clr = x_new[img, :, np_set // w, np_set % w].clone()
  284. # changed to color shape (3,1) --> 1,1
  285. # assert old_clr.shape == (1, 1), print(old_clr.shape,old_clr)
  286. assert old_clr.shape == (3, 1), print(old_clr.shape,old_clr)
  287. new_clr = old_clr.clone()
  288. while (new_clr == old_clr).all().item():
  289. new_clr = self.random_choice([1, 1]).clone().clamp(0., 1.)
  290. x_new[img, :, np_set // w, np_set % w] = new_clr.clone()
  291. #############################################################
  292. # Check exactly what is happening here with the image
  293. # print(min(x_new[0].squeeze()),max(x_new[0].squeeze()),it, 'x_new')
  294. # print(min(x_curr[0].squeeze()),max(x_curr[0].squeeze()),it, 'x_curr')
  295. # print(min(x_best_curr[0].squeeze()),max(x_best_curr[0].squeeze()),it, 'x_best_curr')
  296. #############################################################
  297. # compute loss of the new candidates
  298. margin, loss = self.margin_and_loss(x_new, y_curr)
  299. n_queries[idx_to_fool] += 1
  300. # update best solution
  301. idx_improved = (loss < loss_min_curr).float()
  302. idx_to_update = (idx_improved > 0.).nonzero(as_tuple=False).squeeze()
  303. loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
  304. idx_miscl = (margin < -1e-6).float()
  305. idx_improved = torch.max(idx_improved, idx_miscl)
  306. nimpr = idx_improved.sum().item()
  307. if nimpr > 0.:
  308. idx_improved = (idx_improved.view(-1) > 0).nonzero(as_tuple=False).squeeze()
  309. margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
  310. x_best[idx_to_fool[idx_improved]] = x_new[idx_improved].clone()
  311. t = b_curr[idx_improved].clone()
  312. te = be_curr[idx_improved].clone()
  313. if nimpr > 1:
  314. t[:, ind_p] = be_curr[idx_improved][:, ind_np] + 0
  315. te[:, ind_np] = b_curr[idx_improved][:, ind_p] + 0
  316. else:
  317. t[ind_p] = be_curr[idx_improved][ind_np] + 0
  318. te[ind_np] = b_curr[idx_improved][ind_p] + 0
  319. b_all[idx_to_fool[idx_improved]] = t.clone()
  320. be_all[idx_to_fool[idx_improved]] = te.clone()
  321. # log results current iteration
  322. ind_succ = (margin_min <= 0.).nonzero(as_tuple=False).squeeze()
  323. if self.verbose and ind_succ.numel() != 0:
  324. self.logger.log(' '.join(['{}'.format(it + 1),
  325. '- success rate={}/{} ({:.2%})'.format(
  326. ind_succ.numel(), n_ex_total,
  327. float(ind_succ.numel()) / n_ex_total),
  328. '- avg # queries={:.1f}'.format(
  329. n_queries[ind_succ].mean().item()),
  330. '- med # queries={:.1f}'.format(
  331. n_queries[ind_succ].median().item()),
  332. '- loss={:.3f}'.format(loss_min.mean()),
  333. '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
  334. ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
  335. '- epsit={:.0f}'.format(eps_it),
  336. ]))
  337. if ind_succ.numel() == n_ex_total:
  338. break
  339. elif self.norm == 'patches':
  340. ''' assumes square images and patches '''
  341. s = int(math.ceil(self.eps ** .5))
  342. x_best = x.clone()
  343. x_new = x.clone()
  344. loc = torch.randint(h - s, size=[x.shape[0], 2])
  345. patches_coll = torch.zeros([x.shape[0], c, s, s]).to(self.device)
  346. assert abs(self.update_loc_period) > 1
  347. loc_t = abs(self.update_loc_period)
  348. # set when to start single channel updates
  349. it_start_cu = None
  350. for it in range(0, self.n_queries):
  351. s_it = int(max(self.p_selection(it) ** .5 * s, 1))
  352. if s_it == 1:
  353. break
  354. it_start_cu = it + (self.n_queries - it) // 2
  355. if self.verbose:
  356. self.logger.log('starting single channel updates at query {}'.format(
  357. it_start_cu))
  358. # initialize patches
  359. if self.verbose:
  360. self.logger.log('using {} initialization'.format(self.init_patches))
  361. for counter in range(x.shape[0]):
  362. patches_coll[counter] += self.get_init_patch(c, s).squeeze().clamp(0., 1.)
  363. x_new[counter, :, loc[counter, 0]:loc[counter, 0] + s,
  364. loc[counter, 1]:loc[counter, 1] + s] = patches_coll[counter].clone()
  365. margin_min, loss_min = self.margin_and_loss(x_new, y)
  366. n_queries = torch.ones(x.shape[0]).to(self.device)
  367. for it in range(1, self.n_queries):
  368. # check points still to fool
  369. idx_to_fool = (margin_min > -1e-6).nonzero(as_tuple=False).squeeze()
  370. x_curr = self.check_shape(x[idx_to_fool])
  371. patches_curr = self.check_shape(patches_coll[idx_to_fool])
  372. y_curr = y[idx_to_fool]
  373. margin_min_curr = margin_min[idx_to_fool]
  374. loss_min_curr = loss_min[idx_to_fool]
  375. loc_curr = loc[idx_to_fool]
  376. if len(y_curr.shape) == 0:
  377. y_curr.unsqueeze_(0)
  378. margin_min_curr.unsqueeze_(0)
  379. loss_min_curr.unsqueeze_(0)
  380. loc_curr.unsqueeze_(0)
  381. idx_to_fool.unsqueeze_(0)
  382. # sample update
  383. s_it = int(max(self.p_selection(it) ** .5 * s, 1))
  384. p_it = torch.randint(s - s_it + 1, size=[2])
  385. sh_it = int(max(self.sh_selection(it) * h, 0))
  386. patches_new = patches_curr.clone()
  387. x_new = x_curr.clone()
  388. loc_new = loc_curr.clone()
  389. update_loc = int((it % loc_t == 0) and (sh_it > 0))
  390. update_patch = 1. - update_loc
  391. if self.update_loc_period < 0 and sh_it > 0:
  392. update_loc = 1. - update_loc
  393. update_patch = 1. - update_patch
  394. for counter in range(x_curr.shape[0]):
  395. if update_patch == 1.:
  396. # update patch
  397. if it < it_start_cu:
  398. if s_it > 1:
  399. patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] += self.random_choice([c, 1, 1])
  400. else:
  401. # make sure to sample a different color
  402. old_clr = patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
  403. new_clr = old_clr.clone()
  404. while (new_clr == old_clr).all().item():
  405. new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
  406. patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
  407. else:
  408. assert s_it == 1
  409. assert it >= it_start_cu
  410. # single channel updates
  411. new_ch = self.random_int(low=0, high=3, shape=[1])
  412. patches_new[counter, new_ch, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = 1. - patches_new[
  413. counter, new_ch, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it]
  414. patches_new[counter].clamp_(0., 1.)
  415. if update_loc == 1:
  416. # update location
  417. loc_new[counter] += (torch.randint(low=-sh_it, high=sh_it + 1, size=[2]))
  418. loc_new[counter].clamp_(0, h - s)
  419. x_new[counter, :, loc_new[counter, 0]:loc_new[counter, 0] + s,
  420. loc_new[counter, 1]:loc_new[counter, 1] + s] = patches_new[counter].clone()
  421. # check loss of new candidate
  422. margin, loss = self.margin_and_loss(x_new, y_curr)
  423. n_queries[idx_to_fool]+= 1
  424. # update best solution
  425. idx_improved = (loss < loss_min_curr).float()
  426. idx_to_update = (idx_improved > 0.).nonzero(as_tuple=False).squeeze()
  427. loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
  428. idx_miscl = (margin < -1e-6).float()
  429. idx_improved = torch.max(idx_improved, idx_miscl)
  430. nimpr = idx_improved.sum().item()
  431. if nimpr > 0.:
  432. idx_improved = (idx_improved.view(-1) > 0).nonzero(as_tuple=False).squeeze()
  433. margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
  434. patches_coll[idx_to_fool[idx_improved]] = patches_new[idx_improved].clone()
  435. loc[idx_to_fool[idx_improved]] = loc_new[idx_improved].clone()
  436. # log results current iteration
  437. ind_succ = (margin_min <= 0.).nonzero(as_tuple=False).squeeze()
  438. if self.verbose and ind_succ.numel() != 0:
  439. self.logger.log(' '.join(['{}'.format(it + 1),
  440. '- success rate={}/{} ({:.2%})'.format(
  441. ind_succ.numel(), n_ex_total,
  442. float(ind_succ.numel()) / n_ex_total),
  443. '- avg # queries={:.1f}'.format(
  444. n_queries[ind_succ].mean().item()),
  445. '- med # queries={:.1f}'.format(
  446. n_queries[ind_succ].median().item()),
  447. '- loss={:.3f}'.format(loss_min.mean()),
  448. '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
  449. ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
  450. #'- sit={:.0f} - sh={:.0f}'.format(s_it, sh_it),
  451. '{}'.format(' - loc' if update_loc == 1. else ''),
  452. ]))
  453. if ind_succ.numel() == n_ex_total:
  454. break
  455. # apply patches
  456. for counter in range(x.shape[0]):
  457. x_best[counter, :, loc[counter, 0]:loc[counter, 0] + s,
  458. loc[counter, 1]:loc[counter, 1] + s] = patches_coll[counter].clone()
  459. elif self.norm == 'patches_universal':
  460. ''' assumes square images and patches '''
  461. s = int(math.ceil(self.eps ** .5))
  462. x_best = x.clone()
  463. self.n_imgs = x.shape[0]
  464. x_new = x.clone()
  465. loc = torch.randint(h - s + 1, size=[x.shape[0], 2])
  466. # set when to start single channel updates
  467. it_start_cu = None
  468. for it in range(0, self.n_queries):
  469. s_it = int(max(self.p_selection(it) ** .5 * s, 1))
  470. if s_it == 1:
  471. break
  472. it_start_cu = it + (self.n_queries - it) // 2
  473. if self.verbose:
  474. self.logger.log('starting single channel updates at query {}'.format(
  475. it_start_cu))
  476. # initialize patch
  477. if self.verbose:
  478. self.logger.log('using {} initialization'.format(self.init_patches))
  479. patch_univ = self.get_init_patch(c, s)
  480. it_init = 0
  481. loss_batch = float(1e10)
  482. n_succs = 0
  483. n_iter = self.n_queries
  484. # init update batch
  485. assert not self.data_loader is None
  486. assert not self.resample_loc is None
  487. assert self.targeted
  488. new_train_imgs = []
  489. n_newimgs = self.n_imgs + 0
  490. n_imgsneeded = math.ceil(self.n_queries / self.resample_loc) * n_newimgs
  491. tot_imgs = 0
  492. if self.verbose:
  493. self.logger.log('imgs updated={}, imgs needed={}'.format(
  494. n_newimgs, n_imgsneeded))
  495. while tot_imgs < min(100000, n_imgsneeded):
  496. x_toupdatetrain, _ = next(self.data_loader)
  497. new_train_imgs.append(x_toupdatetrain)
  498. tot_imgs += x_toupdatetrain.shape[0]
  499. newimgstoadd = torch.cat(new_train_imgs, axis=0)
  500. counter_resamplingimgs = 0
  501. for it in range(it_init, n_iter):
  502. # sample size and location of the update
  503. s_it = int(max(self.p_selection(it) ** .5 * s, 1))
  504. p_it = torch.randint(s - s_it + 1, size=[2])
  505. patch_new = patch_univ.clone()
  506. if s_it > 1:
  507. patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] += self.random_choice([c, 1, 1])
  508. else:
  509. old_clr = patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
  510. new_clr = old_clr.clone()
  511. if it < it_start_cu:
  512. while (new_clr == old_clr).all().item():
  513. new_clr = self.random_choice(new_clr).clone().clamp(0., 1.)
  514. else:
  515. # single channel update
  516. new_ch = self.random_int(low=0, high=3, shape=[1])
  517. new_clr[new_ch] = 1. - new_clr[new_ch]
  518. patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
  519. patch_new.clamp_(0., 1.)
  520. # compute loss for new candidate
  521. x_new = x.clone()
  522. for counter in range(x.shape[0]):
  523. loc_new = loc[counter]
  524. x_new[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] = 0.
  525. x_new[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] += patch_new[0]
  526. margin_run, loss_run = self.margin_and_loss(x_new, y)
  527. if self.loss == 'ce':
  528. loss_run += x_new.shape[0]
  529. loss_new = loss_run.sum()
  530. n_succs_new = (margin_run < -1e-6).sum().item()
  531. # accept candidate if loss improves
  532. if loss_new < loss_batch:
  533. is_accepted = True
  534. loss_batch = loss_new + 0.
  535. patch_univ = patch_new.clone()
  536. n_succs = n_succs_new + 0
  537. else:
  538. is_accepted = False
  539. # sample new locations and images
  540. if (it + 1) % self.resample_loc == 0:
  541. newimgstoadd_it = newimgstoadd[counter_resamplingimgs * n_newimgs:(
  542. counter_resamplingimgs + 1) * n_newimgs].clone().cuda()
  543. new_batch = [x[n_newimgs:].clone(), newimgstoadd_it.clone()]
  544. x = torch.cat(new_batch, dim=0)
  545. assert x.shape[0] == self.n_imgs
  546. loc = torch.randint(h - s + 1, size=[self.n_imgs, 2])
  547. assert loc.shape == (self.n_imgs, 2)
  548. loss_batch = loss_batch * 0. + 1e6
  549. counter_resamplingimgs += 1
  550. # logging current iteration
  551. if self.verbose:
  552. self.logger.log(' '.join(['{}'.format(it + 1),
  553. '- success rate={}/{} ({:.2%})'.format(
  554. n_succs, n_ex_total,
  555. float(n_succs) / n_ex_total),
  556. '- loss={:.3f}'.format(loss_batch),
  557. '- max pert={:.0f}'.format(((x_new - x).abs() > 0
  558. ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
  559. ]))
  560. # apply patches on the initial images
  561. for counter in range(x_best.shape[0]):
  562. loc_new = loc[counter]
  563. x_best[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] = 0.
  564. x_best[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] += patch_univ[0]
  565. elif self.norm == 'frames':
  566. # set width and indices of frames
  567. mask = torch.zeros(x.shape[-2:])
  568. s = self.eps + 0
  569. mask[:s] = 1.
  570. mask[-s:] = 1.
  571. mask[:, :s] = 1.
  572. mask[:, -s:] = 1.
  573. ind = (mask == 1.).nonzero(as_tuple=False).squeeze()
  574. eps = ind.shape[0]
  575. x_best = x.clone()
  576. x_new = x.clone()
  577. mask = mask.view(1, 1, h, w).to(self.device)
  578. mask_frame = torch.ones([1, c, h, w], device=x.device) * mask
  579. #
  580. # set when starting single channel updates
  581. it_start_cu = None
  582. for it in range(0, self.n_queries):
  583. s_it = int(max(self.p_selection(it), 1))
  584. if s_it == 1:
  585. break
  586. it_start_cu = it + (self.n_queries - it) // 2
  587. #it_start_cu = 10000
  588. if self.verbose:
  589. self.logger.log('starting single channel updates at query {}'.format(
  590. it_start_cu))
  591. # initialize frames
  592. x_best[:, :, ind[:, 0], ind[:, 1]] = self.random_choice(
  593. [x.shape[0], c, eps]).clamp(0., 1.)
  594. margin_min, loss_min = self.margin_and_loss(x_best, y)
  595. n_queries = torch.ones(x.shape[0]).to(self.device)
  596. for it in range(1, self.n_queries):
  597. # check points still to fool
  598. idx_to_fool = (margin_min > -1e-6).nonzero(as_tuple=False).squeeze()
  599. x_curr = self.check_shape(x[idx_to_fool])
  600. x_best_curr = self.check_shape(x_best[idx_to_fool])
  601. y_curr = y[idx_to_fool]
  602. margin_min_curr = margin_min[idx_to_fool]
  603. loss_min_curr = loss_min[idx_to_fool]
  604. if len(y_curr.shape) == 0:
  605. y_curr.unsqueeze_(0)
  606. margin_min_curr.unsqueeze_(0)
  607. loss_min_curr.unsqueeze_(0)
  608. idx_to_fool.unsqueeze_(0)
  609. # sample update
  610. s_it = max(int(self.p_selection(it)), 1)
  611. ind_it = torch.randperm(eps)[0]
  612. x_new = x_best_curr.clone()
  613. if s_it > 1:
  614. dir_h = self.random_choice([1]).long().cpu()
  615. dir_w = self.random_choice([1]).long().cpu()
  616. new_clr = self.random_choice([c, 1]).clamp(0., 1.)
  617. for counter in range(x_curr.shape[0]):
  618. if s_it > 1:
  619. for counter_h in range(s_it):
  620. for counter_w in range(s_it):
  621. x_new[counter, :, (ind[ind_it, 0] + dir_h * counter_h).clamp(0, h - 1),
  622. (ind[ind_it, 1] + dir_w * counter_w).clamp(0, w - 1)] = new_clr.clone()
  623. else:
  624. p_it = ind[ind_it].clone()
  625. old_clr = x_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
  626. new_clr = old_clr.clone()
  627. if it < it_start_cu:
  628. while (new_clr == old_clr).all().item():
  629. new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
  630. else:
  631. # single channel update
  632. new_ch = self.random_int(low=0, high=3, shape=[1])
  633. new_clr[new_ch] = 1. - new_clr[new_ch]
  634. x_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
  635. x_new.clamp_(0., 1.)
  636. x_new = (x_new - x_curr) * mask_frame + x_curr
  637. # check loss of new candidate
  638. margin, loss = self.margin_and_loss(x_new, y_curr)
  639. n_queries[idx_to_fool]+= 1
  640. # update best solution
  641. idx_improved = (loss < loss_min_curr).float()
  642. idx_to_update = (idx_improved > 0.).nonzero(as_tuple=False).squeeze()
  643. loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
  644. idx_miscl = (margin < -1e-6).float()
  645. idx_improved = torch.max(idx_improved, idx_miscl)
  646. nimpr = idx_improved.sum().item()
  647. if nimpr > 0.:
  648. idx_improved = (idx_improved.view(-1) > 0).nonzero(as_tuple=False).squeeze()
  649. margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
  650. x_best[idx_to_fool[idx_improved]] = x_new[idx_improved].clone()
  651. # log results current iteration
  652. ind_succ = (margin_min <= 0.).nonzero(as_tuple=False).squeeze()
  653. if self.verbose and ind_succ.numel() != 0:
  654. self.logger.log(' '.join(['{}'.format(it + 1),
  655. '- success rate={}/{} ({:.2%})'.format(
  656. ind_succ.numel(), n_ex_total,
  657. float(ind_succ.numel()) / n_ex_total),
  658. '- avg # queries={:.1f}'.format(
  659. n_queries[ind_succ].mean().item()),
  660. '- med # queries={:.1f}'.format(
  661. n_queries[ind_succ].median().item()),
  662. '- loss={:.3f}'.format(loss_min.mean()),
  663. '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
  664. ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
  665. #'- min pert={:.0f}'.format(((x_new - x_curr).abs() > 0
  666. #).max(1)[0].view(x_new.shape[0], -1).sum(-1).min()),
  667. #'- sit={:.0f} - indit={}'.format(s_it, ind_it.item()),
  668. ]))
  669. if ind_succ.numel() == n_ex_total:
  670. break
  671. elif self.norm == 'frames_universal':
  672. # set width and indices of frames
  673. mask = torch.zeros(x.shape[-2:])
  674. s = self.eps + 0
  675. mask[:s] = 1.
  676. mask[-s:] = 1.
  677. mask[:, :s] = 1.
  678. mask[:, -s:] = 1.
  679. ind = (mask == 1.).nonzero(as_tuple=False).squeeze()
  680. eps = ind.shape[0]
  681. x_best = x.clone()
  682. x_new = x.clone()
  683. mask = mask.view(1, 1, h, w).to(self.device)
  684. mask_frame = torch.ones([1, c, h, w], device=x.device) * mask
  685. frame_univ = self.random_choice([1, c, eps]).clamp(0., 1.)
  686. # set when to start single channel updates
  687. it_start_cu = None
  688. for it in range(0, self.n_queries):
  689. s_it = int(max(self.p_selection(it) * s, 1))
  690. if s_it == 1:
  691. break
  692. it_start_cu = it + (self.n_queries - it) // 2
  693. if self.verbose:
  694. self.logger.log('starting single channel updates at query {}'.format(
  695. it_start_cu))
  696. self.n_imgs = x.shape[0]
  697. loss_batch = float(1e10)
  698. n_queries = torch.ones_like(y).float()
  699. # init update batch
  700. assert not self.data_loader is None
  701. assert not self.resample_loc is None
  702. assert self.targeted
  703. new_train_imgs = []
  704. n_newimgs = self.n_imgs + 0
  705. n_imgsneeded = math.ceil(self.n_queries / self.resample_loc) * n_newimgs
  706. tot_imgs = 0
  707. if self.verbose:
  708. self.logger.log('imgs updated={}, imgs needed={}'.format(
  709. n_newimgs, n_imgsneeded))
  710. while tot_imgs < min(100000, n_imgsneeded):
  711. x_toupdatetrain, _ = next(self.data_loader)
  712. new_train_imgs.append(x_toupdatetrain)
  713. tot_imgs += x_toupdatetrain.shape[0]
  714. newimgstoadd = torch.cat(new_train_imgs, axis=0)
  715. counter_resamplingimgs = 0
  716. for it in range(self.n_queries):
  717. # sample update
  718. s_it = max(int(self.p_selection(it) * self.eps), 1)
  719. ind_it = torch.randperm(eps)[0]
  720. mask_frame[:, :, ind[:, 0], ind[:, 1]] = 0
  721. mask_frame[:, :, ind[:, 0], ind[:, 1]] += frame_univ
  722. if s_it > 1:
  723. dir_h = self.random_choice([1]).long().cpu()
  724. dir_w = self.random_choice([1]).long().cpu()
  725. new_clr = self.random_choice([c, 1]).clamp(0., 1.)
  726. for counter_h in range(s_it):
  727. for counter_w in range(s_it):
  728. mask_frame[0, :, (ind[ind_it, 0] + dir_h * counter_h).clamp(0, h - 1),
  729. (ind[ind_it, 1] + dir_w * counter_w).clamp(0, w - 1)] = new_clr.clone()
  730. else:
  731. p_it = ind[ind_it]
  732. old_clr = mask_frame[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
  733. new_clr = old_clr.clone()
  734. if it < it_start_cu:
  735. while (new_clr == old_clr).all().item():
  736. new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
  737. else:
  738. # single channel update
  739. new_ch = self.random_int(low=0, high=3, shape=[1])
  740. new_clr[new_ch] = 1. - new_clr[new_ch]
  741. mask_frame[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
  742. frame_new = mask_frame[:, :, ind[:, 0], ind[:, 1]].clone()
  743. frame_new.clamp_(0., 1.)
  744. if len(frame_new.shape) == 2:
  745. frame_new.unsqueeze_(0)
  746. x_new[:, :, ind[:, 0], ind[:, 1]] = 0.
  747. x_new[:, :, ind[:, 0], ind[:, 1]] += frame_new
  748. margin_run, loss_run = self.margin_and_loss(x_new, y)
  749. if self.loss == 'ce':
  750. loss_run += x_new.shape[0]
  751. loss_new = loss_run.sum()
  752. n_succs_new = (margin_run < -1e-6).sum().item()
  753. # accept candidate if loss improves
  754. if loss_new < loss_batch:
  755. #is_accepted = True
  756. loss_batch = loss_new + 0.
  757. frame_univ = frame_new.clone()
  758. n_succs = n_succs_new + 0
  759. # sample new images
  760. if (it + 1) % self.resample_loc == 0:
  761. newimgstoadd_it = newimgstoadd[counter_resamplingimgs * n_newimgs:(
  762. counter_resamplingimgs + 1) * n_newimgs].clone().cuda()
  763. new_batch = [x[n_newimgs:].clone(), newimgstoadd_it.clone()]
  764. x = torch.cat(new_batch, dim=0)
  765. assert x.shape[0] == self.n_imgs
  766. loss_batch = loss_batch * 0. + 1e6
  767. x_new = x.clone()
  768. counter_resamplingimgs += 1
  769. # loggin current iteration
  770. if self.verbose:
  771. self.logger.log(' '.join(['{}'.format(it + 1),
  772. '- success rate={}/{} ({:.2%})'.format(
  773. n_succs, n_ex_total,
  774. float(n_succs) / n_ex_total),
  775. '- loss={:.3f}'.format(loss_batch),
  776. '- max pert={:.0f}'.format(((x_new - x).abs() > 0
  777. ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
  778. ]))
  779. # apply frame on initial images
  780. x_best[:, :, ind[:, 0], ind[:, 1]] = 0.
  781. x_best[:, :, ind[:, 0], ind[:, 1]] += frame_univ
  782. return n_queries, x_best
  783. def perturb(self, x, y=None):
  784. """
  785. :param x: clean images
  786. :param y: untargeted attack -> clean labels,
  787. if None we use the predicted labels
  788. targeted attack -> target labels, if None random classes,
  789. different from the predicted ones, are sampled
  790. """
  791. #SUPER INIT CHECK
  792. # print('-----'*8)
  793. # print("SUPER CHECK",min(x[4].squeeze()),max(x[4].squeeze()))
  794. # print('-----'*8)
  795. self.init_hyperparam(x)
  796. adv = x.clone()
  797. qr = torch.zeros([x.shape[0]]).to(self.device)
  798. if y is None:
  799. if not self.targeted:
  800. with torch.no_grad():
  801. output = self.predict(x)
  802. y_pred = output.max(1)[1]
  803. y = y_pred.detach().clone().long().to(self.device)
  804. else:
  805. with torch.no_grad():
  806. output = self.predict(x)
  807. n_classes = output.shape[-1]
  808. y_pred = output.max(1)[1]
  809. y = self.random_target_classes(y_pred, n_classes)
  810. else:
  811. y = y.detach().clone().long().to(self.device)
  812. if not self.targeted:
  813. acc = self.predict(x).max(1)[1] == y
  814. else:
  815. acc = self.predict(x).max(1)[1] != y
  816. startt = time.time()
  817. torch.random.manual_seed(self.seed)
  818. torch.cuda.random.manual_seed(self.seed)
  819. np.random.seed(self.seed)
  820. # print('-----'*8)
  821. # print("SUPER 1.5 CHECK",min(x[4].squeeze()),max(x[4].squeeze()))
  822. # print('-----'*8)
  823. for counter in range(self.n_restarts):
  824. ind_to_fool = acc.nonzero(as_tuple=False).squeeze()
  825. if len(ind_to_fool.shape) == 0:
  826. ind_to_fool = ind_to_fool.unsqueeze(0)
  827. if ind_to_fool.numel() != 0:
  828. x_to_fool = x[ind_to_fool].clone()
  829. y_to_fool = y[ind_to_fool].clone()
  830. # print('-----'*8)
  831. # print("SUPER 2 CHECK",min(x_to_fool[0].squeeze()),max(x_to_fool[0].squeeze()))
  832. # print('-----'*8)
  833. qr_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool)
  834. output_curr = self.predict(adv_curr)
  835. if not self.targeted:
  836. acc_curr = output_curr.max(1)[1] == y_to_fool
  837. else:
  838. acc_curr = output_curr.max(1)[1] != y_to_fool
  839. ind_curr = (acc_curr == 0).nonzero(as_tuple=False).squeeze()
  840. acc[ind_to_fool[ind_curr]] = 0
  841. adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone()
  842. qr[ind_to_fool[ind_curr]] = qr_curr[ind_curr].clone()
  843. # if self.verbose:
  844. # print('restart {} - robust accuracy: {:.2%}'.format(
  845. # counter, acc.float().mean()),
  846. # '- cum. time: {:.1f} s'.format(
  847. # time.time() - startt))
  848. return qr, adv