main.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import sys
  2. sys.path.append('./trainer')
  3. import argparse
  4. import nutszebra_cifar10
  5. import binary_tree_wide_resnet
  6. import nutszebra_data_augmentation
  7. import nutszebra_optimizer
  8. if __name__ == '__main__':
  9. parser = argparse.ArgumentParser(description='cifar10')
  10. parser.add_argument('--load_model', '-m',
  11. default=None,
  12. help='trained model')
  13. parser.add_argument('--load_optimizer', '-o',
  14. default=None,
  15. help='optimizer for trained model')
  16. parser.add_argument('--load_log', '-l',
  17. default=None,
  18. help='optimizer for trained model')
  19. parser.add_argument('--save_path', '-p',
  20. default='./',
  21. help='model and optimizer will be saved every epoch')
  22. parser.add_argument('--epoch', '-e', type=int,
  23. default=200,
  24. help='maximum epoch')
  25. parser.add_argument('--batch', '-b', type=int,
  26. default=128,
  27. help='mini batch number')
  28. parser.add_argument('--gpu', '-g', type=int,
  29. default=-1,
  30. help='-1 means cpu mode, put gpu id here')
  31. parser.add_argument('--start_epoch', '-s', type=int,
  32. default=1,
  33. help='start from this epoch')
  34. parser.add_argument('--train_batch_divide', '-trb', type=int,
  35. default=1,
  36. help='divid batch number by this')
  37. parser.add_argument('--test_batch_divide', '-teb', type=int,
  38. default=1,
  39. help='divid batch number by this')
  40. parser.add_argument('--lr', '-lr', type=float,
  41. default=0.1,
  42. help='leraning rate')
  43. parser.add_argument('--d', '-d', type=int,
  44. default=4,
  45. help='d in https://arxiv.org/abs/1704.00509')
  46. parser.add_argument('--k', '-k', type=int,
  47. default=6,
  48. help='k in https://arxiv.org/abs/1704.00509')
  49. parser.add_argument('--n', '-n', type=int,
  50. default=2,
  51. help='n in https://arxiv.org/abs/1704.00509')
  52. args = parser.parse_args().__dict__
  53. print(args)
  54. lr = args.pop('lr')
  55. d = args.pop('d')
  56. k = args.pop('k')
  57. n = args.pop('n')
  58. print('generating model')
  59. model = binary_tree_wide_resnet.BitNet(10, out_channels=(16 * d, 32 * d, 64 * d), N=(n, ) * 3, K=(k, ) * 3, strides=(1, 2, 2))
  60. print('Done')
  61. print('Number of parameters: {}'.format(model.count_parameters()))
  62. optimizer = nutszebra_optimizer.OptimizerWideResBinaryTree(model, lr=lr)
  63. args['model'] = model
  64. args['optimizer'] = optimizer
  65. args['da'] = nutszebra_data_augmentation.DataAugmentationCifar10NormalizeSmall
  66. main = nutszebra_cifar10.TrainCifar10(**args)
  67. main.run()