pointwise.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # torch imports
  2. import torch
  3. # other modules
  4. import argparse
  5. import numpy as np
  6. import foolbox
  7. from tqdm import trange
  8. # PATHING
  9. import os
  10. import sys
  11. root = os.path.abspath(os.curdir)
  12. sys.path.append(root)
  13. # internal imports
  14. from utils.lin.sparse_rs import RSAttack
  15. from utils.lin.models import *
  16. from utils.helpers import *
  17. def main():
  18. #---------------------------------------------------------------------------#
  19. # Parse the only input, path
  20. parser = argparse.ArgumentParser(description='Runs the pointwise attack on network for both beta=100 and beta=1')
  21. parser.add_argument("exp_path", help="pass the RELATIVE path of the PARENT directory of your network (net.pth)")
  22. # optinal arguements
  23. parser.add_argument("bs", nargs="?", type=int, default=64, help="batch size fed into foolbox")
  24. parser.add_argument("num_batches", nargs="?", type=int, default=16, help="batch size fed into foolbox")
  25. parser.add_argument("num_iters", nargs="?", type=int, default=10, help="times to repeat the attack")
  26. args = parser.parse_args()
  27. exp_path = root + '/' + args.exp_path + '/'
  28. bs = args.bs
  29. num_batches = args.num_batches
  30. num_iters = args.num_iters
  31. # since we only tested with k=10, we can simply check the name of dir (rob or og)
  32. if 'og' in exp_path.split('/'):
  33. k = 0
  34. else:
  35. k = 10
  36. #---------------------------------------------------------------------------#
  37. # check cuda
  38. if torch.cuda.is_available():
  39. device = torch.device('cuda:0')
  40. else:
  41. device = torch.device('cpu')
  42. # get data loaders
  43. transform = transforms.Compose(
  44. [transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)),flatten()])
  45. testset = datasets.MNIST(root=root+'/datasets/',train = False,download = False, transform=transform)
  46. test_loader = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=False, num_workers = 2)
  47. dataiter = iter(test_loader)
  48. images, labels = dataiter.next()
  49. images = images.numpy()
  50. # prep the network
  51. net_path = exp_path + 'net.pth'
  52. if k == 0:
  53. net = L_Net().to(device)
  54. else:
  55. net = r_L_Net(k).to(device)
  56. net.load_state_dict(torch.load(net_path, map_location=device))
  57. net.eval()
  58. # RUN the attack
  59. for j in range(2):
  60. if j%2 == 1:
  61. bounds = (-160,160) # aprox beta=100
  62. else:
  63. bounds = (images.min(),images.max())
  64. # SETUP #
  65. fmodel = foolbox.models.PyTorchModel(net, bounds=bounds, num_classes=10, channel_axis=1, device=device)
  66. attack = foolbox.attacks.PointwiseAttack(model=fmodel,distance=foolbox.distances.L0)
  67. Data = {}
  68. Data['final_l0'] = []
  69. # ITERATE #
  70. print("-----"*8)
  71. print("running attack with bounds: ",bounds)
  72. for batch in trange(num_batches):
  73. images, labels = dataiter.next()
  74. images = images.numpy()
  75. # final l0 distances of entire batch
  76. final_l0 = np.ones(bs)*10000
  77. # best attacked images
  78. final_images = images.copy()
  79. for _ in range(num_iters):
  80. # run the attack saving the best adversarial images
  81. adv_images = attack(images.copy(), labels.numpy())
  82. out = net(torch.tensor(adv_images).to(device))
  83. _, pred_adv = torch.max(out.data, 1)
  84. # calculate L_0 distances for these new images
  85. perturbed = np.zeros((bs,28*28),dtype=bool)
  86. for z in range(bs):
  87. perturbed[z] += abs(images[z,:]-adv_images[z,:]>0.0001)
  88. # this gives list of total perturbed pixels for each image
  89. total_l0 = perturbed.sum(axis=1)
  90. # we only save on three conditions
  91. cond_1 = total_l0>0 # there was an actual attack
  92. cond_2 = total_l0 < final_l0 # the attack was better than the best one
  93. cond_3 = (pred_adv != labels.to(device)).to('cpu').numpy() # the attack was succesful
  94. # save both the images and the best L_0 distances for this iteration
  95. improved = cond_1*cond_2*cond_3
  96. final_images[improved] = adv_images[improved]
  97. final_l0[improved] = np.minimum(total_l0[improved],final_l0[improved])
  98. # save the best l_0 distances for this batch
  99. Data['final_l0'].append(final_l0)
  100. Data['final_l0'] = np.asarray(Data['final_l0'])
  101. print('Final median l0 distance: ',np.median(Data['final_l0']))
  102. if __name__ == '__main__':
  103. main()