model.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import logging
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. from crest.helper.layers import Embedding, masked_mean, LSTMCell, FastUniLSTM
  6. import math
  7. import torch.nn as nn
  8. logger = logging.getLogger(__name__)
  9. class LSTM_DQN(torch.nn.Module):
  10. model_name = 'lstm_dqn'
  11. def __init__(self, model_config, word_vocab, verb_map, noun_map, enable_cuda=False):
  12. super(LSTM_DQN, self).__init__()
  13. self.model_config = model_config
  14. self.enable_cuda = enable_cuda
  15. self.word_vocab_size = len(word_vocab)
  16. self.id2word = word_vocab
  17. self.n_actions = len(verb_map)
  18. self.n_objects = len(noun_map)
  19. self.read_config()
  20. self._def_layers()
  21. self.init_weights()
  22. # self.print_parameters()
  23. def print_parameters(self):
  24. amount = 0
  25. for p in self.parameters():
  26. amount += np.prod(p.size())
  27. print("total number of parameters: %s" % (amount))
  28. parameters = filter(lambda p: p.requires_grad, self.parameters())
  29. amount = 0
  30. for p in parameters:
  31. amount += np.prod(p.size())
  32. print("number of trainable parameters: %s" % (amount))
  33. def read_config(self):
  34. # model config
  35. config = self.model_config[self.model_name]
  36. self.embedding_size = config['embedding_size']
  37. self.encoder_rnn_hidden_size = config['encoder_rnn_hidden_size']
  38. self.action_scorer_hidden_dim = config['action_scorer_hidden_dim']
  39. # import ipdb; ipdb.set_trace()
  40. self.dropout_between_rnn_layers = config['dropout_between_rnn_layers']
  41. def _def_layers(self):
  42. # word embeddings
  43. self.word_embedding = Embedding(embedding_size=self.embedding_size, vocab_size=self.word_vocab_size, enable_cuda=self.enable_cuda)
  44. # lstm encoder
  45. self.encoder = FastUniLSTM(ninp=self.embedding_size, nhids=self.encoder_rnn_hidden_size,
  46. dropout_between_rnn_layers=self.dropout_between_rnn_layers)
  47. # Recurrent network for temporal dependencies (a.k.a history).
  48. self.action_scorer_shared_recurrent = LSTMCell(input_size=self.encoder_rnn_hidden_size[-1],
  49. hidden_size=self.action_scorer_hidden_dim)
  50. self.action_scorer_shared = torch.nn.Linear(self.encoder_rnn_hidden_size[-1], self.action_scorer_hidden_dim)
  51. self.action_scorer_action = torch.nn.Linear(self.action_scorer_hidden_dim, self.n_actions, bias=False)
  52. self.action_scorer_object = torch.nn.Linear(self.action_scorer_hidden_dim, self.n_objects, bias=False)
  53. self.fake_recurrent_mask = None
  54. def init_weights(self):
  55. torch.nn.init.xavier_uniform_(self.action_scorer_shared.weight.data, gain=1)
  56. torch.nn.init.xavier_uniform_(self.action_scorer_action.weight.data, gain=1)
  57. torch.nn.init.xavier_uniform_(self.action_scorer_object.weight.data, gain=1)
  58. self.action_scorer_shared.bias.data.fill_(0)
  59. def representation_generator(self, _input_words):
  60. embeddings, mask = self.word_embedding.forward(_input_words) # batch x time x emb
  61. encoding_sequence, _, _ = self.encoder.forward(embeddings, mask) # batch x time x h
  62. mean_encoding = masked_mean(encoding_sequence, mask) # batch x h
  63. return mean_encoding
  64. def recurrent_action_scorer(self, state_representation, last_hidden=None, last_cell=None):
  65. # state representation: batch x input
  66. # last hidden / last cell: batch x hid
  67. if self.fake_recurrent_mask is None or self.fake_recurrent_mask.size(0) != state_representation.size(0):
  68. self.fake_recurrent_mask = torch.autograd.Variable(torch.ones(state_representation.size(0),))
  69. if self.enable_cuda:
  70. self.fake_recurrent_mask = self.fake_recurrent_mask.cuda()
  71. new_h, new_c = self.action_scorer_shared_recurrent.forward(state_representation, self.fake_recurrent_mask,
  72. last_hidden, last_cell)
  73. action_rank = self.action_scorer_action.forward(new_h) # batch x n_action
  74. object_rank = self.action_scorer_object.forward(new_h) # batch x n_object
  75. return action_rank, object_rank, new_h, new_c
  76. def action_scorer(self, state_representation):
  77. hidden = self.action_scorer_shared.forward(state_representation) # batch x hid
  78. hidden = F.relu(hidden) # batch x hid
  79. action_rank = self.action_scorer_action.forward(hidden) # batch x n_action
  80. object_rank = self.action_scorer_object.forward(hidden) # batch x n_object
  81. return action_rank, object_rank
  82. from torch.autograd import Variable
  83. class LSTM_DQN_ATT(LSTM_DQN):
  84. model_name = 'lstm_dqn'
  85. def __init__(self, *args, **kwargs):
  86. super(LSTM_DQN_ATT, self).__init__(*args, **kwargs)
  87. # self.attn = torch.nn.Linear(self.encoder_rnn_hidden_size[0], 1)
  88. self.attn_inner = torch.nn.Linear(self.encoder_rnn_hidden_size[0], 32)
  89. self.attn_outer = torch.nn.Linear(32, 1, bias=False)
  90. if self.enable_cuda:
  91. self.attn_inner = self.attn_inner.cuda()
  92. self.attn_outer = self.attn_outer.cuda()
  93. def representation_generator(self, _input_words, return_att=False, att_mask=None):
  94. embeddings, mask = self.word_embedding.forward(_input_words) # batch x time x emb
  95. encoding_sequence, _, _ = self.encoder.forward(embeddings, mask) # batch x time x h
  96. softmax_att = torch.zeros(encoding_sequence.shape[:-1], requires_grad=True)
  97. if self.enable_cuda:
  98. softmax_att = softmax_att.cuda()
  99. for i in range(len(encoding_sequence)):
  100. numel = int(torch.sum(mask[i]).item())
  101. logit_attn = self.attn_outer(F.tanh(self.attn_inner(encoding_sequence[i][:numel])))
  102. softmax_att[i, :numel] = F.softmax(logit_attn, 0).squeeze(-1)
  103. if att_mask is not None:
  104. softmax_att = softmax_att * att_mask
  105. mean_encoding = torch.bmm(softmax_att.unsqueeze(1), encoding_sequence).squeeze(1)
  106. if return_att:
  107. return mean_encoding, softmax_att
  108. else:
  109. return mean_encoding