1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- # torch imports
- import torch
- # other modules
- import argparse
- # PATHING
- import os
- import sys
- root = os.path.abspath(os.curdir)
- sys.path.append(root)
- # internal imports
- from utils.lin.sparse_rs import RSAttack
- from utils.lin.models import *
- from utils.helpers import *
- def main():
- # Parse the only input, path
- parser = argparse.ArgumentParser(description='Calculates robust accuracy of network')
- parser.add_argument("exp_path", help="pass the RELATIVE path of the PARENT directory of your network (net.pth)")
- # optional arguement
- parser.add_argument("num_queries", nargs="?", type=int, default=300, help="num queries to run the attack for")
- parser.add_argument("beta", nargs="?", type=int, default=100, help="beta value (domain scaling)")
- parser.add_argument("perturb", nargs="?", type=int, default=10, help="adversary budget")
- parser.add_argument("num_batches", nargs="?", type=int, default=39, help="how many batches to evaluate from the MNIST testset (batchsize is 256)")
- args = parser.parse_args()
- exp_path = root + '/' + args.exp_path + '/'
- num_queries = args.num_queries
- num_batches = args.num_batches
- perturb = args.perturb
- beta = args.beta
- # since we only tested with k=10, we can simply check the name of dir (rob or og)
- if 'og' in exp_path.split('/'):
- k = 0
- else:
- k = 10
- # check cuda
- if torch.cuda.is_available():
- device = torch.device('cuda:0')
- else:
- device = torch.device('cpu')
- # load data and network
- Data = prep_MNIST(root, bs=256)
- mu, sigma = mu_sigma(beta, CIFAR=False)
- net_path = exp_path + 'net.pth'
- if k == 0:
- eval_net = L_Net_eval(mu,sigma).to(device)
- else:
- eval_net = r_L_Net_eval(k,mu,sigma).to(device)
- eval_net.load_state_dict(torch.load(net_path, map_location=device))
- eval_net.eval()
- # RUN THE ATTACK
- # We use the double the queries as we wanta more accurate r_test value (one batch)
- adversary = RSAttack(eval_net,
- norm='L0',
- eps=perturb,
- n_queries=num_queries,
- n_restarts=1,
- seed=12345,
- device=device,
- log_path=exp_path+'log_temp.txt'
- )
- # First compute the % robust accuracy on test set for only one batch
- all_acc = []
- # for i in trange(len(Data['x_test'])):
- # 39 full batches in the dataset
- for i in range(num_batches):
- x = (Data['x_test'][i].to(device)-mu)/sigma
- y = Data['y_test'][i].to(device)
- with torch.no_grad():
- # find points originally correctly classified
- output = eval_net(x)
- pred = (output.max(1)[1] == y).float().to(device)
- ind_to_fool = (pred == 1).nonzero(as_tuple=False).squeeze()
- # preform the attack on corresponding indeces and save
- _, adv = adversary.perturb(x[ind_to_fool], y[ind_to_fool])
- # analyze the attack
- output = eval_net(adv.to(device))
- r_acc = (output.max(1)[1] == y[ind_to_fool]).float().to(device)
- adversary.logger.log('robust accuracy {:.2%}'.format(r_acc.float().mean()))
- all_acc.append(r_acc.float().mean().cpu().numpy()*100)
- all_acc = np.asarray(all_acc)
- print("Robust Accuracy: ",np.mean(all_acc),'%')
- if __name__ == '__main__':
- main()
|