train.py 887 B

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import torch
  2. # PATHING
  3. import os
  4. import sys
  5. root = os.path.abspath(os.curdir)
  6. sys.path.append(root)
  7. # local imports
  8. from utils.conv.adv import adv_trainer
  9. def main():
  10. # check cuda
  11. if torch.cuda.is_available():
  12. device = torch.device('cuda:0')
  13. else:
  14. device = torch.device('cpu')
  15. k = 10 # if k=0 the network will use the regular VGG net
  16. perturb = 10
  17. save_dir = '/new_trained/conv/test/'
  18. os.mkdir(root+save_dir)
  19. trainer = adv_trainer(
  20. root = root,
  21. k = k,
  22. perturb = perturb,
  23. beta = 100,
  24. seed = 104,
  25. save_dir = save_dir,
  26. bs=128,
  27. num_iters=10,
  28. num_queries=300,
  29. num_epochs=25,
  30. device=device)
  31. trainer.run()
  32. if __name__ == '__main__':
  33. main()