train_policy_qlearn.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. import logging
  2. import os
  3. import numpy as np
  4. import argparse
  5. import warnings
  6. import yaml
  7. from os.path import join as pjoin
  8. import sys
  9. sys.path.append(sys.path[0] + "/..")
  10. import pickle
  11. import torch
  12. from torch import nn
  13. import torch.nn.functional as F
  14. from tensorboardX import SummaryWriter
  15. from crest.agents.lstm_drqn.agent import RLAgent as Agent
  16. from crest.helper.config_utils import change_config, get_prefix
  17. from crest.helper.bootstrap_utils import get_init_hidden
  18. from crest.helper.generic import SlidingAverage, to_np
  19. from crest.helper.generic import get_experiment_dir, dict2list
  20. from crest.agents.lstm_drqn.test_agent import test
  21. from crest.helper.utils import read_file
  22. logger = logging.getLogger(__name__)
  23. import gym
  24. import gym_textworld # Register all textworld environments.
  25. import textworld
  26. # os.system('rm -r gen_games')
  27. def train(config, prune=False, embed='cnet'):
  28. # train env
  29. print('Setting up TextWorld environment...')
  30. batch_size = config['training']['scheduling']['batch_size']
  31. env_id = gym_textworld.make_batch(env_id=config['general']['env_id'],
  32. batch_size=batch_size,
  33. parallel=True)
  34. env = gym.make(env_id)
  35. env.seed(config['general']['random_seed'])
  36. print("##" * 30)
  37. if prune:
  38. print('Using state pruning ...')
  39. else:
  40. print('Not using state pruning ...')
  41. print("##" * 30)
  42. # valid and test env
  43. run_test = config['general']['run_test']
  44. if run_test:
  45. test_batch_size = config['training']['scheduling']['test_batch_size']
  46. # valid
  47. valid_env_name = config['general']['valid_env_id']
  48. valid_env_id = gym_textworld.make_batch(env_id=valid_env_name, batch_size=test_batch_size, parallel=True)
  49. valid_env = gym.make(valid_env_id)
  50. valid_env.seed(config['general']['random_seed'])
  51. # valid_env.reset()
  52. # test
  53. test_env_name_list = config['general']['test_env_id']
  54. assert isinstance(test_env_name_list, list)
  55. test_env_id_list = [gym_textworld.make_batch(env_id=item, batch_size=test_batch_size, parallel=True) for item in test_env_name_list]
  56. test_env_list = [gym.make(test_env_id) for test_env_id in test_env_id_list]
  57. for i in range(len(test_env_list)):
  58. test_env_list[i].seed(config['general']['random_seed'])
  59. # test_env_list[i].reset()
  60. print('Done.')
  61. # Set the random seed manually for reproducibility.
  62. np.random.seed(config['general']['random_seed'])
  63. torch.manual_seed(config['general']['random_seed'])
  64. if torch.cuda.is_available():
  65. if not config['general']['use_cuda']:
  66. logger.warning("WARNING: CUDA device detected but 'use_cuda: false' found in config.yaml")
  67. else:
  68. torch.backends.cudnn.deterministic = True
  69. torch.cuda.manual_seed(config['general']['random_seed'])
  70. else:
  71. config['general']['use_cuda'] = False # Disable CUDA.
  72. use_cuda = config['general']['use_cuda']
  73. revisit_counting = config['general']['revisit_counting']
  74. replay_batch_size = config['general']['replay_batch_size']
  75. history_size = config['general']['history_size']
  76. update_from = config['general']['update_from']
  77. replay_memory_capacity = config['general']['replay_memory_capacity']
  78. replay_memory_priority_fraction = config['general']['replay_memory_priority_fraction']
  79. word_vocab = dict2list(env.observation_space.id2w)
  80. word2id = {}
  81. for i, w in enumerate(word_vocab):
  82. word2id[w] = i
  83. if config['general']['exp_act']:
  84. print('##' * 30)
  85. print('Using expanded verb list')
  86. verb_list = read_file("data/vocabs/trial_run_custom_tw/verb_vocab.txt")
  87. object_name_list = read_file("data/vocabs/common_nouns.txt")
  88. else:
  89. #"This option only works for coin collector"
  90. verb_list = ["go", "take", "unlock", "lock", "drop", "look", "insert", "open", "inventory", "close"]
  91. object_name_list = ["east", "west", "north", "south", "coin", "apple", "carrot", "textbook", "passkey",
  92. "keycard"]
  93. # Add missing words in word2id
  94. for w in verb_list:
  95. if w not in word2id.keys():
  96. word2id[w] = len(word2id)
  97. word_vocab += [w, ]
  98. for w in object_name_list:
  99. if w not in word2id.keys():
  100. word2id[w] = len(word2id)
  101. word_vocab += [w, ]
  102. verb_map = [word2id[w] for w in verb_list if w in word2id]
  103. noun_map = [word2id[w] for w in object_name_list if w in word2id]
  104. # teacher_path = config['general']['teacher_model_path']
  105. # teacher_agent = Agent(config, word_vocab, verb_map, noun_map,
  106. # att=config['general']['use_attention'],
  107. # bootstrap=False,
  108. # replay_memory_capacity=replay_memory_capacity,
  109. # replay_memory_priority_fraction=replay_memory_priority_fraction)
  110. # teacher_agent.model.load_state_dict(torch.load(teacher_path))
  111. # teacher_agent.model.eval()
  112. student_agent = Agent(config, word_vocab, verb_map, noun_map,
  113. att=config['general']['use_attention'],
  114. bootstrap=config['general']['student'],
  115. replay_memory_capacity=replay_memory_capacity,
  116. replay_memory_priority_fraction=replay_memory_priority_fraction,
  117. embed=embed)
  118. init_learning_rate = config['training']['optimizer']['learning_rate']
  119. exp_dir = get_experiment_dir(config)
  120. summary = SummaryWriter(exp_dir)
  121. parameters = filter(lambda p: p.requires_grad, student_agent.model.parameters())
  122. if config['training']['optimizer']['step_rule'] == 'sgd':
  123. optimizer = torch.optim.SGD(parameters, lr=init_learning_rate)
  124. elif config['training']['optimizer']['step_rule'] == 'adam':
  125. optimizer = torch.optim.Adam(parameters, lr=init_learning_rate)
  126. log_every = 100
  127. reward_avg = SlidingAverage('reward avg', steps=log_every)
  128. step_avg = SlidingAverage('step avg', steps=log_every)
  129. loss_avg = SlidingAverage('loss avg', steps=log_every)
  130. # save & reload checkpoint only in 0th agent
  131. best_avg_reward = -10000
  132. best_avg_step = 10000
  133. # step penalty
  134. discount_gamma = config['general']['discount_gamma']
  135. provide_prev_action = config['general']['provide_prev_action']
  136. # epsilon greedy
  137. epsilon_anneal_epochs = config['general']['epsilon_anneal_epochs']
  138. epsilon_anneal_from = config['general']['epsilon_anneal_from']
  139. epsilon_anneal_to = config['general']['epsilon_anneal_to']
  140. # counting reward
  141. revisit_counting_lambda_anneal_epochs = config['general']['revisit_counting_lambda_anneal_epochs']
  142. revisit_counting_lambda_anneal_from = config['general']['revisit_counting_lambda_anneal_from']
  143. revisit_counting_lambda_anneal_to = config['general']['revisit_counting_lambda_anneal_to']
  144. model_checkpoint_path = config['training']['scheduling']['model_checkpoint_path']
  145. epsilon = epsilon_anneal_from
  146. revisit_counting_lambda = revisit_counting_lambda_anneal_from
  147. #######################################################################
  148. ##### Load the teacher data #####
  149. #######################################################################
  150. prefix_name = get_prefix(args)
  151. filename = './data/teacher_data/{}.npz'.format(prefix_name)
  152. teacher_dict = np.load(filename, allow_pickle=True)
  153. # import ipdb; ipdb.set_trace()
  154. global_action_set = set()
  155. print("##" * 30)
  156. print("Training for {} epochs".format(config['training']['scheduling']['epoch']))
  157. print("##" * 30)
  158. import time
  159. t0 = time.time()
  160. for epoch in range(config['training']['scheduling']['epoch']):
  161. student_agent.model.train()
  162. obs, infos = env.reset()
  163. student_agent.reset(infos)
  164. # this the string identifier for leading the episodic action distribution
  165. id_string = student_agent.get_observation_strings(infos)
  166. cont_flag=False
  167. for id_ in id_string:
  168. if id_ not in teacher_dict.keys():
  169. cont_flag=True
  170. if cont_flag:
  171. print('Skipping this epoch/.....')
  172. continue
  173. # Episodic action list
  174. action_dist = [teacher_dict[id_string[k]][-1] for k in range(len(id_string))]
  175. action_dist = [[x for x in item.keys()] for item in action_dist]
  176. for item in action_dist:
  177. global_action_set.update(item)
  178. print_command_string, print_rewards = [[] for _ in infos], [[] for _ in infos]
  179. print_interm_rewards = [[] for _ in infos]
  180. print_rc_rewards = [[] for _ in infos]
  181. dones = [False] * batch_size
  182. rewards = None
  183. avg_loss_in_this_game = []
  184. curr_observation_strings = student_agent.get_observation_strings(infos)
  185. if revisit_counting:
  186. student_agent.reset_binarized_counter(batch_size)
  187. revisit_counting_rewards = student_agent.get_binarized_count(curr_observation_strings)
  188. current_game_step = 0
  189. prev_actions = ["" for _ in range(batch_size)] if provide_prev_action else None
  190. input_description, description_id_list, student_desc, _ =\
  191. student_agent.get_game_step_info(obs, infos, prev_actions, prune=prune,
  192. teacher_actions=action_dist, ret_desc=True,)
  193. curr_ras_hidden, curr_ras_cell = None, None # ras: recurrent action scorer
  194. memory_cache = [[] for _ in range(batch_size)]
  195. solved = [0 for _ in range(batch_size)]
  196. while not all(dones):
  197. student_agent.model.train()
  198. v_idx, n_idx, chosen_strings, curr_ras_hidden, curr_ras_cell = \
  199. student_agent.generate_one_command(input_description, curr_ras_hidden,
  200. curr_ras_cell, epsilon=0.0,
  201. return_att=args.use_attention)
  202. obs, rewards, dones, infos = env.step(chosen_strings)
  203. curr_observation_strings = student_agent.get_observation_strings(infos)
  204. # print(chosen_strings)
  205. if provide_prev_action:
  206. prev_actions = chosen_strings
  207. # counting
  208. if revisit_counting:
  209. revisit_counting_rewards = student_agent.get_binarized_count(curr_observation_strings, update=True)
  210. else:
  211. revisit_counting_rewards = [0.0 for b in range(batch_size)]
  212. student_agent.revisit_counting_rewards.append(revisit_counting_rewards)
  213. revisit_counting_rewards = [float(format(item, ".3f")) for item in revisit_counting_rewards]
  214. for i in range(len(infos)):
  215. print_command_string[i].append(chosen_strings[i])
  216. print_rewards[i].append(rewards[i])
  217. print_interm_rewards[i].append(infos[i]["intermediate_reward"])
  218. print_rc_rewards[i].append(revisit_counting_rewards[i])
  219. if type(dones) is bool:
  220. dones = [dones] * batch_size
  221. student_agent.rewards.append(rewards)
  222. student_agent.dones.append(dones)
  223. student_agent.intermediate_rewards.append([info["intermediate_reward"] for info in infos])
  224. # computer rewards, and push into replay memory
  225. rewards_np, rewards_pt, mask_np,\
  226. mask_pt, memory_mask = student_agent.compute_reward(revisit_counting_lambda=revisit_counting_lambda,
  227. revisit_counting=revisit_counting)
  228. ###############################
  229. ##### Pruned state desc #####
  230. ###############################
  231. curr_description_id_list = description_id_list
  232. input_description, description_id_list, student_desc, _ =\
  233. student_agent.get_game_step_info(obs, infos, prev_actions, prune=prune,
  234. teacher_actions=action_dist, ret_desc=True,)
  235. for b in range(batch_size):
  236. if memory_mask[b] == 0:
  237. continue
  238. if dones[b] == 1 and rewards[b] == 0:
  239. # last possible step
  240. is_final = True
  241. else:
  242. is_final = mask_np[b] == 0
  243. if rewards[b] > 0.0:
  244. solved[b] = 1
  245. # replay memory
  246. memory_cache[b].append(
  247. (curr_description_id_list[b], v_idx[b], n_idx[b], rewards_pt[b], mask_pt[b], dones[b],
  248. is_final, curr_observation_strings[b]))
  249. if current_game_step > 0 and current_game_step % config["general"]["update_per_k_game_steps"] == 0:
  250. policy_loss = student_agent.update(replay_batch_size, history_size, update_from, discount_gamma=discount_gamma)
  251. if policy_loss is None:
  252. continue
  253. loss = policy_loss
  254. # Backpropagate
  255. optimizer.zero_grad()
  256. loss.backward(retain_graph=True)
  257. # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
  258. torch.nn.utils.clip_grad_norm_(student_agent.model.parameters(), config['training']['optimizer']['clip_grad_norm'])
  259. optimizer.step() # apply gradients
  260. avg_loss_in_this_game.append(to_np(policy_loss))
  261. current_game_step += 1
  262. for i, mc in enumerate(memory_cache):
  263. for item in mc:
  264. if replay_memory_priority_fraction == 0.0:
  265. # vanilla replay memory
  266. student_agent.replay_memory.push(*item)
  267. else:
  268. # prioritized replay memory
  269. student_agent.replay_memory.push(solved[i], *item)
  270. student_agent.finish()
  271. avg_loss_in_this_game = np.mean(avg_loss_in_this_game)
  272. reward_avg.add(student_agent.final_rewards.mean())
  273. step_avg.add(student_agent.step_used_before_done.mean())
  274. loss_avg.add(avg_loss_in_this_game)
  275. # annealing
  276. if epoch < epsilon_anneal_epochs:
  277. epsilon -= (epsilon_anneal_from - epsilon_anneal_to) / float(epsilon_anneal_epochs)
  278. if epoch < revisit_counting_lambda_anneal_epochs:
  279. revisit_counting_lambda -= (revisit_counting_lambda_anneal_from - revisit_counting_lambda_anneal_to) / float(revisit_counting_lambda_anneal_epochs)
  280. # Tensorboard logging #
  281. # (1) Log some numbers
  282. if (epoch + 1) % config["training"]["scheduling"]["logging_frequency"] == 0:
  283. summary.add_scalar('avg_reward', reward_avg.value, epoch + 1)
  284. summary.add_scalar('curr_reward', student_agent.final_rewards.mean(), epoch + 1)
  285. summary.add_scalar('curr_interm_reward', student_agent.final_intermediate_rewards.mean(), epoch + 1)
  286. summary.add_scalar('curr_counting_reward', student_agent.final_counting_rewards.mean(), epoch + 1)
  287. summary.add_scalar('avg_step', step_avg.value, epoch + 1)
  288. summary.add_scalar('curr_step', student_agent.step_used_before_done.mean(), epoch + 1)
  289. summary.add_scalar('loss_avg', loss_avg.value, epoch + 1)
  290. summary.add_scalar('curr_loss', avg_loss_in_this_game, epoch + 1)
  291. t1 = time.time()
  292. summary.add_scalar('time', t1 - t0, epoch + 1)
  293. msg = 'E#{:03d}, R={:.3f}/{:.3f}/IR{:.3f}/CR{:.3f}, S={:.3f}/{:.3f}, L={:.3f}/{:.3f}, epsilon={:.4f}, lambda_counting={:.4f}'
  294. msg = msg.format(epoch,
  295. np.mean(reward_avg.value), student_agent.final_rewards.mean(), student_agent.final_intermediate_rewards.mean(), student_agent.final_counting_rewards.mean(),
  296. np.mean(step_avg.value), student_agent.step_used_before_done.mean(),
  297. np.mean(loss_avg.value), avg_loss_in_this_game,
  298. epsilon, revisit_counting_lambda)
  299. if (epoch + 1) % config["training"]["scheduling"]["logging_frequency"] == 0:
  300. torch.save(student_agent.model.state_dict(), model_checkpoint_path.replace('.pt', '_train.pt'))
  301. print("=========================================================")
  302. for prt_cmd, prt_rew, prt_int_rew, prt_rc_rew in zip(print_command_string, print_rewards, print_interm_rewards, print_rc_rewards):
  303. print("------------------------------")
  304. print(prt_cmd)
  305. print(prt_rew)
  306. print(prt_int_rew)
  307. print(prt_rc_rew)
  308. print(msg)
  309. # test on a different set of games
  310. if run_test and (epoch) % config["training"]["scheduling"]["logging_frequency"] == 0:
  311. valid_R, valid_IR, valid_S = test(config, valid_env, student_agent, test_batch_size, word2id, prune=prune,
  312. teacher_actions=[list(global_action_set)]*test_batch_size)
  313. summary.add_scalar('valid_reward', valid_R, epoch + 1)
  314. summary.add_scalar('valid_interm_reward', valid_IR, epoch + 1)
  315. summary.add_scalar('valid_step', valid_S, epoch + 1)
  316. # save & reload checkpoint by best valid performance
  317. if valid_R > best_avg_reward or (valid_R == best_avg_reward and valid_S < best_avg_step):
  318. best_avg_reward = valid_R
  319. best_avg_step = valid_S
  320. torch.save(student_agent.model.state_dict(), model_checkpoint_path.replace('.pt', '_best.pt'))
  321. print("========= saved checkpoint =========")
  322. if __name__ == '__main__':
  323. for _p in ['saved_models']:
  324. if not os.path.exists(_p):
  325. os.mkdir(_p)
  326. parser = argparse.ArgumentParser(description="train network.")
  327. parser.add_argument("-c", "--config_dir", default='config', help="the default config directory")
  328. parser.add_argument("-type", "--type", default=None, help="easy | medium | hard")
  329. parser.add_argument("-ng", "--num_games", default=None, help="easy | medium | hard")
  330. parser.add_argument("-v", "--verbose", help="increase output verbosity", action="store_true")
  331. parser.add_argument("-vv", "--very-verbose", help="print out warnings", action="store_true")
  332. parser.add_argument("-fr", "--force-remove", help="remove experiments directory to start new", action="store_true")
  333. parser.add_argument("-att", "--use_attention", help="Use attention in the encoder model", action="store_true")
  334. parser.add_argument("-student", "--student", help="Whether Teacher or Student model", action="store_true")
  335. parser.add_argument("-th", "--threshold", help="Filter threshold value for cosine similarity", default=0.3, type=float)
  336. parser.add_argument("-ea", "--exp_act", help="Use expanded vocab list for actions", action="store_true")
  337. parser.add_argument("-prune", "--prune", help="Use pruning or not", action="store_true")
  338. parser.add_argument("-emb", "--embed", default='cnet', type=str) # 'cnet' | 'glove' | 'word2vec' | 'bert'
  339. args = parser.parse_args()
  340. config = change_config(args)
  341. print('Threshold: ', config['bootstrap']['threshold'])
  342. config['training']['scheduling']['epoch']=6000
  343. config['general']['epsilon_anneal_epochs']=3600
  344. config["training"]["scheduling"]["logging_frequency"] = 50
  345. train(config=config, prune=args.prune, embed=args.embed)
  346. # pid = os.getpid()
  347. # os.system('kill -9 {}'.format(pid))