generic.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import os
  2. import numpy as np
  3. from os.path import join as pjoin
  4. from nltk.tokenize import word_tokenize as wt
  5. import torch
  6. from torch.autograd import Variable
  7. from textworld.utils import maybe_mkdir
  8. import math
  9. class SlidingAverage(object):
  10. def __init__(self, name, steps=100):
  11. self.name = name
  12. self.steps = steps
  13. self.t = 0
  14. self.ns = []
  15. self.avgs = []
  16. def add(self, n):
  17. if math.isnan(n):
  18. n = 0
  19. self.ns.append(n)
  20. if len(self.ns) > self.steps:
  21. self.ns.pop(0)
  22. self.t += 1
  23. if self.t % self.steps == 0:
  24. self.avgs.append(self.value)
  25. @property
  26. def value(self):
  27. if len(self.ns) == 0: return 0
  28. return sum(self.ns) / (len(self.ns) + 0.0000000001)
  29. @property
  30. def std(self):
  31. if len(self.ns) == 0: return 0
  32. std = np.std(self.ns)
  33. return std if not np.isnan(std) else 0
  34. def __str__(self):
  35. return "%s=%.4f" % (self.name, self.value)
  36. def __gt__(self, value): return self.value > value
  37. def __lt__(self, value): return self.value < value
  38. def state_dict(self):
  39. return {'t': self.t,
  40. 'ns': tuple(self.ns),
  41. 'avgs': tuple(self.avgs)}
  42. def load_state_dict(self, state):
  43. self.t = state["t"]
  44. self.ns = list(state["ns"])
  45. self.avgs = list(state["avgs"])
  46. def to_np(x):
  47. if isinstance(x, np.ndarray):
  48. return x
  49. return x.data.cpu().numpy()
  50. def to_pt(np_matrix, enable_cuda=False, type='long'):
  51. if type == 'long':
  52. if enable_cuda:
  53. return torch.autograd.Variable(torch.from_numpy(np_matrix).type(torch.LongTensor).cuda())
  54. else:
  55. return torch.autograd.Variable(torch.from_numpy(np_matrix).type(torch.LongTensor))
  56. elif type == 'float':
  57. if enable_cuda:
  58. return torch.autograd.Variable(torch.from_numpy(np_matrix).type(torch.FloatTensor).cuda())
  59. else:
  60. return torch.autograd.Variable(torch.from_numpy(np_matrix).type(torch.FloatTensor))
  61. def get_experiment_dir(config, makedir=True):
  62. env_id = config['general']['env_id']
  63. exps_dir = config['general']['experiments_dir']
  64. exp_tag = config['general']['experiment_tag']
  65. exp_dir = pjoin(exps_dir, env_id + "_" + exp_tag)
  66. if makedir:
  67. return maybe_mkdir(exp_dir)
  68. else:
  69. return exp_dir
  70. def dict2list(id2w_dict):
  71. res = []
  72. for item in id2w_dict:
  73. res.append(id2w_dict[item])
  74. return res
  75. def _words_to_ids(words, word2id):
  76. ids = []
  77. for word in words:
  78. try:
  79. ids.append(word2id[word])
  80. except KeyError:
  81. ids.append(1)
  82. return ids
  83. def preproc(s, str_type='None', lower_case=False):
  84. s = s.replace("\n", ' ')
  85. if s.strip() == "":
  86. return ["nothing"]
  87. if str_type == 'description':
  88. s = s.split("=-")[1]
  89. elif str_type == 'inventory':
  90. s = s.split("carrying")[1]
  91. if s[0] == ':':
  92. s = s[1:]
  93. elif str_type == 'feedback':
  94. if "Welcome to Textworld" in s:
  95. s = s.split("Welcome to Textworld")[1]
  96. if "-=" in s:
  97. s = s.split("-=")[0]
  98. s = s.strip()
  99. if len(s) == 0:
  100. return ["nothing"]
  101. tokens = wt(s)
  102. if lower_case:
  103. tokens = [t.lower() for t in tokens]
  104. return tokens
  105. def max_len(list_of_list):
  106. return max(map(len, list_of_list))
  107. def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncating='pre', value=0.):
  108. '''
  109. FROM KERAS
  110. Pads each sequence to the same length:
  111. the length of the longest sequence.
  112. If maxlen is provided, any sequence longer
  113. than maxlen is truncated to maxlen.
  114. Truncation happens off either the beginning (default) or
  115. the end of the sequence.
  116. Supports post-padding and pre-padding (default).
  117. # Arguments
  118. sequences: list of lists where each element is a sequence
  119. maxlen: int, maximum length
  120. dtype: type to cast the resulting sequence.
  121. padding: 'pre' or 'post', pad either before or after each sequence.
  122. truncating: 'pre' or 'post', remove values from sequences larger than
  123. maxlen either in the beginning or in the end of the sequence
  124. value: float, value to pad the sequences to the desired value.
  125. # Returns
  126. x: numpy array with dimensions (number_of_sequences, maxlen)
  127. '''
  128. lengths = [len(s) for s in sequences]
  129. nb_samples = len(sequences)
  130. if maxlen is None:
  131. maxlen = np.max(lengths)
  132. # take the sample shape from the first non empty sequence
  133. # checking for consistency in the main loop below.
  134. sample_shape = tuple()
  135. for s in sequences:
  136. if len(s) > 0:
  137. sample_shape = np.asarray(s).shape[1:]
  138. break
  139. x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype)
  140. for idx, s in enumerate(sequences):
  141. if len(s) == 0:
  142. continue # empty list was found
  143. if truncating == 'pre':
  144. trunc = s[-maxlen:]
  145. elif truncating == 'post':
  146. trunc = s[:maxlen]
  147. else:
  148. raise ValueError('Truncating type "%s" not understood' % truncating)
  149. # check `trunc` has expected shape
  150. trunc = np.asarray(trunc, dtype=dtype)
  151. if trunc.shape[1:] != sample_shape:
  152. raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %
  153. (trunc.shape[1:], idx, sample_shape))
  154. if padding == 'post':
  155. x[idx, :len(trunc)] = trunc
  156. elif padding == 'pre':
  157. x[idx, -len(trunc):] = trunc
  158. else:
  159. raise ValueError('Padding type "%s" not understood' % padding)
  160. return x