acc.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # torch imports
  2. import torch
  3. # other modules
  4. import argparse
  5. # PATHING
  6. import os
  7. import sys
  8. root = os.path.abspath(os.curdir)
  9. sys.path.append(root)
  10. # internal imports
  11. from utils.lin.sparse_rs import RSAttack
  12. from utils.lin.models import *
  13. from utils.helpers import *
  14. def main():
  15. # Parse the only input, path
  16. parser = argparse.ArgumentParser(description='Calculates the clean accuracy of network using sparse rs')
  17. parser.add_argument("exp_path", help="pass the RELATIVE path of the PARENT directory of your network (net.pth)")
  18. args = parser.parse_args()
  19. exp_path = root + '/' + args.exp_path + '/'
  20. # since we only tested with k=10, we can simply check the name of dir (rob or og)
  21. if 'og' in exp_path.split('/'):
  22. k = 0
  23. else:
  24. k = 10
  25. # check cuda
  26. if torch.cuda.is_available():
  27. device = torch.device('cuda:0')
  28. else:
  29. device = torch.device('cpu')
  30. # load data and network
  31. Data = prep_MNIST(root, bs=256)
  32. net_path = exp_path + 'net.pth'
  33. if k == 0:
  34. net = L_Net().to(device)
  35. else:
  36. net = r_L_Net(k).to(device)
  37. net.load_state_dict(torch.load(net_path, map_location=device))
  38. net.eval()
  39. # test the accuracy
  40. correct = 0
  41. total = 0
  42. with torch.no_grad():
  43. for inputs, labels in zip(Data['x_test'],Data['y_test']):
  44. inputs = inputs.to(device)
  45. labels = labels.to(device)
  46. outputs = net(inputs)
  47. _, predicted = torch.max(outputs.data, 1)
  48. total += labels.size(0)
  49. correct += (predicted == labels).sum().item()
  50. acc = 100 * (correct / total)
  51. print("Clean Accuracy: ",acc,'%')
  52. if __name__ == '__main__':
  53. main()