import logging import numpy as np from collections import namedtuple import random # from matplotlib import pyplot as plt import torch import torch.nn.functional as F from crest.helper.config_utils import change_config, get_prefix from crest.helper.utils import read_file, save2file from crest.helper.bootstrap_utils import CREST from crest.helper.nlp_utils import compact_text from crest.helper.generic import dict2list logger = logging.getLogger(__name__) import gym import gym_textworld # Register all textworld environments. from crest.agents.lstm_drqn.agent import RLAgent from crest.agents.lstm_drqn.test_agent import test def get_agent(config, env): word_vocab = dict2list(env.observation_space.id2w) word2id = {} for i, w in enumerate(word_vocab): word2id[w] = i if config['general']['exp_act']: verb_list = read_file("data/vocabs/trial_run_custom_tw/verb_vocab.txt") object_name_list = read_file("data/vocabs/common_nouns.txt") else: verb_list = ["go", "take", "unlock", "lock", "drop", "look", "insert", "open", "inventory", "close"] object_name_list = ["east", "west", "north", "south", "coin", "apple", "carrot", "textbook", "passkey", "keycard"] # Add missing words in word2id for w in verb_list: if w not in word2id.keys(): word2id[w] = len(word2id) word_vocab += [w, ] for w in object_name_list: if w not in word2id.keys(): word2id[w] = len(word2id) word_vocab += [w, ] verb_map = [word2id[w] for w in verb_list if w in word2id] noun_map = [word2id[w] for w in object_name_list if w in word2id] print('Loaded {} verbs'.format(len(verb_map))) print('Loaded {} nouns'.format(len(noun_map))) print('##' * 30) print('Missing verbs and objects:') print([w for w in verb_list if w not in word2id]) print([w for w in object_name_list if w not in word2id]) print('Loading DRQN agent') if config['general']['student']: agent = RLAgent(config, word_vocab, verb_map, noun_map, att=config['general']['use_attention'], bootstrap=config['general']['student'], embed=config['bootstrap']['embed']) else: agent = RLAgent(config, word_vocab, verb_map, noun_map, att=config['general']['use_attention'], bootstrap=config ['general']['student'],) return agent class Evaluator(): def __init__(self, config, args, threshold=0.3): self.config = config self.args = args teacher_path = config['general']['teacher_model_path'] print('Setting up TextWorld environment...') def load_valid_env(self, valid_env_name): test_batch_size = 1 valid_env_id = gym_textworld.make_batch(env_id=valid_env_name, batch_size=test_batch_size, parallel=True) self.valid_env = gym.make(valid_env_id) self.valid_env.seed(config['general']['random_seed']) print('Loaded env name: ', valid_env_name) def load_agent(self): self.agent = get_agent(config, self.valid_env) model_checkpoint_path = config['training']['scheduling']['model_checkpoint_path'] load_path = model_checkpoint_path.replace('.pt', '_best.pt') print('Loading model from : ', load_path) self.agent.model.load_state_dict(torch.load(load_path)) self.hidden_size = config['model']['lstm_dqn']['action_scorer_hidden_dim'] self.hash_features = {} def inference(self, agent, env, prune=False, action_dist=None): batch_size = 1 assert batch_size == 1, "Batchsize should be 1 during inference" agent.model.eval() obs, infos = env.reset() agent.reset(infos) id_string_0 = agent.get_observation_strings(infos)[0] provide_prev_action = self.config['general']['provide_prev_action'] dones = [False] * batch_size rewards = [0] prev_actions = ["" for _ in range(batch_size)] if provide_prev_action else None input_description, description_id_list, desc, _ =\ agent.get_game_step_info(obs, infos, prev_actions, prune=prune, teacher_actions=action_dist, ret_desc=True,) curr_ras_hidden, curr_ras_cell = None, None # ras: recurrent action scorer if prune: desc_strings, desc_disc = agent.get_similarity_scores(obs, infos, prev_actions, prune=prune, teacher_actions=action_dist, ret_desc=True,) self.desc.append(list(desc_disc.keys())) self.desc_scores.append(list(desc_disc.values())) self.desc_strings.append(desc_strings) if id_string_0 in self.id_string_list: print('Already encountered this game. Skipping...') return self.id_string_list.append(id_string_0) self.game_num += 1 while not all(dones): v_idx, n_idx, _, curr_ras_hidden, curr_ras_cell = agent.generate_one_command(input_description, curr_ras_hidden, curr_ras_cell, epsilon=0.0, return_att=args.use_attention) if args.use_attention: softmax_att = agent.get_softmax_attention() else: softmax_att = None qv, qn = agent.get_qvalues() _, v_idx_maxq, _, n_idx_maxq = agent.choose_maxQ_command(qv, qn) chosen_strings = agent.get_chosen_strings(v_idx_maxq.detach(), n_idx_maxq.detach()) sorted_tokens=None sorted_atts=None obs, rewards, dones, infos = env.step(chosen_strings) if provide_prev_action: prev_actions = chosen_strings IR = [info["intermediate_reward"] for info in infos] if type(dones) is bool: dones = [dones] * batch_size agent.rewards.append(rewards) agent.dones.append(dones) agent.intermediate_rewards.append([info["intermediate_reward"] for info in infos]) input_description, description_id_list, desc, _ =\ agent.get_game_step_info(obs, infos, prev_actions, prune=prune, teacher_actions=action_dist, ret_desc=True,) if prune: desc_strings, desc_disc = agent.get_similarity_scores(obs, infos, prev_actions, prune=prune, teacher_actions=action_dist, ret_desc=True) self.desc.append(list(desc_disc.keys())) self.desc_scores.append(list(desc_disc.values())) self.desc_strings.append(desc_strings) _, _, orig_desc, _ = agent.get_game_step_info(obs, infos, prev_actions, prune=False, ret_desc=True,) for x, y in zip(orig_desc, desc): self.orig_data += [' '.join(x), ' '.join(y)] agent.finish() R = agent.final_rewards.mean() S = agent.step_used_before_done.mean() IR = agent.final_intermediate_rewards.mean() msg = '====EVAL==== R={:.3f}, IR={:.3f}, S={:.3f}' msg = msg.format(R, IR, S) print(msg) print("\n") self.result_logs['R'].append(R) self.result_logs['IR'].append(IR) self.result_logs['S'].append(S) def infer(self): numgames = self.args.num_test_games prune = self.args.prune save_dict = {} count = 0 self.id_string_list = [] self.result_logs = {'R': [], 'IR': [], 'S': []} if prune: self.prune_filename = config['training']['scheduling']['model_checkpoint_path'].replace('.pt', '_level_{}_logs.txt'.format(args.level)).replace('saved_models', 'prune_logs') self.score_filename = config['training']['scheduling']['model_checkpoint_path'].replace('.pt', '_level_{}_logs.npz'.format(args.level)).replace('saved_models', 'score_logs') self.orig_data = [] self.desc = [] self.desc_scores = [] self.desc_strings = [] if args.method=='drqn': if prune: prefix_name = get_prefix(self.args) filename = './data/teacher_data/{}.npz'.format(prefix_name) teacher_dict = np.load(filename, allow_pickle=True) global_action_set = set() for k in teacher_dict.keys(): if k=='allow_pickle': continue action_dist = teacher_dict[k][-1] action_dist = [x for x in action_dist.keys()] global_action_set.update(action_dist) self.game_num = 0 print('here') while (len(self.id_string_list)