layers.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import torch
  2. import numpy as np
  3. import torch.nn.functional as F
  4. def masked_mean(x, m=None, dim=-1):
  5. """
  6. mean pooling when there're paddings
  7. input: tensor: batch x time x h
  8. mask: batch x time
  9. output: tensor: batch x h
  10. """
  11. if m is None:
  12. return torch.mean(x, dim=dim)
  13. mask_sum = torch.sum(m, dim=-1) # batch
  14. res = torch.sum(x, dim=1) # batch x h
  15. res = res / (mask_sum.unsqueeze(-1) + 1e-6)
  16. return res
  17. class LayerNorm(torch.nn.Module):
  18. def __init__(self, input_dim):
  19. super(LayerNorm, self).__init__()
  20. self.gamma = torch.nn.Parameter(torch.ones(input_dim))
  21. self.beta = torch.nn.Parameter(torch.zeros(input_dim))
  22. self.eps = 1e-6
  23. def forward(self, x, mask):
  24. # x: nbatch x hidden
  25. # mask: nbatch
  26. mean = x.mean(-1, keepdim=True)
  27. std = torch.sqrt(x.var(dim=1, keepdim=True) + self.eps)
  28. output = self.gamma * (x - mean) / (std + self.eps) + self.beta
  29. return output * mask.unsqueeze(1)
  30. class Embedding(torch.nn.Module):
  31. '''
  32. inputs: x: batch x seq (x is post-padded by 0s)
  33. outputs:embedding: batch x seq x emb
  34. mask: batch x seq
  35. '''
  36. def __init__(self, embedding_size, vocab_size, enable_cuda=False):
  37. super(Embedding, self).__init__()
  38. self.embedding_size = embedding_size
  39. self.vocab_size = vocab_size
  40. self.enable_cuda = enable_cuda
  41. self.embedding_layer = torch.nn.Embedding(self.vocab_size, self.embedding_size, padding_idx=0)
  42. def compute_mask(self, x):
  43. mask = torch.ne(x, 0).float()
  44. if self.enable_cuda:
  45. mask = mask.cuda()
  46. return mask
  47. def forward(self, x):
  48. embeddings = self.embedding_layer(x) # batch x time x emb
  49. mask = self.compute_mask(x) # batch x time
  50. return embeddings, mask
  51. class LSTMCell(torch.nn.Module):
  52. """A basic LSTM cell."""
  53. def __init__(self, input_size, hidden_size, use_layernorm=False, use_bias=True):
  54. """
  55. Most parts are copied from torch.nn.LSTMCell.
  56. """
  57. super(LSTMCell, self).__init__()
  58. self.input_size = input_size
  59. self.hidden_size = hidden_size
  60. self.use_bias = use_bias
  61. self.use_layernorm = use_layernorm
  62. self.weight_ih = torch.nn.Parameter(torch.FloatTensor(input_size, 4 * hidden_size))
  63. self.weight_hh = torch.nn.Parameter(torch.FloatTensor(hidden_size, 4 * hidden_size))
  64. if use_bias:
  65. self.bias_f = torch.nn.Parameter(torch.FloatTensor(hidden_size))
  66. self.bias_iog = torch.nn.Parameter(torch.FloatTensor(3 * hidden_size))
  67. else:
  68. self.register_parameter('bias', None)
  69. if self.use_layernorm:
  70. self.layernorm_i = LayerNorm(input_dim=self.hidden_size * 4)
  71. self.layernorm_h = LayerNorm(input_dim=self.hidden_size * 4)
  72. self.layernorm_c = LayerNorm(input_dim=self.hidden_size)
  73. self.reset_parameters()
  74. def reset_parameters(self):
  75. torch.nn.init.orthogonal_(self.weight_hh.data)
  76. torch.nn.init.xavier_uniform_(self.weight_ih.data, gain=1)
  77. if self.use_bias:
  78. self.bias_f.data.fill_(1.0)
  79. self.bias_iog.data.fill_(0.0)
  80. def get_init_hidden(self, bsz, use_cuda):
  81. h_0 = torch.autograd.Variable(torch.FloatTensor(bsz, self.hidden_size).zero_())
  82. c_0 = torch.autograd.Variable(torch.FloatTensor(bsz, self.hidden_size).zero_())
  83. if use_cuda:
  84. h_0, c_0 = h_0.cuda(), c_0.cuda()
  85. return h_0, c_0
  86. def forward(self, input_, mask_, h_0=None, c_0=None, dropped_h_0=None):
  87. if h_0 is None or c_0 is None:
  88. h_init, c_init = self.get_init_hidden(input_.size(0), use_cuda=input_.is_cuda)
  89. if h_0 is None:
  90. h_0 = h_init
  91. if c_0 is None:
  92. c_0 = c_init
  93. if dropped_h_0 is None:
  94. dropped_h_0 = h_0
  95. wh = torch.mm(dropped_h_0, self.weight_hh)
  96. wi = torch.mm(input_, self.weight_ih)
  97. if self.use_layernorm:
  98. wi = self.layernorm_i(wi, mask_)
  99. wh = self.layernorm_h(wh, mask_)
  100. pre_act = wi + wh
  101. if self.use_bias:
  102. pre_act = pre_act + torch.cat([self.bias_f, self.bias_iog]).unsqueeze(0)
  103. f, i, o, g = torch.split(pre_act, split_size_or_sections=self.hidden_size, dim=1)
  104. expand_mask_ = mask_.unsqueeze(1) # batch x None
  105. c_1 = torch.sigmoid(f) * c_0 + torch.sigmoid(i) * torch.tanh(g)
  106. c_1 = c_1 * expand_mask_ + c_0 * (1 - expand_mask_)
  107. if self.use_layernorm:
  108. h_1 = torch.sigmoid(o) * torch.tanh(self.layernorm_c(c_1, mask_))
  109. else:
  110. h_1 = torch.sigmoid(o) * torch.tanh(c_1)
  111. h_1 = h_1 * expand_mask_ + h_0 * (1 - expand_mask_)
  112. return h_1, c_1
  113. def __repr__(self):
  114. s = '{name}({input_size}, {hidden_size})'
  115. return s.format(name=self.__class__.__name__, **self.__dict__)
  116. class FastUniLSTM(torch.nn.Module):
  117. def __init__(self, ninp, nhids, dropout_between_rnn_layers=0.):
  118. super(FastUniLSTM, self).__init__()
  119. self.ninp = ninp
  120. self.nhids = nhids
  121. self.nlayers = len(self.nhids)
  122. self.dropout_between_rnn_layers = dropout_between_rnn_layers
  123. self.stack_rnns()
  124. if self.dropout_between_rnn_layers > 0:
  125. print('##'*30)
  126. print('Using Dropout')
  127. print('##'*30)
  128. else:
  129. print('##'*30)
  130. print('Not Using Dropout')
  131. print('##'*30)
  132. def stack_rnns(self):
  133. rnns = [torch.nn.LSTM(self.ninp if i == 0 else self.nhids[i - 1],
  134. self.nhids[i],
  135. num_layers=1,
  136. bidirectional=False) for i in range(self.nlayers)]
  137. self.rnns = torch.nn.ModuleList(rnns)
  138. def forward(self, x, mask):
  139. def pad_(tensor, n):
  140. if n > 0:
  141. zero_pad = torch.autograd.Variable(torch.zeros((n,) + tensor.size()[1:]))
  142. if x.is_cuda:
  143. zero_pad = zero_pad.cuda()
  144. tensor = torch.cat([tensor, zero_pad])
  145. return tensor
  146. # Compute sorted sequence lengths
  147. batch_size = x.size(0)
  148. lengths = mask.data.eq(1).long().sum(1) # .squeeze()
  149. _, idx_sort = torch.sort(lengths, dim=0, descending=True)
  150. _, idx_unsort = torch.sort(idx_sort, dim=0)
  151. lengths = list(lengths[idx_sort])
  152. idx_sort = torch.autograd.Variable(idx_sort)
  153. idx_unsort = torch.autograd.Variable(idx_unsort)
  154. # Sort x
  155. x = x.index_select(0, idx_sort)
  156. # remove non-zero rows, and remember how many zeros
  157. n_nonzero = np.count_nonzero(lengths)
  158. n_zero = batch_size - n_nonzero
  159. if n_zero != 0:
  160. lengths = lengths[:n_nonzero]
  161. x = x[:n_nonzero]
  162. # Transpose batch and sequence dims
  163. x = x.transpose(0, 1)
  164. # Pack it up
  165. rnn_input = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
  166. # Encode all layers
  167. outputs = [rnn_input]
  168. for i in range(self.nlayers):
  169. rnn_input = outputs[-1]
  170. # dropout between rnn layers
  171. if self.dropout_between_rnn_layers > 0:
  172. dropout_input = F.dropout(rnn_input.data, p=self.dropout_between_rnn_layers,
  173. training=self.training)
  174. rnn_input = torch.nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes)
  175. seq, last = self.rnns[i](rnn_input)
  176. outputs.append(seq)
  177. if i == self.nlayers - 1:
  178. # last layer
  179. last_state = last[0] # (num_layers * num_directions, batch, hidden_size)
  180. last_state = last_state[0] # batch x hidden_size
  181. # Unpack everything
  182. for i, o in enumerate(outputs[1:], 1):
  183. outputs[i] = torch.nn.utils.rnn.pad_packed_sequence(o)[0]
  184. output = outputs[-1]
  185. # Transpose and unsort
  186. output = output.transpose(0, 1) # batch x time x enc
  187. # re-padding
  188. output = pad_(output, n_zero)
  189. last_state = pad_(last_state, n_zero)
  190. output = output.index_select(0, idx_unsort)
  191. last_state = last_state.index_select(0, idx_unsort)
  192. # Pad up to original batch sequence length
  193. if output.size(1) != mask.size(1):
  194. padding = torch.zeros(output.size(0),
  195. mask.size(1) - output.size(1),
  196. output.size(2)).type(output.data.type())
  197. output = torch.cat([output, torch.autograd.Variable(padding)], 1)
  198. output = output.contiguous() * mask.unsqueeze(-1)
  199. return output, last_state, mask