config_utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import warnings
  2. import yaml
  3. from os.path import join as pjoin
  4. import textworld
  5. import shutil
  6. import time, os
  7. def change_config(args, method='drqn', wait_time=2, kind='normal', tf=False, ensemble=False, test=False): # kind = noisy | normal
  8. student = args.student
  9. exp_act_list = args.exp_act
  10. challenge_type = 'coin_collector'
  11. if args.very_verbose:
  12. args.verbose = args.very_verbose
  13. warnings.simplefilter("default", textworld.TextworldGenerationWarning)
  14. # Read config from yaml file.
  15. if challenge_type == 'custom_tw' or challenge_type == 'treasure_hunter':
  16. config_file = pjoin(args.config_dir, 'config_{}_{}.yaml'.format(args.type, args.num_games))
  17. elif challenge_type == 'coin_collector': # args.type is not None:
  18. if args.num_games is not None:
  19. config_file = pjoin(args.config_dir, 'config_{}_{}.yaml'.format(args.type, args.num_games))
  20. else:
  21. config_file = pjoin(args.config_dir, 'config_{}.yaml'.format(args.type, ))
  22. else:
  23. config_file = pjoin(args.config_dir, 'config.yaml')
  24. with open(config_file) as reader:
  25. config = yaml.safe_load(reader)
  26. prefixed_method_name = method + ('_att' if args.use_attention else '')
  27. prefixed_method_name = 'coin_collector_' + prefixed_method_name
  28. config['bootstrap']['threshold'] = args.threshold
  29. config['bootstrap']['prune'] = args.prune if hasattr(args, 'prune') else False
  30. config['bootstrap']['embed'] = args.embed if hasattr(args, 'embed') else 'cnet'
  31. use_dropout = args.dropout if hasattr(args, 'dropout') else None
  32. if exp_act_list:
  33. prefixed_method_name += '_exp_act'
  34. if use_dropout is not None and not student: # student model does not use drop-out
  35. prefixed_method_name += '_drop_{}'.format(use_dropout)
  36. ######## Base model path #################
  37. teacher_model_path = config['training']['scheduling']['model_checkpoint_path']. \
  38. replace('dqrn', prefixed_method_name). \
  39. replace('summary_', '').replace('.pt', '_train.pt')
  40. ######## Bootstrapped model path #################
  41. config['training']['scheduling']['teacher_model_checkpoint_path'] = \
  42. config['training']['scheduling']['model_checkpoint_path']. \
  43. replace('dqrn', prefixed_method_name). \
  44. replace('summary_', '')
  45. config['general']['student'] = student
  46. if student:
  47. print('##')
  48. prefixed_method_name += '_student'
  49. prefixed_method_name += '_thres_{}'.format(config['bootstrap']['threshold'])
  50. if config['bootstrap']['prune']:
  51. prefixed_method_name += '_prune'
  52. if config['bootstrap']['embed'] is not 'cnet':
  53. prefixed_method_name += '_embed_{}'.format(config['bootstrap']['embed'])
  54. ## Change the method specific info here
  55. config['general']['experiment_tag'] = config['general']['experiment_tag'].replace('drqn', method)
  56. config['general']['experiments_dir'] = config['general']['experiments_dir'].replace('summary', prefixed_method_name)
  57. config['training']['scheduling']['model_checkpoint_path'] = \
  58. config['training']['scheduling']['model_checkpoint_path']. \
  59. replace('dqrn', prefixed_method_name). \
  60. replace('summary_', '')
  61. config['general']['teacher_model_path'] = teacher_model_path
  62. config['general']['use_attention'] = args.use_attention
  63. config['general']['student'] = student
  64. config['general']['exp_act'] = exp_act_list
  65. print(config['general']['experiment_tag'])
  66. print(config['general']['experiments_dir'])
  67. print(config['training']['scheduling']['model_checkpoint_path'])
  68. print('Train env name : ', config['general']['env_id'])
  69. print('Valid env name : ', config['general']['valid_env_id'])
  70. time.sleep(wait_time)
  71. if os.path.exists(config['general']['experiments_dir']) and not test:
  72. if not args.force_remove:
  73. prompt = input('Are you sure you want to delete {} (yes/no):'.
  74. format(config['general']['experiments_dir']))
  75. # if prompt == 'yes' or do_not_prompt:
  76. if prompt == 'yes':
  77. print('##' * 30)
  78. print('Removing directory ', config['general']['experiments_dir'])
  79. print('##' * 30)
  80. shutil.rmtree(config['general']['experiments_dir'])
  81. else:
  82. print('##' * 30)
  83. print('Removing directory ', config['general']['experiments_dir'])
  84. print('##' * 30)
  85. shutil.rmtree(config['general']['experiments_dir'])
  86. else:
  87. if os.path.exists(config['general']['experiments_dir']):
  88. print('{} already exists. If you want to delete and '
  89. 'start fresh use \'-fr\' option.'.
  90. format(config['general']['experiments_dir']))
  91. return config
  92. def get_prefix(args, method='drqn'):
  93. prefixed_method_name = 'coin_collector_'
  94. prefixed_method_name += (method + ('_att' if args.use_attention else ''))
  95. if args.exp_act:
  96. prefixed_method_name += '_exp_act'
  97. prefixed_method_name += '_ng_{}'.format(args.num_games)
  98. prefixed_method_name += '_type_{}'.format(args.type)
  99. return prefixed_method_name