123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988 |
- # Copyright (c) 2020-present
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- #
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from __future__ import unicode_literals
- import torch
- import time
- import math
- import torch.nn.functional as F
- import numpy as np
- import copy
- from utils.helpers import Logger
- import os
- class RSAttack():
- """
- Sparse-RS attacks
- :param predict: forward pass function
- :param norm: type of the attack
- :param n_restarts: number of random restarts
- :param n_queries: max number of queries (each restart)
- :param eps: bound on the sparsity of perturbations
- :param seed: random seed for the starting point
- :param alpha_init: parameter to control alphai
- :param loss: loss function optimized ('margin', 'ce' supported)
- :param resc_schedule adapt schedule of alphai to n_queries
- :param device specify device to use
- :param log_path path to save logfile.txt
- :param constant_schedule use constant alphai
- :param targeted perform targeted attacks
- :param init_patches initialization for patches
- :param resample_loc period in queries of resampling images and
- locations for universal attacks
- :param data_loader loader to get new images for resampling
- :param update_loc_period period in queries of updates of the location
- for image-specific patches
- """
-
- def __init__(
- self,
- predict,
- norm='L0',
- n_queries=5000,
- eps=None,
- p_init=.8,
- n_restarts=1,
- seed=0,
- verbose=True,
- targeted=False,
- loss='margin',
- resc_schedule=True,
- device=None,
- log_path=None,
- constant_schedule=False,
- init_patches='random_squares',
- resample_loc=None,
- data_loader=None,
- update_loc_period=None):
- """
- Sparse-RS implementation in PyTorch
- """
-
- self.predict = predict
- self.norm = norm
- self.n_queries = n_queries
- self.eps = eps
- self.p_init = p_init
- self.n_restarts = n_restarts
- self.seed = seed
- self.verbose = verbose
- self.targeted = targeted
- self.loss = loss
- self.rescale_schedule = resc_schedule
- self.device = device
- self.logger = Logger(log_path)
- self.constant_schedule = constant_schedule
- self.init_patches = init_patches
- self.resample_loc = n_queries // 10 if resample_loc is None else resample_loc
- self.data_loader = data_loader
- self.update_loc_period = update_loc_period if not update_loc_period is None else 4 if not targeted else 10
-
- def margin_and_loss(self, x, y):
- """
- :param y: correct labels if untargeted else target labels
- """
- logits = self.predict(x)
-
- # breakpoint()
- # I CAHNGED THIS #
- # print('shapes: ',logits.shape, y.shape)
- # print(F.cross_entropy(logits, y, reduction='none'))
- # print('shape bef: ',logits.shape)
- if logits.dim() == 1:
- logits.unsqueeze_(dim=0)
- ## CHANGE ENDS HERE
- # print('shape aft: ',logits.shape)
- xent = F.cross_entropy(logits, y, reduction='none')
- u = torch.arange(x.shape[0])
- y_corr = logits[u, y].clone()
- logits[u, y] = -float('inf')
- # print('val:',logits)
- # print('right before',logits.shape)
- # print('max value',logits.max(dim=-1)[0])
- y_others = logits.max(dim=-1)[0]
- if not self.targeted:
- if self.loss == 'ce':
- return y_corr - y_others, -1. * xent
- elif self.loss == 'margin':
- return y_corr - y_others, y_corr - y_others
- else:
- return y_others - y_corr, xent
- def init_hyperparam(self, x):
- assert self.norm in ['L0', 'patches', 'frames',
- 'patches_universal', 'frames_universal']
- assert not self.eps is None
- assert self.loss in ['ce', 'margin']
- if self.device is None:
- self.device = x.device
- self.orig_dim = list(x.shape[1:])
- self.ndims = len(self.orig_dim)
- if self.seed is None:
- self.seed = time.time()
- if self.targeted:
- self.loss = 'ce'
-
- def random_target_classes(self, y_pred, n_classes):
- y = torch.zeros_like(y_pred)
- for counter in range(y_pred.shape[0]):
- l = list(range(n_classes))
- l.remove(y_pred[counter])
- t = self.random_int(0, len(l))
- y[counter] = l[t]
- return y.long().to(self.device)
- def check_shape(self, x):
- return x if len(x.shape) == (self.ndims + 1) else x.unsqueeze(0)
- def random_choice(self, shape):
- t = 2 * torch.rand(shape).to(self.device) - 1
- return torch.sign(t)
- def random_int(self, low=0, high=1, shape=[1]):
- t = low + (high - low) * torch.rand(shape).to(self.device)
- return t.long()
- def normalize(self, x):
- if self.norm == 'Linf':
- t = x.abs().view(x.shape[0], -1).max(1)[0]
- return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
- elif self.norm == 'L2':
- t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
- return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
- def lp_norm(self, x):
- if self.norm == 'L2':
- t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
- return t.view(-1, *([1] * self.ndims))
- def p_selection(self, it):
- """ schedule to decrease the parameter p """
- if self.rescale_schedule:
- it = int(it / self.n_queries * 10000)
- if 'patches' in self.norm:
- if 10 < it <= 50:
- p = self.p_init / 2
- elif 50 < it <= 200:
- p = self.p_init / 4
- elif 200 < it <= 500:
- p = self.p_init / 8
- elif 500 < it <= 1000:
- p = self.p_init / 16
- elif 1000 < it <= 2000:
- p = self.p_init / 32
- elif 2000 < it <= 4000:
- p = self.p_init / 64
- elif 4000 < it <= 6000:
- p = self.p_init / 128
- elif 6000 < it <= 8000:
- p = self.p_init / 256
- elif 8000 < it:
- p = self.p_init / 512
- else:
- p = self.p_init
- elif 'frames' in self.norm:
- if not 'universal' in self.norm :
- tot_qr = 10000 if self.rescale_schedule else self.n_queries
- p = max((float(tot_qr - it) / tot_qr - .5) * self.p_init * self.eps ** 2, 0.)
- return 3. * math.ceil(p)
-
- else:
- assert self.rescale_schedule
- its = [200, 600, 1200, 1800, 2500, 10000, 100000]
- resc_factors = [1., .8, .6, .4, .2, .1, 0.]
- c = 0
- while it >= its[c]:
- c += 1
- return resc_factors[c] * self.p_init
-
- elif 'L0' in self.norm:
- if 0 < it <= 50:
- p = self.p_init / 2
- elif 50 < it <= 200:
- p = self.p_init / 4
- elif 200 < it <= 500:
- p = self.p_init / 5
- elif 500 < it <= 1000:
- p = self.p_init / 6
- elif 1000 < it <= 2000:
- p = self.p_init / 8
- elif 2000 < it <= 4000:
- p = self.p_init / 10
- elif 4000 < it <= 6000:
- p = self.p_init / 12
- elif 6000 < it <= 8000:
- p = self.p_init / 15
- elif 8000 < it:
- p = self.p_init / 20
- else:
- p = self.p_init
-
- if self.constant_schedule:
- p = self.p_init / 2
-
- return p
- def sh_selection(self, it):
- """ schedule to decrease the parameter p """
- t = max((float(self.n_queries - it) / self.n_queries - .0) ** 1., 0) * .75
- return t
-
- def get_init_patch(self, c, s, n_iter=1000):
- if self.init_patches == 'stripes':
- patch_univ = torch.zeros([1, c, s, s]).to(self.device) + self.random_choice(
- [1, c, 1, s]).clamp(0., 1.)
- elif self.init_patches == 'uniform':
- patch_univ = torch.zeros([1, c, s, s]).to(self.device) + self.random_choice(
- [1, c, 1, 1]).clamp(0., 1.)
- elif self.init_patches == 'random':
- patch_univ = self.random_choice([1, c, s, s]).clamp(0., 1.)
- elif self.init_patches == 'random_squares':
- patch_univ = torch.zeros([1, c, s, s]).to(self.device)
- for _ in range(n_iter):
- size_init = torch.randint(low=1, high=math.ceil(s ** .5), size=[1]).item()
- loc_init = torch.randint(s - size_init + 1, size=[2])
- patch_univ[0, :, loc_init[0]:loc_init[0] + size_init, loc_init[1]:loc_init[1] + size_init] = 0.
- patch_univ[0, :, loc_init[0]:loc_init[0] + size_init, loc_init[1]:loc_init[1] + size_init
- ] += self.random_choice([c, 1, 1]).clamp(0., 1.)
- elif self.init_patches == 'sh':
- patch_univ = torch.ones([1, c, s, s]).to(self.device)
-
- return patch_univ.clamp(0., 1.)
-
- def attack_single_run(self, x, y):
- ### INITIAL DATA CHECK ###
- # print('-----'*8)
- # print('INITAL X',min(x[0].squeeze()),max(x[0].squeeze()))
- # print('-----'*8)
- with torch.no_grad():
- adv = x.clone()
- c, h, w = x.shape[1:]
- n_features = c * h * w
- n_ex_total = x.shape[0]
-
- if self.norm == 'L0':
- eps = self.eps
-
- x_best = x.clone()
- n_pixels = h * w
- b_all, be_all = torch.zeros([x.shape[0], eps]).long(), torch.zeros([x.shape[0], n_pixels - eps]).long()
- for img in range(x.shape[0]):
- # print('randperm')
- ind_all = torch.randperm(n_pixels)
- ind_p = ind_all[:eps]
- ind_np = ind_all[eps:]
- x_best[img, :, ind_p // w, ind_p % w] = self.random_choice([c, eps]).clamp(0., 1.)
- b_all[img] = ind_p.clone()
- be_all[img] = ind_np.clone()
-
- margin_min, loss_min = self.margin_and_loss(x_best, y)
- n_queries = torch.ones(x.shape[0]).to(self.device)
-
- for it in range(1, self.n_queries):
- # print(it,self.n_queries)
- # print(x.shape)
- # check points still to fool
- idx_to_fool = (margin_min > 0.).nonzero(as_tuple=False).squeeze()
- x_curr = self.check_shape(x[idx_to_fool])
- x_best_curr = self.check_shape(x_best[idx_to_fool])
- y_curr = y[idx_to_fool]
- margin_min_curr = margin_min[idx_to_fool]
- loss_min_curr = loss_min[idx_to_fool]
- b_curr, be_curr = b_all[idx_to_fool], be_all[idx_to_fool]
- # print(b_curr.shape,be_curr.shape,'inside 1')
- if len(y_curr.shape) == 0:
- y_curr.unsqueeze_(0)
- margin_min_curr.unsqueeze_(0)
- loss_min_curr.unsqueeze_(0)
- b_curr.unsqueeze_(0)
- be_curr.unsqueeze_(0)
- idx_to_fool.unsqueeze_(0)
-
- # build new candidate
- x_new = x_best_curr.clone()
- eps_it = max(int(self.p_selection(it) * eps), 1)
- ind_p = torch.randperm(eps)[:eps_it]
- ind_np = torch.randperm(n_pixels - eps)[:eps_it]
-
- for img in range(x_new.shape[0]):
- p_set = b_curr[img, ind_p]
- np_set = be_curr[img, ind_np]
- x_new[img, :, p_set // w, p_set % w] = x_curr[img, :, p_set // w, p_set % w].clone()
- if eps_it > 1:
- x_new[img, :, np_set // w, np_set % w] = self.random_choice([c, eps_it]).clamp(0., 1.)
- else:
- # if update is 1x1 make sure the sampled color is different from the current one
- old_clr = x_new[img, :, np_set // w, np_set % w].clone()
- # changed to color shape (3,1) --> 1,1
- assert old_clr.shape == (1, 1), print(old_clr.shape,old_clr)
- # assert old_clr.shape == (3, 1), print(old_clr.shape,old_clr)
- new_clr = old_clr.clone()
- while (new_clr == old_clr).all().item():
- new_clr = self.random_choice([1, 1]).clone().clamp(0., 1.)
- x_new[img, :, np_set // w, np_set % w] = new_clr.clone()
- #############################################################
- # Check exactly what is happening here with the image
- # print(min(x_new[0].squeeze()),max(x_new[0].squeeze()),it, 'x_new')
- # print(min(x_curr[0].squeeze()),max(x_curr[0].squeeze()),it, 'x_curr')
- # print(min(x_best_curr[0].squeeze()),max(x_best_curr[0].squeeze()),it, 'x_best_curr')
- #############################################################
- # compute loss of the new candidates
- margin, loss = self.margin_and_loss(x_new, y_curr)
- n_queries[idx_to_fool] += 1
-
- # update best solution
- idx_improved = (loss < loss_min_curr).float()
- idx_to_update = (idx_improved > 0.).nonzero(as_tuple=False).squeeze()
- loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
-
- idx_miscl = (margin < -1e-6).float()
- idx_improved = torch.max(idx_improved, idx_miscl)
- nimpr = idx_improved.sum().item()
- if nimpr > 0.:
- idx_improved = (idx_improved.view(-1) > 0).nonzero(as_tuple=False).squeeze()
- margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
- x_best[idx_to_fool[idx_improved]] = x_new[idx_improved].clone()
- t = b_curr[idx_improved].clone()
- te = be_curr[idx_improved].clone()
-
- if nimpr > 1:
- t[:, ind_p] = be_curr[idx_improved][:, ind_np] + 0
- te[:, ind_np] = b_curr[idx_improved][:, ind_p] + 0
- else:
- t[ind_p] = be_curr[idx_improved][ind_np] + 0
- te[ind_np] = b_curr[idx_improved][ind_p] + 0
-
- b_all[idx_to_fool[idx_improved]] = t.clone()
- be_all[idx_to_fool[idx_improved]] = te.clone()
-
- # log results current iteration
- ind_succ = (margin_min <= 0.).nonzero(as_tuple=False).squeeze()
- if self.verbose and ind_succ.numel() != 0:
- self.logger.log(' '.join(['{}'.format(it + 1),
- '- success rate={}/{} ({:.2%})'.format(
- ind_succ.numel(), n_ex_total,
- float(ind_succ.numel()) / n_ex_total),
- '- avg # queries={:.1f}'.format(
- n_queries[ind_succ].mean().item()),
- '- med # queries={:.1f}'.format(
- n_queries[ind_succ].median().item()),
- '- loss={:.3f}'.format(loss_min.mean()),
- '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
- ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
- '- epsit={:.0f}'.format(eps_it),
- ]))
-
- if ind_succ.numel() == n_ex_total:
- break
-
- elif self.norm == 'patches':
- ''' assumes square images and patches '''
-
- s = int(math.ceil(self.eps ** .5))
- x_best = x.clone()
- x_new = x.clone()
- loc = torch.randint(h - s, size=[x.shape[0], 2])
- patches_coll = torch.zeros([x.shape[0], c, s, s]).to(self.device)
- assert abs(self.update_loc_period) > 1
- loc_t = abs(self.update_loc_period)
-
- # set when to start single channel updates
- it_start_cu = None
- for it in range(0, self.n_queries):
- s_it = int(max(self.p_selection(it) ** .5 * s, 1))
- if s_it == 1:
- break
- it_start_cu = it + (self.n_queries - it) // 2
- if self.verbose:
- self.logger.log('starting single channel updates at query {}'.format(
- it_start_cu))
-
- # initialize patches
- if self.verbose:
- self.logger.log('using {} initialization'.format(self.init_patches))
- for counter in range(x.shape[0]):
- patches_coll[counter] += self.get_init_patch(c, s).squeeze().clamp(0., 1.)
- x_new[counter, :, loc[counter, 0]:loc[counter, 0] + s,
- loc[counter, 1]:loc[counter, 1] + s] = patches_coll[counter].clone()
-
- margin_min, loss_min = self.margin_and_loss(x_new, y)
- n_queries = torch.ones(x.shape[0]).to(self.device)
-
- for it in range(1, self.n_queries):
- # check points still to fool
- idx_to_fool = (margin_min > -1e-6).nonzero(as_tuple=False).squeeze()
- x_curr = self.check_shape(x[idx_to_fool])
- patches_curr = self.check_shape(patches_coll[idx_to_fool])
- y_curr = y[idx_to_fool]
- margin_min_curr = margin_min[idx_to_fool]
- loss_min_curr = loss_min[idx_to_fool]
- loc_curr = loc[idx_to_fool]
- if len(y_curr.shape) == 0:
- y_curr.unsqueeze_(0)
- margin_min_curr.unsqueeze_(0)
- loss_min_curr.unsqueeze_(0)
-
- loc_curr.unsqueeze_(0)
- idx_to_fool.unsqueeze_(0)
-
- # sample update
- s_it = int(max(self.p_selection(it) ** .5 * s, 1))
- p_it = torch.randint(s - s_it + 1, size=[2])
- sh_it = int(max(self.sh_selection(it) * h, 0))
- patches_new = patches_curr.clone()
- x_new = x_curr.clone()
- loc_new = loc_curr.clone()
- update_loc = int((it % loc_t == 0) and (sh_it > 0))
- update_patch = 1. - update_loc
- if self.update_loc_period < 0 and sh_it > 0:
- update_loc = 1. - update_loc
- update_patch = 1. - update_patch
- for counter in range(x_curr.shape[0]):
- if update_patch == 1.:
- # update patch
- if it < it_start_cu:
- if s_it > 1:
- patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] += self.random_choice([c, 1, 1])
- else:
- # make sure to sample a different color
- old_clr = patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
- new_clr = old_clr.clone()
- while (new_clr == old_clr).all().item():
- new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
- patches_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
- else:
- assert s_it == 1
- assert it >= it_start_cu
- # single channel updates
- new_ch = self.random_int(low=0, high=3, shape=[1])
- patches_new[counter, new_ch, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = 1. - patches_new[
- counter, new_ch, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it]
-
- patches_new[counter].clamp_(0., 1.)
- if update_loc == 1:
- # update location
- loc_new[counter] += (torch.randint(low=-sh_it, high=sh_it + 1, size=[2]))
- loc_new[counter].clamp_(0, h - s)
-
- x_new[counter, :, loc_new[counter, 0]:loc_new[counter, 0] + s,
- loc_new[counter, 1]:loc_new[counter, 1] + s] = patches_new[counter].clone()
-
- # check loss of new candidate
- margin, loss = self.margin_and_loss(x_new, y_curr)
- n_queries[idx_to_fool]+= 1
-
- # update best solution
- idx_improved = (loss < loss_min_curr).float()
- idx_to_update = (idx_improved > 0.).nonzero(as_tuple=False).squeeze()
- loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
-
- idx_miscl = (margin < -1e-6).float()
- idx_improved = torch.max(idx_improved, idx_miscl)
- nimpr = idx_improved.sum().item()
- if nimpr > 0.:
- idx_improved = (idx_improved.view(-1) > 0).nonzero(as_tuple=False).squeeze()
- margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
- patches_coll[idx_to_fool[idx_improved]] = patches_new[idx_improved].clone()
- loc[idx_to_fool[idx_improved]] = loc_new[idx_improved].clone()
-
- # log results current iteration
- ind_succ = (margin_min <= 0.).nonzero(as_tuple=False).squeeze()
- if self.verbose and ind_succ.numel() != 0:
- self.logger.log(' '.join(['{}'.format(it + 1),
- '- success rate={}/{} ({:.2%})'.format(
- ind_succ.numel(), n_ex_total,
- float(ind_succ.numel()) / n_ex_total),
- '- avg # queries={:.1f}'.format(
- n_queries[ind_succ].mean().item()),
- '- med # queries={:.1f}'.format(
- n_queries[ind_succ].median().item()),
- '- loss={:.3f}'.format(loss_min.mean()),
- '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
- ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
- #'- sit={:.0f} - sh={:.0f}'.format(s_it, sh_it),
- '{}'.format(' - loc' if update_loc == 1. else ''),
- ]))
- if ind_succ.numel() == n_ex_total:
- break
-
- # apply patches
- for counter in range(x.shape[0]):
- x_best[counter, :, loc[counter, 0]:loc[counter, 0] + s,
- loc[counter, 1]:loc[counter, 1] + s] = patches_coll[counter].clone()
-
- elif self.norm == 'patches_universal':
- ''' assumes square images and patches '''
-
- s = int(math.ceil(self.eps ** .5))
- x_best = x.clone()
- self.n_imgs = x.shape[0]
- x_new = x.clone()
- loc = torch.randint(h - s + 1, size=[x.shape[0], 2])
-
- # set when to start single channel updates
- it_start_cu = None
- for it in range(0, self.n_queries):
- s_it = int(max(self.p_selection(it) ** .5 * s, 1))
- if s_it == 1:
- break
- it_start_cu = it + (self.n_queries - it) // 2
- if self.verbose:
- self.logger.log('starting single channel updates at query {}'.format(
- it_start_cu))
-
- # initialize patch
- if self.verbose:
- self.logger.log('using {} initialization'.format(self.init_patches))
- patch_univ = self.get_init_patch(c, s)
- it_init = 0
-
- loss_batch = float(1e10)
- n_succs = 0
- n_iter = self.n_queries
-
- # init update batch
- assert not self.data_loader is None
- assert not self.resample_loc is None
- assert self.targeted
- new_train_imgs = []
- n_newimgs = self.n_imgs + 0
- n_imgsneeded = math.ceil(self.n_queries / self.resample_loc) * n_newimgs
- tot_imgs = 0
- if self.verbose:
- self.logger.log('imgs updated={}, imgs needed={}'.format(
- n_newimgs, n_imgsneeded))
- while tot_imgs < min(100000, n_imgsneeded):
- x_toupdatetrain, _ = next(self.data_loader)
- new_train_imgs.append(x_toupdatetrain)
- tot_imgs += x_toupdatetrain.shape[0]
- newimgstoadd = torch.cat(new_train_imgs, axis=0)
- counter_resamplingimgs = 0
-
- for it in range(it_init, n_iter):
- # sample size and location of the update
- s_it = int(max(self.p_selection(it) ** .5 * s, 1))
- p_it = torch.randint(s - s_it + 1, size=[2])
-
- patch_new = patch_univ.clone()
-
- if s_it > 1:
- patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] += self.random_choice([c, 1, 1])
- else:
- old_clr = patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
- new_clr = old_clr.clone()
- if it < it_start_cu:
- while (new_clr == old_clr).all().item():
- new_clr = self.random_choice(new_clr).clone().clamp(0., 1.)
- else:
- # single channel update
- new_ch = self.random_int(low=0, high=3, shape=[1])
- new_clr[new_ch] = 1. - new_clr[new_ch]
-
- patch_new[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
-
- patch_new.clamp_(0., 1.)
-
- # compute loss for new candidate
- x_new = x.clone()
-
- for counter in range(x.shape[0]):
- loc_new = loc[counter]
- x_new[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] = 0.
- x_new[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] += patch_new[0]
-
- margin_run, loss_run = self.margin_and_loss(x_new, y)
- if self.loss == 'ce':
- loss_run += x_new.shape[0]
- loss_new = loss_run.sum()
- n_succs_new = (margin_run < -1e-6).sum().item()
-
- # accept candidate if loss improves
- if loss_new < loss_batch:
- is_accepted = True
- loss_batch = loss_new + 0.
- patch_univ = patch_new.clone()
- n_succs = n_succs_new + 0
- else:
- is_accepted = False
-
- # sample new locations and images
- if (it + 1) % self.resample_loc == 0:
- newimgstoadd_it = newimgstoadd[counter_resamplingimgs * n_newimgs:(
- counter_resamplingimgs + 1) * n_newimgs].clone().cuda()
- new_batch = [x[n_newimgs:].clone(), newimgstoadd_it.clone()]
- x = torch.cat(new_batch, dim=0)
- assert x.shape[0] == self.n_imgs
-
- loc = torch.randint(h - s + 1, size=[self.n_imgs, 2])
- assert loc.shape == (self.n_imgs, 2)
-
- loss_batch = loss_batch * 0. + 1e6
- counter_resamplingimgs += 1
-
- # logging current iteration
- if self.verbose:
- self.logger.log(' '.join(['{}'.format(it + 1),
- '- success rate={}/{} ({:.2%})'.format(
- n_succs, n_ex_total,
- float(n_succs) / n_ex_total),
- '- loss={:.3f}'.format(loss_batch),
- '- max pert={:.0f}'.format(((x_new - x).abs() > 0
- ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
- ]))
- # apply patches on the initial images
- for counter in range(x_best.shape[0]):
- loc_new = loc[counter]
- x_best[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] = 0.
- x_best[counter, :, loc_new[0]:loc_new[0] + s, loc_new[1]:loc_new[1] + s] += patch_univ[0]
-
- elif self.norm == 'frames':
- # set width and indices of frames
- mask = torch.zeros(x.shape[-2:])
- s = self.eps + 0
- mask[:s] = 1.
- mask[-s:] = 1.
- mask[:, :s] = 1.
- mask[:, -s:] = 1.
- ind = (mask == 1.).nonzero(as_tuple=False).squeeze()
- eps = ind.shape[0]
- x_best = x.clone()
- x_new = x.clone()
- mask = mask.view(1, 1, h, w).to(self.device)
- mask_frame = torch.ones([1, c, h, w], device=x.device) * mask
- #
-
- # set when starting single channel updates
- it_start_cu = None
- for it in range(0, self.n_queries):
- s_it = int(max(self.p_selection(it), 1))
- if s_it == 1:
- break
- it_start_cu = it + (self.n_queries - it) // 2
- #it_start_cu = 10000
- if self.verbose:
- self.logger.log('starting single channel updates at query {}'.format(
- it_start_cu))
-
- # initialize frames
- x_best[:, :, ind[:, 0], ind[:, 1]] = self.random_choice(
- [x.shape[0], c, eps]).clamp(0., 1.)
-
- margin_min, loss_min = self.margin_and_loss(x_best, y)
- n_queries = torch.ones(x.shape[0]).to(self.device)
-
- for it in range(1, self.n_queries):
- # check points still to fool
- idx_to_fool = (margin_min > -1e-6).nonzero(as_tuple=False).squeeze()
- x_curr = self.check_shape(x[idx_to_fool])
- x_best_curr = self.check_shape(x_best[idx_to_fool])
- y_curr = y[idx_to_fool]
- margin_min_curr = margin_min[idx_to_fool]
- loss_min_curr = loss_min[idx_to_fool]
-
- if len(y_curr.shape) == 0:
- y_curr.unsqueeze_(0)
- margin_min_curr.unsqueeze_(0)
- loss_min_curr.unsqueeze_(0)
- idx_to_fool.unsqueeze_(0)
-
- # sample update
- s_it = max(int(self.p_selection(it)), 1)
- ind_it = torch.randperm(eps)[0]
-
- x_new = x_best_curr.clone()
- if s_it > 1:
- dir_h = self.random_choice([1]).long().cpu()
- dir_w = self.random_choice([1]).long().cpu()
- new_clr = self.random_choice([c, 1]).clamp(0., 1.)
-
- for counter in range(x_curr.shape[0]):
- if s_it > 1:
- for counter_h in range(s_it):
- for counter_w in range(s_it):
- x_new[counter, :, (ind[ind_it, 0] + dir_h * counter_h).clamp(0, h - 1),
- (ind[ind_it, 1] + dir_w * counter_w).clamp(0, w - 1)] = new_clr.clone()
- else:
- p_it = ind[ind_it].clone()
- old_clr = x_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
- new_clr = old_clr.clone()
- if it < it_start_cu:
- while (new_clr == old_clr).all().item():
- new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
- else:
- # single channel update
- new_ch = self.random_int(low=0, high=3, shape=[1])
- new_clr[new_ch] = 1. - new_clr[new_ch]
- x_new[counter, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
-
- x_new.clamp_(0., 1.)
- x_new = (x_new - x_curr) * mask_frame + x_curr
-
- # check loss of new candidate
- margin, loss = self.margin_and_loss(x_new, y_curr)
- n_queries[idx_to_fool]+= 1
-
- # update best solution
- idx_improved = (loss < loss_min_curr).float()
- idx_to_update = (idx_improved > 0.).nonzero(as_tuple=False).squeeze()
- loss_min[idx_to_fool[idx_to_update]] = loss[idx_to_update]
-
- idx_miscl = (margin < -1e-6).float()
- idx_improved = torch.max(idx_improved, idx_miscl)
- nimpr = idx_improved.sum().item()
- if nimpr > 0.:
- idx_improved = (idx_improved.view(-1) > 0).nonzero(as_tuple=False).squeeze()
- margin_min[idx_to_fool[idx_improved]] = margin[idx_improved].clone()
- x_best[idx_to_fool[idx_improved]] = x_new[idx_improved].clone()
-
- # log results current iteration
- ind_succ = (margin_min <= 0.).nonzero(as_tuple=False).squeeze()
- if self.verbose and ind_succ.numel() != 0:
- self.logger.log(' '.join(['{}'.format(it + 1),
- '- success rate={}/{} ({:.2%})'.format(
- ind_succ.numel(), n_ex_total,
- float(ind_succ.numel()) / n_ex_total),
- '- avg # queries={:.1f}'.format(
- n_queries[ind_succ].mean().item()),
- '- med # queries={:.1f}'.format(
- n_queries[ind_succ].median().item()),
- '- loss={:.3f}'.format(loss_min.mean()),
- '- max pert={:.0f}'.format(((x_new - x_curr).abs() > 0
- ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
- #'- min pert={:.0f}'.format(((x_new - x_curr).abs() > 0
- #).max(1)[0].view(x_new.shape[0], -1).sum(-1).min()),
- #'- sit={:.0f} - indit={}'.format(s_it, ind_it.item()),
- ]))
- if ind_succ.numel() == n_ex_total:
- break
-
- elif self.norm == 'frames_universal':
- # set width and indices of frames
- mask = torch.zeros(x.shape[-2:])
- s = self.eps + 0
- mask[:s] = 1.
- mask[-s:] = 1.
- mask[:, :s] = 1.
- mask[:, -s:] = 1.
- ind = (mask == 1.).nonzero(as_tuple=False).squeeze()
- eps = ind.shape[0]
- x_best = x.clone()
- x_new = x.clone()
- mask = mask.view(1, 1, h, w).to(self.device)
- mask_frame = torch.ones([1, c, h, w], device=x.device) * mask
- frame_univ = self.random_choice([1, c, eps]).clamp(0., 1.)
-
- # set when to start single channel updates
- it_start_cu = None
- for it in range(0, self.n_queries):
- s_it = int(max(self.p_selection(it) * s, 1))
- if s_it == 1:
- break
- it_start_cu = it + (self.n_queries - it) // 2
- if self.verbose:
- self.logger.log('starting single channel updates at query {}'.format(
- it_start_cu))
-
- self.n_imgs = x.shape[0]
- loss_batch = float(1e10)
- n_queries = torch.ones_like(y).float()
-
- # init update batch
- assert not self.data_loader is None
- assert not self.resample_loc is None
- assert self.targeted
- new_train_imgs = []
- n_newimgs = self.n_imgs + 0
- n_imgsneeded = math.ceil(self.n_queries / self.resample_loc) * n_newimgs
- tot_imgs = 0
- if self.verbose:
- self.logger.log('imgs updated={}, imgs needed={}'.format(
- n_newimgs, n_imgsneeded))
- while tot_imgs < min(100000, n_imgsneeded):
- x_toupdatetrain, _ = next(self.data_loader)
- new_train_imgs.append(x_toupdatetrain)
- tot_imgs += x_toupdatetrain.shape[0]
- newimgstoadd = torch.cat(new_train_imgs, axis=0)
- counter_resamplingimgs = 0
-
- for it in range(self.n_queries):
- # sample update
- s_it = max(int(self.p_selection(it) * self.eps), 1)
- ind_it = torch.randperm(eps)[0]
-
- mask_frame[:, :, ind[:, 0], ind[:, 1]] = 0
- mask_frame[:, :, ind[:, 0], ind[:, 1]] += frame_univ
-
- if s_it > 1:
- dir_h = self.random_choice([1]).long().cpu()
- dir_w = self.random_choice([1]).long().cpu()
- new_clr = self.random_choice([c, 1]).clamp(0., 1.)
-
- for counter_h in range(s_it):
- for counter_w in range(s_it):
- mask_frame[0, :, (ind[ind_it, 0] + dir_h * counter_h).clamp(0, h - 1),
- (ind[ind_it, 1] + dir_w * counter_w).clamp(0, w - 1)] = new_clr.clone()
- else:
- p_it = ind[ind_it]
- old_clr = mask_frame[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it].clone()
- new_clr = old_clr.clone()
- if it < it_start_cu:
- while (new_clr == old_clr).all().item():
- new_clr = self.random_choice([c, 1, 1]).clone().clamp(0., 1.)
- else:
- # single channel update
- new_ch = self.random_int(low=0, high=3, shape=[1])
- new_clr[new_ch] = 1. - new_clr[new_ch]
- mask_frame[0, :, p_it[0]:p_it[0] + s_it, p_it[1]:p_it[1] + s_it] = new_clr.clone()
-
- frame_new = mask_frame[:, :, ind[:, 0], ind[:, 1]].clone()
- frame_new.clamp_(0., 1.)
- if len(frame_new.shape) == 2:
- frame_new.unsqueeze_(0)
-
- x_new[:, :, ind[:, 0], ind[:, 1]] = 0.
- x_new[:, :, ind[:, 0], ind[:, 1]] += frame_new
-
- margin_run, loss_run = self.margin_and_loss(x_new, y)
- if self.loss == 'ce':
- loss_run += x_new.shape[0]
- loss_new = loss_run.sum()
- n_succs_new = (margin_run < -1e-6).sum().item()
-
- # accept candidate if loss improves
- if loss_new < loss_batch:
- #is_accepted = True
- loss_batch = loss_new + 0.
- frame_univ = frame_new.clone()
- n_succs = n_succs_new + 0
-
- # sample new images
- if (it + 1) % self.resample_loc == 0:
- newimgstoadd_it = newimgstoadd[counter_resamplingimgs * n_newimgs:(
- counter_resamplingimgs + 1) * n_newimgs].clone().cuda()
- new_batch = [x[n_newimgs:].clone(), newimgstoadd_it.clone()]
- x = torch.cat(new_batch, dim=0)
- assert x.shape[0] == self.n_imgs
-
- loss_batch = loss_batch * 0. + 1e6
- x_new = x.clone()
- counter_resamplingimgs += 1
-
- # loggin current iteration
- if self.verbose:
- self.logger.log(' '.join(['{}'.format(it + 1),
- '- success rate={}/{} ({:.2%})'.format(
- n_succs, n_ex_total,
- float(n_succs) / n_ex_total),
- '- loss={:.3f}'.format(loss_batch),
- '- max pert={:.0f}'.format(((x_new - x).abs() > 0
- ).max(1)[0].view(x_new.shape[0], -1).sum(-1).max()),
- ]))
-
- # apply frame on initial images
- x_best[:, :, ind[:, 0], ind[:, 1]] = 0.
- x_best[:, :, ind[:, 0], ind[:, 1]] += frame_univ
-
- return n_queries, x_best
- def perturb(self, x, y=None):
- """
- :param x: clean images
- :param y: untargeted attack -> clean labels,
- if None we use the predicted labels
- targeted attack -> target labels, if None random classes,
- different from the predicted ones, are sampled
- """
- #SUPER INIT CHECK
- # print('-----'*8)
- # print("SUPER CHECK",min(x[4].squeeze()),max(x[4].squeeze()))
- # print('-----'*8)
- self.init_hyperparam(x)
- adv = x.clone()
- qr = torch.zeros([x.shape[0]]).to(self.device)
- if y is None:
- if not self.targeted:
- with torch.no_grad():
- output = self.predict(x)
- y_pred = output.max(1)[1]
- y = y_pred.detach().clone().long().to(self.device)
- else:
- with torch.no_grad():
- output = self.predict(x)
- n_classes = output.shape[-1]
- y_pred = output.max(1)[1]
- y = self.random_target_classes(y_pred, n_classes)
- else:
- y = y.detach().clone().long().to(self.device)
- if not self.targeted:
- acc = self.predict(x).max(1)[1] == y
- else:
- acc = self.predict(x).max(1)[1] != y
- startt = time.time()
- torch.random.manual_seed(self.seed)
- torch.cuda.random.manual_seed(self.seed)
- np.random.seed(self.seed)
- # print('-----'*8)
- # print("SUPER 1.5 CHECK",min(x[4].squeeze()),max(x[4].squeeze()))
- # print('-----'*8)
- for counter in range(self.n_restarts):
-
- ind_to_fool = acc.nonzero(as_tuple=False).squeeze()
- if len(ind_to_fool.shape) == 0:
- ind_to_fool = ind_to_fool.unsqueeze(0)
- if ind_to_fool.numel() != 0:
- x_to_fool = x[ind_to_fool].clone()
- y_to_fool = y[ind_to_fool].clone()
- # print('-----'*8)
- # print("SUPER 2 CHECK",min(x_to_fool[0].squeeze()),max(x_to_fool[0].squeeze()))
- # print('-----'*8)
- qr_curr, adv_curr = self.attack_single_run(x_to_fool, y_to_fool)
- output_curr = self.predict(adv_curr)
- if not self.targeted:
- acc_curr = output_curr.max(1)[1] == y_to_fool
- else:
- acc_curr = output_curr.max(1)[1] != y_to_fool
- ind_curr = (acc_curr == 0).nonzero(as_tuple=False).squeeze()
- acc[ind_to_fool[ind_curr]] = 0
- adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone()
- qr[ind_to_fool[ind_curr]] = qr_curr[ind_curr].clone()
- # if self.verbose:
- # print('restart {} - robust accuracy: {:.2%}'.format(
- # counter, acc.float().mean()),
- # '- cum. time: {:.1f} s'.format(
- # time.time() - startt))
- return qr, adv
|