evaluate_agents_att.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import logging
  2. import numpy as np
  3. from collections import namedtuple
  4. import random
  5. # from matplotlib import pyplot as plt
  6. import torch
  7. import torch.nn.functional as F
  8. from crest.helper.config_utils import change_config, get_prefix
  9. from crest.helper.utils import read_file, save2file
  10. from crest.helper.bootstrap_utils import CREST
  11. from crest.helper.nlp_utils import compact_text
  12. from crest.helper.generic import dict2list
  13. logger = logging.getLogger(__name__)
  14. import gym
  15. import gym_textworld # Register all textworld environments.
  16. from crest.agents.lstm_drqn.agent import RLAgent
  17. from crest.agents.lstm_drqn.test_agent import test
  18. def get_agent(config, env):
  19. word_vocab = dict2list(env.observation_space.id2w)
  20. word2id = {}
  21. for i, w in enumerate(word_vocab):
  22. word2id[w] = i
  23. if config['general']['exp_act']:
  24. verb_list = read_file("data/vocabs/trial_run_custom_tw/verb_vocab.txt")
  25. object_name_list = read_file("data/vocabs/common_nouns.txt")
  26. else:
  27. verb_list = ["go", "take", "unlock", "lock", "drop", "look", "insert", "open", "inventory", "close"]
  28. object_name_list = ["east", "west", "north", "south", "coin", "apple", "carrot", "textbook", "passkey", "keycard"]
  29. # Add missing words in word2id
  30. for w in verb_list:
  31. if w not in word2id.keys():
  32. word2id[w] = len(word2id)
  33. word_vocab += [w, ]
  34. for w in object_name_list:
  35. if w not in word2id.keys():
  36. word2id[w] = len(word2id)
  37. word_vocab += [w, ]
  38. verb_map = [word2id[w] for w in verb_list if w in word2id]
  39. noun_map = [word2id[w] for w in object_name_list if w in word2id]
  40. print('Loaded {} verbs'.format(len(verb_map)))
  41. print('Loaded {} nouns'.format(len(noun_map)))
  42. print('##' * 30)
  43. print('Missing verbs and objects:')
  44. print([w for w in verb_list if w not in word2id])
  45. print([w for w in object_name_list if w not in word2id])
  46. print('Loading DRQN agent')
  47. if config['general']['student']:
  48. agent = RLAgent(config, word_vocab, verb_map, noun_map, att=config['general']['use_attention'],
  49. bootstrap=config['general']['student'], embed=config['bootstrap']['embed'])
  50. else:
  51. agent = RLAgent(config, word_vocab, verb_map, noun_map, att=config['general']['use_attention'], bootstrap=config ['general']['student'],)
  52. return agent
  53. class Evaluator():
  54. def __init__(self, config, args, threshold=0.3):
  55. self.config = config
  56. self.args = args
  57. teacher_path = config['general']['teacher_model_path']
  58. print('Setting up TextWorld environment...')
  59. def load_valid_env(self, valid_env_name):
  60. test_batch_size = 1
  61. valid_env_id = gym_textworld.make_batch(env_id=valid_env_name, batch_size=test_batch_size, parallel=True)
  62. self.valid_env = gym.make(valid_env_id)
  63. self.valid_env.seed(config['general']['random_seed'])
  64. print('Loaded env name: ', valid_env_name)
  65. def load_agent(self):
  66. self.agent = get_agent(config, self.valid_env)
  67. model_checkpoint_path = config['training']['scheduling']['model_checkpoint_path']
  68. load_path = model_checkpoint_path.replace('.pt', '_best.pt')
  69. print('Loading model from : ', load_path)
  70. self.agent.model.load_state_dict(torch.load(load_path))
  71. self.hidden_size = config['model']['lstm_dqn']['action_scorer_hidden_dim']
  72. self.hash_features = {}
  73. def inference(self, agent, env, prune=False, action_dist=None):
  74. batch_size = 1
  75. assert batch_size == 1, "Batchsize should be 1 during inference"
  76. agent.model.eval()
  77. obs, infos = env.reset()
  78. agent.reset(infos)
  79. id_string_0 = agent.get_observation_strings(infos)[0]
  80. provide_prev_action = self.config['general']['provide_prev_action']
  81. dones = [False] * batch_size
  82. rewards = [0]
  83. prev_actions = ["" for _ in range(batch_size)] if provide_prev_action else None
  84. input_description, description_id_list, desc, _ =\
  85. agent.get_game_step_info(obs, infos, prev_actions, prune=prune,
  86. teacher_actions=action_dist, ret_desc=True,)
  87. curr_ras_hidden, curr_ras_cell = None, None # ras: recurrent action scorer
  88. if prune:
  89. desc_strings, desc_disc = agent.get_similarity_scores(obs, infos, prev_actions, prune=prune,
  90. teacher_actions=action_dist, ret_desc=True,)
  91. self.desc.append(list(desc_disc.keys()))
  92. self.desc_scores.append(list(desc_disc.values()))
  93. self.desc_strings.append(desc_strings)
  94. if id_string_0 in self.id_string_list:
  95. print('Already encountered this game. Skipping...')
  96. return
  97. self.id_string_list.append(id_string_0)
  98. self.game_num += 1
  99. while not all(dones):
  100. v_idx, n_idx, _, curr_ras_hidden, curr_ras_cell = agent.generate_one_command(input_description, curr_ras_hidden, curr_ras_cell,
  101. epsilon=0.0, return_att=args.use_attention)
  102. if args.use_attention:
  103. softmax_att = agent.get_softmax_attention()
  104. else:
  105. softmax_att = None
  106. qv, qn = agent.get_qvalues()
  107. _, v_idx_maxq, _, n_idx_maxq = agent.choose_maxQ_command(qv, qn)
  108. chosen_strings = agent.get_chosen_strings(v_idx_maxq.detach(), n_idx_maxq.detach())
  109. sorted_tokens=None
  110. sorted_atts=None
  111. obs, rewards, dones, infos = env.step(chosen_strings)
  112. if provide_prev_action:
  113. prev_actions = chosen_strings
  114. IR = [info["intermediate_reward"] for info in infos]
  115. if type(dones) is bool:
  116. dones = [dones] * batch_size
  117. agent.rewards.append(rewards)
  118. agent.dones.append(dones)
  119. agent.intermediate_rewards.append([info["intermediate_reward"] for info in infos])
  120. input_description, description_id_list, desc, _ =\
  121. agent.get_game_step_info(obs, infos, prev_actions, prune=prune, teacher_actions=action_dist, ret_desc=True,)
  122. if prune:
  123. desc_strings, desc_disc = agent.get_similarity_scores(obs, infos, prev_actions, prune=prune, teacher_actions=action_dist, ret_desc=True)
  124. self.desc.append(list(desc_disc.keys()))
  125. self.desc_scores.append(list(desc_disc.values()))
  126. self.desc_strings.append(desc_strings)
  127. _, _, orig_desc, _ = agent.get_game_step_info(obs, infos, prev_actions, prune=False, ret_desc=True,)
  128. for x, y in zip(orig_desc, desc):
  129. self.orig_data += [' '.join(x), ' '.join(y)]
  130. agent.finish()
  131. R = agent.final_rewards.mean()
  132. S = agent.step_used_before_done.mean()
  133. IR = agent.final_intermediate_rewards.mean()
  134. msg = '====EVAL==== R={:.3f}, IR={:.3f}, S={:.3f}'
  135. msg = msg.format(R, IR, S)
  136. print(msg)
  137. print("\n")
  138. self.result_logs['R'].append(R)
  139. self.result_logs['IR'].append(IR)
  140. self.result_logs['S'].append(S)
  141. def infer(self):
  142. numgames = self.args.num_test_games
  143. prune = self.args.prune
  144. save_dict = {}
  145. count = 0
  146. self.id_string_list = []
  147. self.result_logs = {'R': [], 'IR': [], 'S': []}
  148. if prune:
  149. self.prune_filename = config['training']['scheduling']['model_checkpoint_path'].replace('.pt', '_level_{}_logs.txt'.format(args.level)).replace('saved_models', 'prune_logs')
  150. self.score_filename = config['training']['scheduling']['model_checkpoint_path'].replace('.pt', '_level_{}_logs.npz'.format(args.level)).replace('saved_models', 'score_logs')
  151. self.orig_data = []
  152. self.desc = []
  153. self.desc_scores = []
  154. self.desc_strings = []
  155. if args.method=='drqn':
  156. if prune:
  157. prefix_name = get_prefix(self.args)
  158. filename = './data/teacher_data/{}.npz'.format(prefix_name)
  159. teacher_dict = np.load(filename, allow_pickle=True)
  160. global_action_set = set()
  161. for k in teacher_dict.keys():
  162. if k=='allow_pickle':
  163. continue
  164. action_dist = teacher_dict[k][-1]
  165. action_dist = [x for x in action_dist.keys()]
  166. global_action_set.update(action_dist)
  167. self.game_num = 0
  168. print('here')
  169. while (len(self.id_string_list)<int(numgames)):
  170. print('Game number : ', self.game_num)
  171. if prune:
  172. self.inference(self.agent, self.valid_env, prune=prune, action_dist=[list(global_action_set)])
  173. else:
  174. self.inference(self.agent, self.valid_env, prune=False)
  175. if prune:
  176. save2file(self.prune_filename, self.orig_data)
  177. np.savez(self.score_filename, desc=self.desc, desc_scores=self.desc_scores, desc_strings=self.desc_strings)
  178. return self.result_logs
  179. if __name__ == '__main__':
  180. import os, argparse, pickle
  181. parser = argparse.ArgumentParser(description="train network.")
  182. parser.add_argument("-c", "--config_dir", default='config', help="the default config directory")
  183. parser.add_argument("-type", "--type", default=None, help="easy | medium | hard")
  184. parser.add_argument("-ng", "--num_games", default=None, help="easy | medium | hard")
  185. parser.add_argument("-v", "--verbose", help="increase output verbosity", action="store_true")
  186. parser.add_argument("-vv", "--very-verbose", help="print out warnings", action="store_true")
  187. parser.add_argument("-fr", "--force-remove", help="remove experiments directory to start new", action="store_true")
  188. parser.add_argument("-att", "--use_attention", help="Use attention in the encoder model", action="store_true")
  189. parser.add_argument("-th", "--threshold", help="Filter threshold value for cosine similarity", default=0.3, type=float)
  190. parser.add_argument("-ea", "--exp_act", help="Use expanded vocab list for actions", action="store_true")
  191. parser.add_argument("-prune", "--prune", help="Use pruning or not", action="store_true")
  192. parser.add_argument("-level", "--level", help="how many levels in the game to test", type=int, default=15)
  193. parser.add_argument("-m", "--method", help="What method to use DRQN/DQN", type=str, default="drqn")
  194. parser.add_argument("-emb", "--embed", default='cnet', type=str) # 'cnet' | 'glove' | 'word2vec' | 'bert'
  195. parser.add_argument("-drop", "--dropout", default=0, type=float)
  196. parser.add_argument("-student", "--student", help="Whether Teacher or Student model", action="store_true")
  197. args = parser.parse_args()
  198. assert not args.force_remove
  199. config = change_config(args, method=args.method, test=True)
  200. args.num_test_games = 20
  201. true_valid_name = config['general']['valid_env_id']
  202. evaluator = Evaluator(config, args, threshold=args.threshold)
  203. evaluator.load_valid_env(true_valid_name)
  204. evaluator.load_agent()
  205. filename = config['training']['scheduling']['model_checkpoint_path'].replace('.pt', '_level_{}_logs.txt'.format(args.level)).replace('saved_models', 'emnlp_logs/logs_{}_{}'.format(args.type, args.num_games))
  206. dirname = os.path.dirname(filename)
  207. os.makedirs(dirname, exist_ok=True)
  208. fp = open(filename, 'w')
  209. results = []
  210. for k in range(3):
  211. config['general']['valid_env_id'] = true_valid_name
  212. config['general']['valid_env_id'] = config['general']['valid_env_id'].replace('gamesize10', 'gamesize20')
  213. config['general']['valid_env_id'] = config['general']['valid_env_id'].replace('_validation', '_test')
  214. config['general']['valid_env_id'] = config['general']['valid_env_id'].replace('_step50', '_step100').replace('_step75', '_step100')
  215. config['general']['valid_env_id'] = config['general']['valid_env_id'].replace('_level15', '_level{}'.format(args.level))
  216. config['general']['valid_env_id'] = config['general']['valid_env_id'].replace('_seed9', '_seed{}'.format(k+1))
  217. fp.writelines('##'*30)
  218. fp.writelines('\n')
  219. fp.writelines(config['general']['valid_env_id'])
  220. fp.writelines('\n')
  221. fp.writelines('##'*30)
  222. fp.writelines('\n')
  223. evaluator.load_valid_env(config['general']['valid_env_id'])
  224. result_logs = evaluator.infer()
  225. R = np.mean(result_logs['R'])
  226. IR = np.mean(result_logs['IR'])
  227. S = np.mean(result_logs['S'])
  228. results.append([R, IR, S])
  229. msg = '====FINAL EVAL==== R={:.3f}, IR={:.3f}, S={:.3f}'.format(R, IR, S)
  230. fp.writelines(msg)
  231. fp.writelines('\n')
  232. mean_res = np.mean(results, axis=0)
  233. std_res = np.std(results, axis=0)
  234. fp.writelines('##' * 30)
  235. fp.writelines('\n')
  236. fp.writelines(' Final seeded results : R={}/{}, IR={}/{}, S={}/{}'.format(mean_res[0], std_res[0], mean_res[1], std_res[1], mean_res[2], std_res[2]))
  237. fp.writelines('\n')
  238. fp.writelines('##' * 30)
  239. fp.writelines('\n')
  240. fp.close()
  241. pid = os.getpid()
  242. os.system('kill -9 {}'.format(pid))