prepare_gist.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import logging
  2. import numpy as np
  3. from collections import namedtuple
  4. import random
  5. # from matplotlib import pyplot as plt
  6. import torch
  7. import torch.nn.functional as F
  8. from sklearn.feature_extraction.text import TfidfVectorizer
  9. import sys
  10. sys.path.append(sys.path[0] + "/..")
  11. from crest.helper.config_utils import change_config, get_prefix
  12. from crest.helper.utils import read_file
  13. from crest.helper.bootstrap_utils import CREST
  14. from crest.helper.nlp_utils import compact_text
  15. from crest.helper.generic import dict2list
  16. logger = logging.getLogger(__name__)
  17. import gym
  18. import gym_textworld # Register all textworld environments.
  19. from crest.agents.lstm_drqn.agent import RLAgent
  20. from crest.agents.lstm_drqn.test_agent import test
  21. def get_agent(config, env):
  22. word_vocab = dict2list(env.observation_space.id2w)
  23. word2id = {}
  24. for i, w in enumerate(word_vocab):
  25. word2id[w] = i
  26. if config['general']['exp_act']:
  27. print('##' * 30)
  28. print('Using expanded action list for treasure hunter')
  29. verb_list = read_file("data/vocabs/trial_run_custom_tw/verb_vocab.txt")
  30. object_name_list = read_file("data/vocabs/common_nouns.txt")
  31. else:
  32. verb_list = ["go", "take", "unlock", "lock", "drop", "look", "insert", "open", "inventory", "close"]
  33. object_name_list = ["east", "west", "north", "south", "coin", "apple", "carrot", "textbook", "passkey", "keycard"]
  34. # Add missing words in word2id
  35. for w in verb_list:
  36. if w not in word2id.keys():
  37. word2id[w] = len(word2id)
  38. word_vocab += [w, ]
  39. for w in object_name_list:
  40. if w not in word2id.keys():
  41. word2id[w] = len(word2id)
  42. word_vocab += [w, ]
  43. verb_map = [word2id[w] for w in verb_list if w in word2id]
  44. noun_map = [word2id[w] for w in object_name_list if w in word2id]
  45. print('Loaded {} verbs'.format(len(verb_map)))
  46. print('Loaded {} nouns'.format(len(noun_map)))
  47. print('##' * 30)
  48. print('Missing verbs and objects:')
  49. print([w for w in verb_list if w not in word2id])
  50. print([w for w in object_name_list if w not in word2id])
  51. agent = RLAgent(config, word_vocab, verb_map, noun_map, att=config['general']['use_attention'], bootstrap=False,)
  52. return agent
  53. def topk_attention(softmax_att, desc, k=10):
  54. np_att = softmax_att.detach().cpu().numpy()[0]
  55. desc = desc[0]
  56. dtype = [('token', 'S10'), ('att', float)]
  57. values = [(s, a) for s, a in zip(desc, np_att)]
  58. val_array = np.array(values, dtype=dtype)
  59. sorted_values = np.sort(val_array, order='att')[::-1]
  60. sorted_tokens = [x['token'] for x in sorted_values]
  61. sorted_atts = [np.round(x['att'], 3) for x in sorted_values]
  62. return sorted_tokens[:k], sorted_atts[:k]
  63. class GISTSaver():
  64. def __init__(self, config, args, threshold=0.3):
  65. self.bs_obj = CREST(threshold=threshold)
  66. self.config = config
  67. validation_games = 20
  68. teacher_path = config['general']['teacher_model_path']
  69. print('Setting up TextWorld environment...')
  70. self.batch_size = 1
  71. # load
  72. print('Making env id {}'.format(config['general']['env_id']))
  73. env_id = gym_textworld.make_batch(env_id=config['general']['env_id'],
  74. batch_size=self.batch_size,
  75. parallel=True)
  76. self.env = gym.make(env_id)
  77. # self.env.seed(config['general']['random_seed'])
  78. test_batch_size = config['training']['scheduling']['test_batch_size']
  79. # valid
  80. valid_env_name = config['general']['valid_env_id']
  81. valid_env_id = gym_textworld.make_batch(env_id=valid_env_name,
  82. batch_size=test_batch_size,
  83. parallel=True)
  84. self.valid_env = gym.make(valid_env_id)
  85. self.valid_env.seed(config['general']['random_seed'])
  86. self.teacher_agent = get_agent(config, self.env)
  87. print('Loading teacher from : ', teacher_path)
  88. self.teacher_agent.model.load_state_dict(torch.load(teacher_path))
  89. # import time; time.sleep(5)
  90. self.hidden_size = config['model']['lstm_dqn']['action_scorer_hidden_dim']
  91. self.hash_features = {}
  92. def inference_teacher(self, agent, env, noise_std=0):
  93. assert self.batch_size == 1, "Batchsize should be 1 during inference"
  94. agent.model.eval()
  95. obs, infos = env.reset()
  96. agent.reset(infos)
  97. id_string_0 = agent.get_observation_strings(infos)[0]
  98. print_command_string, print_rewards = [[] for _ in infos], [[] for _ in infos]
  99. print_interm_rewards = [[] for _ in infos]
  100. provide_prev_action = self.config['general']['provide_prev_action']
  101. dones = [False] * self.batch_size
  102. rewards = [0]
  103. prev_actions = ["" for _ in range(self.batch_size)] if provide_prev_action else None
  104. input_description, _, desc, _ = agent.get_game_step_info(obs, infos, prev_actions, ret_desc=True)
  105. curr_ras_hidden, curr_ras_cell = None, None # ras: recurrent action scorer
  106. # curr_ras_hidden, curr_ras_cell = get_init_hidden(bsz=self.batch_size,
  107. # hidden_size=self.hidden_size, use_cuda=True)
  108. print("##" * 30)
  109. print(obs)
  110. print("##" * 30)
  111. obs_list = []
  112. infos_list = []
  113. act_list = []
  114. sorted_tokens_list = []
  115. sorted_att_list = []
  116. id_string = id_string_0
  117. new_rooms = 0
  118. while not all(dones):
  119. v_idx, n_idx, _, curr_ras_hidden, curr_ras_cell = agent.generate_one_command(input_description, curr_ras_hidden,
  120. curr_ras_cell, epsilon=0.0, return_att=args.use_attention)
  121. if args.use_attention:
  122. softmax_att = agent.get_softmax_attention()
  123. else:
  124. softmax_att = None
  125. qv, qn = agent.get_qvalues()
  126. qv_noisy = qv
  127. qn_noisy = qn
  128. _, v_idx_maxq, _, n_idx_maxq = agent.choose_maxQ_command(qv_noisy, qn_noisy)
  129. chosen_strings = agent.get_chosen_strings(v_idx_maxq.detach(), n_idx_maxq.detach())
  130. if args.use_attention:
  131. sorted_tokens, sorted_atts = topk_attention(softmax_att, desc, k=10)
  132. else:
  133. sorted_tokens = None
  134. sorted_atts = None
  135. print('Action : ', chosen_strings[0])
  136. obs_list.append(obs[0])
  137. infos_list.append(infos[0])
  138. act_list.append(chosen_strings[0])
  139. sorted_tokens_list.append(sorted_tokens)
  140. sorted_att_list.append(sorted_atts)
  141. obs, rewards, dones, infos = env.step(chosen_strings)
  142. if provide_prev_action:
  143. prev_actions = chosen_strings
  144. for i in range(len(infos)):
  145. print_command_string[i].append(chosen_strings[i])
  146. print_rewards[i].append(rewards[i])
  147. print_interm_rewards[i].append(infos[i]["intermediate_reward"])
  148. IR = [info["intermediate_reward"] for info in infos]
  149. new_id_string = agent.get_observation_strings(infos)[0]
  150. if new_id_string != id_string:
  151. self.hash_features[id_string] = [infos, prev_actions, qv.detach().cpu().numpy(),
  152. qn.detach().cpu().numpy(),
  153. softmax_att,
  154. desc, chosen_strings]
  155. id_string = agent.get_observation_strings(infos)[0]
  156. new_rooms += 1
  157. if new_rooms >= 75:
  158. break
  159. if type(dones) is bool:
  160. dones = [dones] * self.batch_size
  161. agent.rewards.append(rewards)
  162. agent.dones.append(dones)
  163. agent.intermediate_rewards.append([info["intermediate_reward"] for info in infos])
  164. input_description, _, desc, _ = agent.get_game_step_info(obs, infos, prev_actions, ret_desc=True)
  165. agent.finish()
  166. R = agent.final_rewards.mean()
  167. S = agent.step_used_before_done.mean()
  168. IR = agent.final_intermediate_rewards.mean()
  169. msg = '====EVAL==== R={:.3f}, IR={:.3f}, S={:.3f}, new_rooms={}'
  170. msg = msg.format(R, IR, S, new_rooms)
  171. print(msg)
  172. print("\n")
  173. return R, IR, S, obs_list, infos_list, act_list, sorted_tokens_list, \
  174. sorted_att_list, id_string_0
  175. def compute_similarity(self, state_, ):
  176. pass
  177. def compute_action_distribution(self, action_list, normalize=True):
  178. action_dict = {}
  179. tot_tokens = 0
  180. for action in action_list:
  181. for token in action.split(" "):
  182. tot_tokens += 1
  183. if token in action_dict.keys():
  184. action_dict[token] += 1
  185. else:
  186. action_dict[token] = 1
  187. if normalize:
  188. for token in action_dict.keys():
  189. action_dict[token] = (action_dict[token] * 1.)/tot_tokens
  190. return action_dict
  191. def infer(self, numgames, noise_std=0):
  192. save_dict = {}
  193. count = 0
  194. for i in range(numgames):
  195. print('Game number : ', i)
  196. R, IR, S, obs_list, infos_list, act_list, sorted_tokens_list, \
  197. sorted_att_list, id_string = \
  198. self.inference_teacher(self.teacher_agent, self.env, noise_std=noise_std)
  199. action_dist = self.compute_action_distribution(act_list)
  200. if R==1:
  201. count+=1
  202. save_dict[id_string] = [obs_list, infos_list, act_list,
  203. sorted_tokens_list, sorted_att_list,
  204. action_dist]
  205. print('saved ', count)
  206. prefix_name = get_prefix(args)
  207. filename = './data/teacher_data/{}.npz'.format(prefix_name)
  208. hash_filename = './data/teacher_data/teacher_softmax_{}.pkl'.format(prefix_name)
  209. np.savez(filename, **save_dict, allow_pickle=True)
  210. with open(hash_filename, 'wb') as fp:
  211. pickle.dump(self.hash_features, fp, -1)
  212. if __name__ == '__main__':
  213. import os, argparse, pickle, hickle
  214. for _p in ['saved_models']:
  215. if not os.path.exists(_p):
  216. os.mkdir(_p)
  217. parser = argparse.ArgumentParser(description="train network.")
  218. parser.add_argument("-c", "--config_dir", default='config', help="the default config directory")
  219. parser.add_argument("-type", "--type", default=None, help="easy | medium | hard")
  220. parser.add_argument("-ng", "--num_games", default=25, type=int)
  221. parser.add_argument("-v", "--verbose", help="increase output verbosity", action="store_true")
  222. parser.add_argument("-vv", "--very-verbose", help="print out warnings", action="store_true")
  223. parser.add_argument("-fr", "--force-remove", help="remove experiments directory to start new",
  224. action="store_true")
  225. parser.add_argument("-att", "--use_attention", help="Use attention in the encoder model",
  226. action="store_true")
  227. parser.add_argument("-student", "--student", help="Use student", action="store_true")
  228. parser.add_argument("-th", "--threshold", help="Filter threshold value for cosine similarity", default=0.3, type=float)
  229. parser.add_argument("-ea", "--exp_act", help="Use expanded vocab list for actions", action="store_true")
  230. parser.add_argument("-drop", "--dropout", default=0, type=float)
  231. args = parser.parse_args()
  232. config = change_config(args, method='drqn', wait_time=0, test=True)
  233. state_pruner = GISTSaver(config, args, threshold=args.threshold)
  234. state_pruner.infer(args.num_games)
  235. pid = os.getpid()
  236. os.system('kill -9 {}'.format(pid))