utils.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import numpy as np
  2. import torch
  3. from model import DeepGL
  4. from logger import Logger
  5. import os
  6. def save_parameters(args, run_name):
  7. with open(os.path.join(args.log_path, run_name)+'/parameters.txt', 'w') as f:
  8. f.write('num_blocks {}, lr {}, beta1 {} beta2 {}, batch_size {} gamma {} scheduler_step {}'.format(
  9. args.num_blocks, args.lr, args.beta1, args.beta2, args.batch_size,
  10. args.gamma, args.scheduler_step
  11. ))
  12. def prepare_directories(args, run_name):
  13. if not os.path.isdir(args.data_path):
  14. raise Exception("Invalid data path. No such directory")
  15. if not os.path.isdir(args.log_path):
  16. os.makedirs(args.log_path)
  17. if args.pretrained_path:
  18. if not os.path.isdir(args.pretrained_path) or \
  19. not os.path.isdir(os.path.join(args.pretrained_path, 'states')):
  20. raise Exception("Invalid path. No such directory with pretrained model")
  21. else:
  22. exp_path = os.path.join(args.log_path, run_name)
  23. os.makedirs(exp_path)
  24. os.makedirs(os.path.join(exp_path, 'samples'))
  25. os.makedirs(os.path.join(exp_path, 'states'))
  26. os.makedirs(os.path.join(exp_path, 'tensorboard_logs'))
  27. def build_model(args):
  28. model = DeepGL(args.num_blocks)
  29. if args.pretrained_path:
  30. model.load_state_dict(torch.load(
  31. os.path.join(args.pretrained_path, 'samples') + '/' + str(args.load_step) + '.pt'))
  32. return model
  33. def prepare_logger(path):
  34. if not os.path.isdir(path):
  35. os.makedirs(path)
  36. logger = Logger(path)
  37. return logger