utils_3.py 246 B

123456789
  1. def build_model(args):
  2. model = DeepGL(args.num_blocks)
  3. if args.pretrained_path:
  4. model.load_state_dict(torch.load(
  5. os.path.join(args.pretrained_path, 'samples') + '/' + str(args.load_step) + '.pt'))
  6. return model