test_agent.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import logging
  2. import os
  3. import numpy as np
  4. import argparse
  5. import warnings
  6. import yaml
  7. from os.path import join as pjoin
  8. import torch
  9. from crest.agents.lstm_drqn.agent import RLAgent
  10. from crest.helper.generic import get_experiment_dir
  11. logger = logging.getLogger(__name__)
  12. import gym
  13. import gym_textworld # Register all textworld environments.
  14. import textworld
  15. def test(config, env, agent, batch_size, word2id, prune=False, teacher_actions=None):
  16. agent.model.eval()
  17. obs, infos = env.reset()
  18. agent.reset(infos)
  19. print_command_string, print_rewards = [[] for _ in infos], [[] for _ in infos]
  20. print_interm_rewards = [[] for _ in infos]
  21. provide_prev_action = config['general']['provide_prev_action']
  22. dones = [False] * batch_size
  23. rewards = None
  24. prev_actions = ["" for _ in range(batch_size)] if provide_prev_action else None
  25. if prune:
  26. input_description, description_id_list, desc, _ = \
  27. agent.get_game_step_info(obs, infos, prev_actions, prune=prune,
  28. teacher_actions=teacher_actions, ret_desc=True, )
  29. else:
  30. input_description, _ = agent.get_game_step_info(obs, infos, prev_actions)
  31. curr_ras_hidden, curr_ras_cell = None, None # ras: recurrent action scorer
  32. while not all(dones):
  33. v_idx, n_idx, chosen_strings, curr_ras_hidden, curr_ras_cell = agent.generate_one_command(input_description,
  34. curr_ras_hidden,
  35. curr_ras_cell,
  36. epsilon=0.0)
  37. obs, rewards, dones, infos = env.step(chosen_strings)
  38. if provide_prev_action:
  39. prev_actions = chosen_strings
  40. for i in range(len(infos)):
  41. print_command_string[i].append(chosen_strings[i])
  42. print_rewards[i].append(rewards[i])
  43. print_interm_rewards[i].append(infos[i]["intermediate_reward"])
  44. if type(dones) is bool:
  45. dones = [dones] * batch_size
  46. agent.rewards.append(rewards)
  47. agent.dones.append(dones)
  48. agent.intermediate_rewards.append([info["intermediate_reward"] for info in infos])
  49. if prune:
  50. input_description, description_id_list, desc, _ = \
  51. agent.get_game_step_info(obs, infos, prev_actions, prune=prune,
  52. teacher_actions=teacher_actions, ret_desc=True, )
  53. else:
  54. input_description, _ = agent.get_game_step_info(obs, infos, prev_actions)
  55. agent.finish()
  56. R = agent.final_rewards.mean()
  57. S = agent.step_used_before_done.mean()
  58. IR = agent.final_intermediate_rewards.mean()
  59. msg = '====EVAL==== R={:.3f}, IR={:.3f}, S={:.3f}'
  60. msg = msg.format(R, IR, S)
  61. print(msg)
  62. print("\n")
  63. return R, IR, S