123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246 |
- import torch
- import numpy as np
- import torch.nn.functional as F
- def masked_mean(x, m=None, dim=-1):
- """
- mean pooling when there're paddings
- input: tensor: batch x time x h
- mask: batch x time
- output: tensor: batch x h
- """
- if m is None:
- return torch.mean(x, dim=dim)
- mask_sum = torch.sum(m, dim=-1) # batch
- res = torch.sum(x, dim=1) # batch x h
- res = res / (mask_sum.unsqueeze(-1) + 1e-6)
- return res
- class LayerNorm(torch.nn.Module):
- def __init__(self, input_dim):
- super(LayerNorm, self).__init__()
- self.gamma = torch.nn.Parameter(torch.ones(input_dim))
- self.beta = torch.nn.Parameter(torch.zeros(input_dim))
- self.eps = 1e-6
- def forward(self, x, mask):
- # x: nbatch x hidden
- # mask: nbatch
- mean = x.mean(-1, keepdim=True)
- std = torch.sqrt(x.var(dim=1, keepdim=True) + self.eps)
- output = self.gamma * (x - mean) / (std + self.eps) + self.beta
- return output * mask.unsqueeze(1)
- class Embedding(torch.nn.Module):
- '''
- inputs: x: batch x seq (x is post-padded by 0s)
- outputs:embedding: batch x seq x emb
- mask: batch x seq
- '''
- def __init__(self, embedding_size, vocab_size, enable_cuda=False):
- super(Embedding, self).__init__()
- self.embedding_size = embedding_size
- self.vocab_size = vocab_size
- self.enable_cuda = enable_cuda
- self.embedding_layer = torch.nn.Embedding(self.vocab_size, self.embedding_size, padding_idx=0)
- def compute_mask(self, x):
- mask = torch.ne(x, 0).float()
- if self.enable_cuda:
- mask = mask.cuda()
- return mask
- def forward(self, x):
- embeddings = self.embedding_layer(x) # batch x time x emb
- mask = self.compute_mask(x) # batch x time
- return embeddings, mask
- class LSTMCell(torch.nn.Module):
- """A basic LSTM cell."""
- def __init__(self, input_size, hidden_size, use_layernorm=False, use_bias=True):
- """
- Most parts are copied from torch.nn.LSTMCell.
- """
- super(LSTMCell, self).__init__()
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.use_bias = use_bias
- self.use_layernorm = use_layernorm
- self.weight_ih = torch.nn.Parameter(torch.FloatTensor(input_size, 4 * hidden_size))
- self.weight_hh = torch.nn.Parameter(torch.FloatTensor(hidden_size, 4 * hidden_size))
- if use_bias:
- self.bias_f = torch.nn.Parameter(torch.FloatTensor(hidden_size))
- self.bias_iog = torch.nn.Parameter(torch.FloatTensor(3 * hidden_size))
- else:
- self.register_parameter('bias', None)
- if self.use_layernorm:
- self.layernorm_i = LayerNorm(input_dim=self.hidden_size * 4)
- self.layernorm_h = LayerNorm(input_dim=self.hidden_size * 4)
- self.layernorm_c = LayerNorm(input_dim=self.hidden_size)
- self.reset_parameters()
- def reset_parameters(self):
- torch.nn.init.orthogonal_(self.weight_hh.data)
- torch.nn.init.xavier_uniform_(self.weight_ih.data, gain=1)
- if self.use_bias:
- self.bias_f.data.fill_(1.0)
- self.bias_iog.data.fill_(0.0)
- def get_init_hidden(self, bsz, use_cuda):
- h_0 = torch.autograd.Variable(torch.FloatTensor(bsz, self.hidden_size).zero_())
- c_0 = torch.autograd.Variable(torch.FloatTensor(bsz, self.hidden_size).zero_())
- if use_cuda:
- h_0, c_0 = h_0.cuda(), c_0.cuda()
- return h_0, c_0
- def forward(self, input_, mask_, h_0=None, c_0=None, dropped_h_0=None):
- if h_0 is None or c_0 is None:
- h_init, c_init = self.get_init_hidden(input_.size(0), use_cuda=input_.is_cuda)
- if h_0 is None:
- h_0 = h_init
- if c_0 is None:
- c_0 = c_init
- if dropped_h_0 is None:
- dropped_h_0 = h_0
- wh = torch.mm(dropped_h_0, self.weight_hh)
- wi = torch.mm(input_, self.weight_ih)
- if self.use_layernorm:
- wi = self.layernorm_i(wi, mask_)
- wh = self.layernorm_h(wh, mask_)
- pre_act = wi + wh
- if self.use_bias:
- pre_act = pre_act + torch.cat([self.bias_f, self.bias_iog]).unsqueeze(0)
- f, i, o, g = torch.split(pre_act, split_size_or_sections=self.hidden_size, dim=1)
- expand_mask_ = mask_.unsqueeze(1) # batch x None
- c_1 = torch.sigmoid(f) * c_0 + torch.sigmoid(i) * torch.tanh(g)
- c_1 = c_1 * expand_mask_ + c_0 * (1 - expand_mask_)
- if self.use_layernorm:
- h_1 = torch.sigmoid(o) * torch.tanh(self.layernorm_c(c_1, mask_))
- else:
- h_1 = torch.sigmoid(o) * torch.tanh(c_1)
- h_1 = h_1 * expand_mask_ + h_0 * (1 - expand_mask_)
- return h_1, c_1
- def __repr__(self):
- s = '{name}({input_size}, {hidden_size})'
- return s.format(name=self.__class__.__name__, **self.__dict__)
- class FastUniLSTM(torch.nn.Module):
- def __init__(self, ninp, nhids, dropout_between_rnn_layers=0.):
- super(FastUniLSTM, self).__init__()
- self.ninp = ninp
- self.nhids = nhids
- self.nlayers = len(self.nhids)
- self.dropout_between_rnn_layers = dropout_between_rnn_layers
- self.stack_rnns()
- if self.dropout_between_rnn_layers > 0:
- print('##'*30)
- print('Using Dropout')
- print('##'*30)
- else:
- print('##'*30)
- print('Not Using Dropout')
- print('##'*30)
- def stack_rnns(self):
- rnns = [torch.nn.LSTM(self.ninp if i == 0 else self.nhids[i - 1],
- self.nhids[i],
- num_layers=1,
- bidirectional=False) for i in range(self.nlayers)]
- self.rnns = torch.nn.ModuleList(rnns)
- def forward(self, x, mask):
- def pad_(tensor, n):
- if n > 0:
- zero_pad = torch.autograd.Variable(torch.zeros((n,) + tensor.size()[1:]))
- if x.is_cuda:
- zero_pad = zero_pad.cuda()
- tensor = torch.cat([tensor, zero_pad])
- return tensor
- # Compute sorted sequence lengths
- batch_size = x.size(0)
- lengths = mask.data.eq(1).long().sum(1) # .squeeze()
- _, idx_sort = torch.sort(lengths, dim=0, descending=True)
- _, idx_unsort = torch.sort(idx_sort, dim=0)
- lengths = list(lengths[idx_sort])
- idx_sort = torch.autograd.Variable(idx_sort)
- idx_unsort = torch.autograd.Variable(idx_unsort)
- # Sort x
- x = x.index_select(0, idx_sort)
- # remove non-zero rows, and remember how many zeros
- n_nonzero = np.count_nonzero(lengths)
- n_zero = batch_size - n_nonzero
- if n_zero != 0:
- lengths = lengths[:n_nonzero]
- x = x[:n_nonzero]
- # Transpose batch and sequence dims
- x = x.transpose(0, 1)
- # Pack it up
- rnn_input = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
- # Encode all layers
- outputs = [rnn_input]
- for i in range(self.nlayers):
- rnn_input = outputs[-1]
- # dropout between rnn layers
- if self.dropout_between_rnn_layers > 0:
- dropout_input = F.dropout(rnn_input.data, p=self.dropout_between_rnn_layers,
- training=self.training)
- rnn_input = torch.nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes)
-
- seq, last = self.rnns[i](rnn_input)
- outputs.append(seq)
- if i == self.nlayers - 1:
- # last layer
- last_state = last[0] # (num_layers * num_directions, batch, hidden_size)
- last_state = last_state[0] # batch x hidden_size
- # Unpack everything
- for i, o in enumerate(outputs[1:], 1):
- outputs[i] = torch.nn.utils.rnn.pad_packed_sequence(o)[0]
- output = outputs[-1]
- # Transpose and unsort
- output = output.transpose(0, 1) # batch x time x enc
- # re-padding
- output = pad_(output, n_zero)
- last_state = pad_(last_state, n_zero)
- output = output.index_select(0, idx_unsort)
- last_state = last_state.index_select(0, idx_unsort)
- # Pad up to original batch sequence length
- if output.size(1) != mask.size(1):
- padding = torch.zeros(output.size(0),
- mask.size(1) - output.size(1),
- output.size(2)).type(output.data.type())
- output = torch.cat([output, torch.autograd.Variable(padding)], 1)
- output = output.contiguous() * mask.unsqueeze(-1)
- return output, last_state, mask
|