pointwise.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # torch imports
  2. import torch
  3. # other modules
  4. import argparse
  5. import numpy as np
  6. import foolbox
  7. from tqdm import trange
  8. # PATHING
  9. import os
  10. import sys
  11. root = os.path.abspath(os.curdir)
  12. sys.path.append(root)
  13. # internal imports
  14. from utils.conv.sparse_rs import RSAttack
  15. from utils.conv.models import *
  16. from utils.helpers import *
  17. def main():
  18. #---------------------------------------------------------------------------#
  19. # Parse the only input, path
  20. parser = argparse.ArgumentParser(description='Runs the pointwise attack on network for both beta=100 and beta=1')
  21. parser.add_argument("exp_path", help="pass the RELATIVE path of the PARENT directory of your network (net.pth)")
  22. # optinal arguements
  23. parser.add_argument("bs", nargs="?", type=int, default=64, help="batch size fed into foolbox")
  24. parser.add_argument("num_batches", nargs="?", type=int, default=16, help="batch size fed into foolbox")
  25. parser.add_argument("num_iters", nargs="?", type=int, default=10, help="times to repeat the attack")
  26. args = parser.parse_args()
  27. exp_path = root + '/' + args.exp_path + '/'
  28. bs = args.bs
  29. num_batches = args.num_batches
  30. num_iters = args.num_iters
  31. # since we only tested with k=10, we can simply check the name of dir (rob or og)
  32. if 'og' in exp_path.split('/'):
  33. k = 0
  34. else:
  35. k = 10
  36. #---------------------------------------------------------------------------#
  37. # check cuda
  38. if torch.cuda.is_available():
  39. device = torch.device('cuda:0')
  40. else:
  41. device = torch.device('cpu')
  42. # get data loaders
  43. transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])
  44. testset = datasets.CIFAR10(root=root+'/datasets/CIFAR/',train = False,download = False, transform=transform)
  45. test_loader = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False, num_workers = 2)
  46. dataiter = iter(test_loader)
  47. images, labels = dataiter.next()
  48. images = images.numpy()
  49. # prep the network
  50. net_path = exp_path + 'net.pth'
  51. if k == 0:
  52. net = VGG().to(device)
  53. else:
  54. net = rob_VGG(k).to(device)
  55. net.load_state_dict(torch.load(net_path, map_location=device))
  56. net.eval()
  57. # RUN the attack
  58. for j in range(2):
  59. if j%2 == 1:
  60. bounds = (-260,260) # aprox beta=100
  61. else:
  62. bounds = (-2.429065704345703,2.7537312507629395)
  63. # SETUP #
  64. fmodel = foolbox.models.PyTorchModel(net, bounds=bounds, num_classes=10, channel_axis=1, device=device)
  65. attack = foolbox.attacks.PointwiseAttack(model=fmodel,distance=foolbox.distances.L0)
  66. Data = {}
  67. Data['final_l0'] = []
  68. # ITERATE #
  69. print("-----"*8)
  70. print("running attack with bounds: ",bounds)
  71. for batch in trange(num_batches):
  72. images, labels = dataiter.next()
  73. images = images.numpy()
  74. # final l0 distances of entire batch
  75. final_l0 = np.ones(bs)*10000
  76. # best attacked images
  77. final_images = images.copy()
  78. for _ in range(num_iters):
  79. # run the attack saving the best adversarial images
  80. adv_images = attack(images.copy(), labels.numpy().copy())
  81. out = net(torch.tensor(adv_images).to(device))
  82. _, pred_adv = torch.max(out.data, 1)
  83. # calculate L_0 distances for these new images
  84. perturbed = np.zeros((bs,32,32),dtype=bool)
  85. for z in range(bs):
  86. for ch in range(3):
  87. # adds True if pixel at this channel was perturbed
  88. perturbed[z] += abs(images[z,ch,:]-adv_images[z,ch,:]>0.0001)
  89. # this gives list of total perturbed pixels for each image
  90. total_l0 = perturbed.sum(axis=1).sum(axis=1)
  91. # we only save on three conditions
  92. cond_1 = total_l0>0 # there was an actual attack
  93. cond_2 = total_l0 < final_l0 # the attack was better than the best one
  94. cond_3 = (pred_adv != labels.to(device)).to('cpu').numpy() # the attack was succesful
  95. # save both the images and the best L_0 distances for this iteration
  96. improved = cond_1*cond_2*cond_3
  97. final_images[improved] = adv_images[improved]
  98. final_l0[improved] = np.minimum(total_l0[improved],final_l0[improved])
  99. # save the best l_0 distances for this batch
  100. Data['final_l0'].append(final_l0)
  101. Data['final_l0'] = np.asarray(Data['final_l0'])
  102. print('Final median l0 distance: ',np.median(Data['final_l0']))
  103. if __name__ == '__main__':
  104. main()