rs.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # torch imports
  2. import torch
  3. # other modules
  4. import argparse
  5. # PATHING
  6. import os
  7. import sys
  8. root = os.path.abspath(os.curdir)
  9. sys.path.append(root)
  10. # internal imports
  11. from utils.conv.sparse_rs import RSAttack
  12. from utils.conv.models import *
  13. from utils.helpers import *
  14. def main():
  15. # Parse the only input path
  16. parser = argparse.ArgumentParser(description='Calculates robust accuracy of network')
  17. parser.add_argument("exp_path", help="pass the RELATIVE path of the PARENT directory of your network (net.pth)")
  18. # optional arguement
  19. parser.add_argument("num_queries", nargs="?", type=int, default=300, help="num queries to run the attack for")
  20. parser.add_argument("beta", nargs="?", type=int, default=100, help="beta value (domain scaling)")
  21. parser.add_argument("perturb", nargs="?", type=int, default=10, help="adversary budget")
  22. parser.add_argument("num_batches", nargs="?", type=int, default=39, help="how many batches to evaluate from the MNIST testset (batchsize is 256)")
  23. args = parser.parse_args()
  24. exp_path = root + '/' + args.exp_path + '/'
  25. num_queries = args.num_queries
  26. num_batches = args.num_batches
  27. perturb = args.perturb
  28. beta = args.beta
  29. # since we only tested with k=10, we can simply check the name of dir (rob or og)
  30. if 'og' in exp_path.split('/'):
  31. k = 0
  32. else:
  33. k = 10
  34. # check cuda
  35. if torch.cuda.is_available():
  36. device = torch.device('cuda:0')
  37. else:
  38. device = torch.device('cpu')
  39. # load data and network
  40. Data = prep_CIFAR(root, bs=256)
  41. mu, sigma = mu_sigma(beta, CIFAR=True)
  42. net_path = exp_path + 'net.pth'
  43. if k == 0:
  44. eval_net = VGG_eval(mu,sigma).to(device)
  45. else:
  46. eval_net = rob_VGG_eval(k,mu,sigma).to(device)
  47. eval_net.load_state_dict(torch.load(net_path, map_location=device))
  48. eval_net.eval()
  49. # RUN THE ATTACK
  50. # We use the double the queries as we wanta more accurate r_test value (one batch)
  51. adversary = RSAttack(eval_net,
  52. norm='L0',
  53. eps=perturb,
  54. n_queries=num_queries,
  55. n_restarts=1,
  56. seed=12345,
  57. device=device,
  58. log_path=exp_path+'log_temp.txt'
  59. )
  60. # First compute the % robust accuracy on test set for only one batch
  61. all_acc = []
  62. # for i in trange(len(Data['x_test'])):
  63. # 39 full batches in the dataset
  64. for i in range(num_batches):
  65. x = (Data['x_test'][i].to(device)-mu)/sigma
  66. y = Data['y_test'][i].to(device)
  67. with torch.no_grad():
  68. # find points originally correctly classified
  69. output = eval_net(x)
  70. pred = (output.max(1)[1] == y).float().to(device)
  71. ind_to_fool = (pred == 1).nonzero(as_tuple=False).squeeze()
  72. # preform the attack on corresponding indeces and save
  73. _, adv = adversary.perturb(x[ind_to_fool], y[ind_to_fool])
  74. # analyze the attack
  75. output = eval_net(adv.to(device))
  76. r_acc = (output.max(1)[1] == y[ind_to_fool]).float().to(device)
  77. adversary.logger.log('robust accuracy {:.2%}'.format(r_acc.float().mean()))
  78. all_acc.append(r_acc.float().mean().cpu().numpy()*100)
  79. all_acc = np.asarray(all_acc)
  80. print("Robust Accuracy: ",np.mean(all_acc),'%')
  81. if __name__ == '__main__':
  82. main()