agent.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. import logging
  2. import numpy as np
  3. from collections import namedtuple
  4. import random
  5. import torch
  6. import torch.nn.functional as F
  7. from crest.helper.model import LSTM_DQN
  8. from crest.helper.bootstrap_utils import CREST
  9. from crest.helper.model import LSTM_DQN_ATT
  10. from crest.helper.generic import to_np, to_pt, preproc, _words_to_ids, pad_sequences, max_len
  11. logger = logging.getLogger(__name__)
  12. import gym
  13. import gym_textworld # Register all textworld environments.
  14. Transition = namedtuple('Transition', ('observation_id_list', 'v_idx', 'n_idx', 'reward', 'mask', 'done', 'is_final', 'observation_str'))
  15. from collections import deque
  16. class ReplayMemory(object):
  17. def __init__(self, capacity=100000):
  18. # vanilla replay memory
  19. self.capacity = capacity
  20. self.memory = []
  21. self.position = 0
  22. def push(self, *args):
  23. """Saves a transition."""
  24. if len(self.memory) < self.capacity:
  25. self.memory.append(None)
  26. self.memory[self.position] = Transition(*args)
  27. self.position = (self.position + 1) % self.capacity
  28. def get_batch(self, batch_size, history_size):
  29. if len(self.memory) <= history_size:
  30. return None
  31. res = []
  32. tried_times = 0
  33. while len(res) < batch_size:
  34. tried_times += 1
  35. if tried_times >= 500:
  36. break
  37. idx = np.random.randint(history_size - 1, len(self.memory) - 1)
  38. # only last frame can be (is_final == True)
  39. if np.any([item.is_final for item in self.memory[idx - (history_size - 1): idx]]):
  40. continue
  41. res.append(self.memory[idx - (history_size - 1): idx + 1])
  42. if len(res) == 0:
  43. return None
  44. res = list(map(list, zip(*res))) # list (history size) of list (batch) of tuples
  45. return res
  46. def sample(self, batch_size):
  47. return random.sample(self.memory, batch_size)
  48. def __len__(self):
  49. return len(self.memory)
  50. class PrioritizedReplayMemory(object):
  51. def __init__(self, capacity=100000, priority_fraction=0.0):
  52. # prioritized replay memory
  53. self.priority_fraction = priority_fraction
  54. self.alpha_capacity = int(capacity * priority_fraction)
  55. self.beta_capacity = capacity - self.alpha_capacity
  56. self.alpha_memory, self.beta_memory = [], []
  57. self.alpha_position, self.beta_position = 0, 0
  58. def push(self, is_prior=False, *args):
  59. """Saves a transition."""
  60. if is_prior:
  61. if len(self.alpha_memory) < self.alpha_capacity:
  62. self.alpha_memory.append(None)
  63. self.alpha_memory[self.alpha_position] = Transition(*args)
  64. self.alpha_position = (self.alpha_position + 1) % self.alpha_capacity
  65. else:
  66. if len(self.beta_memory) < self.beta_capacity:
  67. self.beta_memory.append(None)
  68. self.beta_memory[self.beta_position] = Transition(*args)
  69. self.beta_position = (self.beta_position + 1) % self.beta_capacity
  70. def _get_batch(self, batch_size, history_size, which_memory):
  71. if len(which_memory) <= history_size:
  72. return None
  73. res = []
  74. tried_times = 0
  75. while len(res) < batch_size:
  76. tried_times += 1
  77. if tried_times >= 500:
  78. break
  79. idx = np.random.randint(history_size - 1, len(which_memory) - 1)
  80. # only last frame can be (is_final == True)
  81. if np.any([item.is_final for item in which_memory[idx - (history_size - 1): idx]]):
  82. continue
  83. res.append(which_memory[idx - (history_size - 1): idx + 1])
  84. if len(res) == 0:
  85. return None
  86. return res
  87. def get_batch(self, batch_size, history_size):
  88. from_alpha = min(int(self.priority_fraction * batch_size), len(self.alpha_memory))
  89. from_beta = min(batch_size - int(self.priority_fraction * batch_size), len(self.beta_memory))
  90. res = []
  91. res_alpha = self._get_batch(from_alpha, history_size, self.alpha_memory)
  92. res_beta = self._get_batch(from_beta, history_size, self.beta_memory)
  93. if res_alpha is None and res_beta is None:
  94. return None
  95. if res_alpha is not None:
  96. res += res_alpha
  97. if res_beta is not None:
  98. res += res_beta
  99. random.shuffle(res)
  100. res = list(map(list, zip(*res))) # list (history size) of list (batch) of tuples
  101. return res
  102. def __len__(self):
  103. return len(self.alpha_memory) + len(self.beta_memory)
  104. class ObservationHistoryCache(object):
  105. def __init__(self, capacity=1):
  106. # vanilla replay memory
  107. self.capacity = capacity
  108. self.memory = []
  109. self.reset()
  110. def push(self, stuff):
  111. """stuff is list."""
  112. for i in range(1, self.capacity):
  113. self.memory[i - 1] = self.memory[i]
  114. self.memory[-1] = stuff
  115. def get_all(self):
  116. res = []
  117. for b in range(len(self.memory[-1])):
  118. tmp = []
  119. for i in range(self.capacity):
  120. if self.memory[i] == []:
  121. continue
  122. tmp += self.memory[i][b]
  123. res.append(tmp)
  124. return res
  125. def reset(self):
  126. self.memory = []
  127. for i in range(self.capacity):
  128. self.memory.append([])
  129. def __len__(self):
  130. return len(self.memory)
  131. class RLAgent(object):
  132. def __init__(self, config, word_vocab, verb_map, noun_map, att=False, bootstrap=False, replay_memory_capacity=100000,
  133. replay_memory_priority_fraction=0.0, load_pretrained=False, embed='cnet'):
  134. # print('Creating RL agent...')
  135. self.use_dropout_exploration = True
  136. self.config = config
  137. self.use_cuda = config['general']['use_cuda']
  138. self.word_vocab = word_vocab
  139. self.verb_map = verb_map
  140. self.noun_map = noun_map
  141. self.word2id = {}
  142. self.att = att
  143. for i, w in enumerate(word_vocab):
  144. self.word2id[w] = i
  145. self.bootstrap = bootstrap
  146. if bootstrap:
  147. print('##' * 30)
  148. print('Using Bootstrapping...')
  149. print('##' * 30)
  150. self.bs_thres = config['bootstrap']['threshold']
  151. self.bs_obj = CREST(threshold=self.bs_thres, embeddings=embed)
  152. if att:
  153. print('##' * 30)
  154. print('Using attention...')
  155. print('##' * 30)
  156. self.model = LSTM_DQN_ATT(model_config=config["model"], word_vocab=self.word_vocab,
  157. verb_map=verb_map, noun_map=noun_map, enable_cuda=self.use_cuda)
  158. else:
  159. print('##' * 30)
  160. print('NOT using attention...')
  161. print('##' * 30)
  162. self.model = LSTM_DQN(model_config=config["model"], word_vocab=self.word_vocab,
  163. verb_map=verb_map, noun_map=noun_map, enable_cuda=self.use_cuda)
  164. self.action_scorer_hidden_dim = config['model']['lstm_dqn']['action_scorer_hidden_dim']
  165. if load_pretrained:
  166. self.load_pretrained_model(config["model"]['global']['pretrained_model_save_path'])
  167. if self.use_cuda:
  168. self.model.cuda()
  169. if replay_memory_priority_fraction > 0.0:
  170. self.replay_memory = PrioritizedReplayMemory(replay_memory_capacity,
  171. priority_fraction=replay_memory_priority_fraction)
  172. else:
  173. self.replay_memory = ReplayMemory(replay_memory_capacity)
  174. self.observation_cache_capacity = config['general']['observation_cache_capacity']
  175. self.observation_cache = ObservationHistoryCache(self.observation_cache_capacity)
  176. def load_pretrained_model(self, load_from):
  177. # load model, if there is any
  178. print("------------------------------------loading best model------------------------------\n")
  179. try:
  180. save_f = open(load_from, 'rb')
  181. self.model = torch.load(save_f)
  182. except:
  183. print("failed...")
  184. def reset(self, infos):
  185. self.rewards = []
  186. self.dones = []
  187. self.intermediate_rewards = []
  188. self.revisit_counting_rewards = []
  189. self.observation_cache.reset()
  190. def get_chosen_strings(self, v_idx, n_idx):
  191. v_idx_np = to_np(v_idx)
  192. n_idx_np = to_np(n_idx)
  193. res_str = []
  194. for i in range(n_idx_np.shape[0]):
  195. v, n = self.verb_map[v_idx_np[i]], self.noun_map[n_idx_np[i]]
  196. res_str.append(self.word_vocab[v] + " " + self.word_vocab[n])
  197. return res_str
  198. def choose_random_command(self, verb_rank, noun_rank):
  199. batch_size = verb_rank.size(0)
  200. vr, nr = to_np(verb_rank), to_np(noun_rank)
  201. v_idx, n_idx = [], []
  202. for i in range(batch_size):
  203. v_idx.append(np.random.choice(len(vr[i]), 1)[0])
  204. n_idx.append(np.random.choice(len(nr[i]), 1)[0])
  205. v_qvalue, n_qvalue = [], []
  206. for i in range(batch_size):
  207. v_qvalue.append(verb_rank[i][v_idx[i]])
  208. n_qvalue.append(noun_rank[i][n_idx[i]])
  209. v_qvalue, n_qvalue = torch.stack(v_qvalue), torch.stack(n_qvalue)
  210. v_idx, n_idx = to_pt(np.array(v_idx), self.use_cuda), to_pt(np.array(n_idx), self.use_cuda)
  211. return v_qvalue, v_idx, n_qvalue, n_idx
  212. def choose_maxQ_command(self, verb_rank, noun_rank):
  213. batch_size = verb_rank.size(0)
  214. vr, nr = to_np(verb_rank), to_np(noun_rank)
  215. v_idx = np.argmax(vr, -1)
  216. n_idx = np.argmax(nr, -1)
  217. v_qvalue, n_qvalue = [], []
  218. for i in range(batch_size):
  219. v_qvalue.append(verb_rank[i][v_idx[i]])
  220. n_qvalue.append(noun_rank[i][n_idx[i]])
  221. v_qvalue, n_qvalue = torch.stack(v_qvalue), torch.stack(n_qvalue)
  222. v_idx, n_idx = to_pt(v_idx, self.use_cuda), to_pt(n_idx, self.use_cuda)
  223. return v_qvalue, v_idx, n_qvalue, n_idx
  224. def get_ranks(self, input_description, prev_hidden=None, prev_cell=None, return_att=False, att_mask=None):
  225. if return_att:
  226. state_representation, softmax_att =self.model.representation_generator(input_description, return_att=True, att_mask=att_mask)
  227. self.softmax_att = softmax_att
  228. else:
  229. state_representation = self.model.representation_generator(input_description)
  230. verb_rank, noun_rank, curr_hidden, curr_cell = self.model.recurrent_action_scorer(state_representation, prev_hidden, prev_cell)
  231. self.verb_rank = verb_rank
  232. self.noun_rank = noun_rank
  233. return verb_rank, noun_rank, curr_hidden, curr_cell
  234. def get_qvalues_att(self, input_description, prev_hidden=None, prev_cell=None, T=0.1):
  235. assert self.att, "Attention module must be turned on"
  236. state_representation, softmax_att = self.model.representation_generator(input_description, return_att=True)
  237. self.softmax_att = softmax_att
  238. verb_rank, noun_rank, curr_hidden, curr_cell = \
  239. self.model.recurrent_action_scorer(state_representation, prev_hidden, prev_cell)
  240. verb_softmax = F.softmax(verb_rank / T)
  241. noun_softmax = F.softmax(noun_rank / T)
  242. return verb_softmax, noun_softmax, curr_hidden, curr_cell
  243. def get_softmax_attention(self):
  244. return self.softmax_att
  245. def get_qvalues(self):
  246. return self.verb_rank, self.noun_rank
  247. def get_similarity_scores(self, obs, infos, prev_actions=None, prune=False,
  248. ret_desc=False, teacher_actions=None):
  249. # concat d/i/q/f together as one string
  250. info=infos[0]
  251. inventory_strings, inv_dict = self.bs_obj.prune_state(info["inventory"], teacher_actions[0],
  252. add_prefix=False, return_details=True)
  253. desc_strings, desc_disc = self.bs_obj.prune_state(info["description"], teacher_actions[0],
  254. add_prefix=False, return_details=True)
  255. obj_strings, obj_disc = self.bs_obj.prune_state(info["objective"], teacher_actions[0],
  256. add_prefix=False, return_details=True)
  257. return info["description"], desc_disc
  258. def generate_one_command(self, input_description, prev_hidden=None,
  259. prev_cell=None, epsilon=0.2, return_att=False, att_mask=None):
  260. verb_rank, noun_rank, curr_hidden, curr_cell = \
  261. self.get_ranks(input_description, prev_hidden,
  262. prev_cell, return_att=return_att, att_mask=att_mask) # batch x n_verb, batch x n_noun
  263. curr_hidden = curr_hidden.detach()
  264. curr_cell = curr_cell.detach()
  265. v_qvalue_maxq, v_idx_maxq, n_qvalue_maxq, n_idx_maxq = self.choose_maxQ_command(verb_rank, noun_rank)
  266. v_qvalue_random, v_idx_random, n_qvalue_random, n_idx_random = self.choose_random_command(verb_rank, noun_rank)
  267. # random number for epsilon greedy
  268. rand_num = np.random.uniform(low=0.0, high=1.0, size=(input_description.size(0),))
  269. less_than_epsilon = (rand_num < epsilon).astype("float32") # batch
  270. greater_than_epsilon = 1.0 - less_than_epsilon
  271. less_than_epsilon = to_pt(less_than_epsilon, self.use_cuda, type='float')
  272. greater_than_epsilon = to_pt(greater_than_epsilon, self.use_cuda, type='float')
  273. less_than_epsilon, greater_than_epsilon = less_than_epsilon.long(), greater_than_epsilon.long()
  274. v_idx = less_than_epsilon * v_idx_random + greater_than_epsilon * v_idx_maxq
  275. n_idx = less_than_epsilon * n_idx_random + greater_than_epsilon * n_idx_maxq
  276. v_idx, n_idx = v_idx.detach(), n_idx.detach()
  277. chosen_strings = self.get_chosen_strings(v_idx, n_idx)
  278. return v_idx, n_idx, chosen_strings, curr_hidden, curr_cell
  279. def get_game_step_info(self, obs, infos, prev_actions=None, prune=False,
  280. ret_desc=False, teacher_actions=None):
  281. # concat d/i/q/f together as one string
  282. if prune:
  283. inventory_strings = [self.bs_obj.prune_state(info["inventory"], teacher_actions[k], add_prefix=False) for k, info in enumerate(infos)]
  284. else:
  285. inventory_strings = [info["inventory"] for info in infos]
  286. inventory_token_list = [preproc(item, str_type='inventory', lower_case=True) for item in inventory_strings]
  287. inventory_id_list = [_words_to_ids(tokens, self.word2id) for tokens in inventory_token_list]
  288. if prune:
  289. feedback_strings = [self.bs_obj.prune_state(info["command_feedback"], teacher_actions[k], add_prefix=False)
  290. for k, info in enumerate(infos)]
  291. else:
  292. feedback_strings = [info["command_feedback"] for info in infos]
  293. feedback_token_list = [preproc(item, str_type='feedback', lower_case=True) for item in feedback_strings]
  294. feedback_id_list = [_words_to_ids(tokens, self.word2id) for tokens in feedback_token_list]
  295. orig_quest_string = [info["objective"] for info in infos]
  296. if prune:
  297. quest_strings = [self.bs_obj.prune_state(info["objective"], teacher_actions[k], add_prefix=False) for k, info in enumerate(infos)]
  298. else:
  299. quest_strings = [info["objective"] for info in infos]
  300. quest_token_list = [preproc(item, str_type='None', lower_case=True) for item in quest_strings]
  301. quest_id_list = [_words_to_ids(tokens, self.word2id) for tokens in quest_token_list]
  302. prev_actions = prev_actions
  303. if prev_actions is not None:
  304. prev_action_token_list = [preproc(item, str_type='None', lower_case=True) for item in prev_actions]
  305. prev_action_id_list = [_words_to_ids(tokens, self.word2id) for tokens in prev_action_token_list]
  306. else:
  307. prev_action_token_list = [[] for _ in infos]
  308. prev_action_id_list = [[] for _ in infos]
  309. if prune:
  310. description_strings = [self.bs_obj.prune_state(info["description"], teacher_actions[k]) for k, info in enumerate(infos)]
  311. else:
  312. description_strings = [info["description"] for info in infos]
  313. description_token_list = [preproc(item, str_type='description', lower_case=True) for item in description_strings]
  314. for i, d in enumerate(description_token_list):
  315. if len(d) == 0:
  316. description_token_list[i] = ["end"] # hack here, if empty description, insert word "end"
  317. description_id_list = [_words_to_ids(tokens, self.word2id) for tokens in description_token_list]
  318. description_id_list = [_d + _i + _q + _f + _pa for (_d, _i, _q, _f, _pa) in
  319. zip(description_id_list, inventory_id_list, quest_id_list,
  320. feedback_id_list, prev_action_id_list)]
  321. description_str_list = [_d + _i + _q + _f + _pa for (_d, _i, _q, _f, _pa) in
  322. zip(description_token_list, inventory_token_list,
  323. quest_token_list, feedback_token_list, prev_action_token_list)]
  324. self.observation_cache.push(description_id_list)
  325. description_with_history_id_list = self.observation_cache.get_all()
  326. input_description = pad_sequences(description_with_history_id_list,
  327. maxlen=max_len(description_with_history_id_list),
  328. padding='post').astype('int32')
  329. input_description = to_pt(input_description, self.use_cuda)
  330. if ret_desc:
  331. return input_description, description_with_history_id_list, description_str_list, orig_quest_string
  332. else:
  333. return input_description, description_with_history_id_list
  334. def get_observation_strings(self, infos):
  335. # concat game_id_d/i/q together as one string
  336. game_file_names = [info["game_file"] for info in infos]
  337. inventory_strings = [info["inventory"] for info in infos]
  338. description_strings = [info["description"] for info in infos]
  339. observation_strings = [_n + _d + _i for (_n, _d, _i) in zip(game_file_names, description_strings, inventory_strings)]
  340. return observation_strings
  341. def compute_reward(self, revisit_counting_lambda=0.0, revisit_counting=True):
  342. if len(self.dones) == 1:
  343. mask = [1.0 for _ in self.dones[-1]]
  344. else:
  345. assert len(self.dones) > 1
  346. mask = [1.0 if not self.dones[-2][i] else 0.0 for i in range(len(self.dones[-1]))]
  347. mask = np.array(mask, dtype='float32')
  348. mask_pt = to_pt(mask, self.use_cuda, type='float')
  349. # self.rewards: list of list, max_game_length x batch_size
  350. rewards = np.array(self.rewards[-1], dtype='float32') # batch
  351. if revisit_counting:
  352. # rewards += np.array(self.intermediate_rewards[-1], dtype='float32')
  353. if len(self.revisit_counting_rewards) > 0:
  354. rewards = rewards + np.array(self.revisit_counting_rewards[-1], dtype='float32') * revisit_counting_lambda
  355. rewards_pt = to_pt(rewards, self.use_cuda, type='float')
  356. # memory mask: play one more step after done
  357. if len(self.dones) < 3:
  358. memory_mask = [1.0 for _ in self.dones[-1]]
  359. else:
  360. memory_mask = [1.0 if mask[i] == 1 or ((not self.dones[-3][i]) and self.dones[-2][i])
  361. else 0.0 for i in range(len(self.dones[-1]))]
  362. return rewards, rewards_pt, mask, mask_pt, memory_mask
  363. def update(self, replay_batch_size, history_size, update_from=0, discount_gamma=0.0):
  364. if len(self.replay_memory) < replay_batch_size:
  365. return None
  366. transitions = self.replay_memory.get_batch(replay_batch_size, history_size + 1) # list (history_size + 1) of list (batch) of tuples
  367. # last transitions is just for computing the last Q function
  368. if transitions is None:
  369. return None
  370. sequences = [Transition(*zip(*batch)) for batch in transitions]
  371. losses = []
  372. prev_ras_hidden, prev_ras_cell = None, None # ras: recurrent action scorer
  373. observation_id_list = pad_sequences(sequences[0].observation_id_list, maxlen=max_len(sequences[0].observation_id_list), padding='post').astype('int32')
  374. input_observation = to_pt(observation_id_list, self.use_cuda)
  375. v_idx = torch.stack(sequences[0].v_idx, 0) # batch x 1
  376. n_idx = torch.stack(sequences[0].n_idx, 0) # batch x 1
  377. verb_rank, noun_rank, curr_ras_hidden, curr_ras_cell = self.get_ranks(input_observation, prev_ras_hidden, prev_ras_cell)
  378. v_qvalue, n_qvalue = verb_rank.gather(1, v_idx.unsqueeze(-1)).squeeze(-1), noun_rank.gather(1, n_idx.unsqueeze(-1)).squeeze(-1) # batch
  379. prev_qvalue = torch.mean(torch.stack([v_qvalue, n_qvalue], -1), -1) # batch
  380. if update_from > 0:
  381. prev_qvalue, curr_ras_hidden, curr_ras_cell = prev_qvalue.detach(), curr_ras_hidden.detach(), curr_ras_cell.detach()
  382. for i in range(1, len(sequences)):
  383. observation_id_list = pad_sequences(sequences[i].observation_id_list,
  384. maxlen=max_len(sequences[i].observation_id_list),
  385. padding='post').astype('int32')
  386. input_observation = to_pt(observation_id_list, self.use_cuda)
  387. v_idx = torch.stack(sequences[i].v_idx, 0) # batch x 1
  388. n_idx = torch.stack(sequences[i].n_idx, 0) # batch x 1
  389. verb_rank, noun_rank, curr_ras_hidden, curr_ras_cell = self.get_ranks(input_observation,
  390. curr_ras_hidden,
  391. curr_ras_cell)
  392. v_qvalue_max, _, n_qvalue_max, _ = self.choose_maxQ_command(verb_rank, noun_rank)
  393. q_value_max = torch.mean(torch.stack([v_qvalue_max, n_qvalue_max], -1), -1) # batch
  394. q_value_max = q_value_max.detach()
  395. v_qvalue, n_qvalue = verb_rank.gather(1, v_idx.unsqueeze(-1)).squeeze(-1), \
  396. noun_rank.gather(1, n_idx.unsqueeze(-1)).squeeze(-1) # batch
  397. q_value = torch.mean(torch.stack([v_qvalue, n_qvalue], -1), -1) # batch
  398. if i < update_from or i == len(sequences) - 1:
  399. q_value, curr_ras_hidden, curr_ras_cell = q_value.detach(), curr_ras_hidden.detach(), \
  400. curr_ras_cell.detach()
  401. if i > update_from:
  402. prev_rewards = torch.stack(sequences[i - 1].reward) # batch
  403. prev_not_done = 1.0 - np.array(sequences[i - 1].done, dtype='float32') # batch
  404. prev_not_done = to_pt(prev_not_done, self.use_cuda, type='float')
  405. prev_rewards = prev_rewards + prev_not_done * q_value_max * discount_gamma # batch
  406. prev_mask = torch.stack(sequences[i - 1].mask) # batch
  407. prev_loss = F.smooth_l1_loss(prev_qvalue * prev_mask, prev_rewards * prev_mask) # huber_loss
  408. losses.append(prev_loss)
  409. prev_qvalue = q_value
  410. return torch.stack(losses).mean()
  411. def finish(self):
  412. self.final_rewards = np.array(self.rewards[-1], dtype='float32') # batch
  413. self.final_counting_rewards = np.sum(np.array(self.revisit_counting_rewards), 0) # batch
  414. dones = []
  415. for d in self.dones:
  416. d = np.array([float(dd) for dd in d], dtype='float32')
  417. dones.append(d)
  418. dones = np.array(dones)
  419. step_used = 1.0 - dones
  420. self.step_used_before_done = np.sum(step_used, 0) # batch
  421. self.final_intermediate_rewards = []
  422. intermediate_rewards = np.array(self.intermediate_rewards) # step x batch
  423. intermediate_rewards = np.transpose(intermediate_rewards, (1, 0)) # batch x step
  424. for i in range(intermediate_rewards.shape[0]):
  425. self.final_intermediate_rewards.append(np.sum(intermediate_rewards[i][:int(self.step_used_before_done[i]) + 1]))
  426. self.final_intermediate_rewards = np.array(self.final_intermediate_rewards)
  427. def reset_binarized_counter(self, batch_size):
  428. self.binarized_counter_dict = [{} for _ in range(batch_size)]
  429. def get_binarized_count(self, observation_strings, update=True):
  430. batch_size = len(observation_strings)
  431. count_rewards = []
  432. for i in range(batch_size):
  433. concat_string = observation_strings[i]
  434. if concat_string not in self.binarized_counter_dict[i]:
  435. self.binarized_counter_dict[i][concat_string] = 0.0
  436. if update:
  437. self.binarized_counter_dict[i][concat_string] += 1.0
  438. r = self.binarized_counter_dict[i][concat_string]
  439. r = float(r == 1.0)
  440. count_rewards.append(r)
  441. return count_rewards
  442. def state_dict(self):
  443. return {'model': self.model.state_dict(), 'optimizer': self.optimizer.state_dict()}
  444. def load_state_dict(self, state):
  445. self.model.load_state_dict(state['model'])
  446. self.optimizer.load_state_dict(state['optimizer'])