model.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.init as init
  4. class AttrProxy(object):
  5. """
  6. Translates index lookups into attribute lookups.
  7. To implement some trick which able to use list of nn.Module in a nn.Module
  8. see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2
  9. """
  10. def __init__(self, module, prefix):
  11. self.module = module
  12. self.prefix = prefix
  13. def __getitem__(self, i):
  14. return getattr(self.module, self.prefix + str(i))
  15. class Propogator(nn.Module):
  16. """
  17. Gated Propogator for GGNN
  18. Using LSTM gating mechanism
  19. """
  20. def __init__(self, state_dim, n_node, n_edge_types):
  21. super(Propogator, self).__init__()
  22. self.n_node = n_node
  23. self.n_edge_types = n_edge_types
  24. self.reset_gate = nn.Sequential(
  25. nn.Linear(state_dim * 3, state_dim),
  26. nn.Sigmoid()
  27. )
  28. self.update_gate = nn.Sequential(
  29. nn.Linear(state_dim * 3, state_dim),
  30. nn.Sigmoid()
  31. )
  32. self.tansform = nn.Sequential(
  33. nn.Linear(state_dim * 3, state_dim),
  34. nn.Tanh()
  35. )
  36. def forward(self, state_in, state_out, state_cur, A):
  37. A_in = A[:, :, :self.n_node * self.n_edge_types]
  38. A_out = A[:, :, self.n_node * self.n_edge_types:]
  39. a_in = torch.bmm(A_in, state_in)
  40. a_out = torch.bmm(A_out, state_out)
  41. a = torch.cat((a_in, a_out, state_cur), 2)
  42. r = self.reset_gate(a)
  43. z = self.update_gate(a)
  44. joined_input = torch.cat((a_in, a_out, r * state_cur), 2)
  45. h_hat = self.tansform(joined_input)
  46. output = (1 - z) * state_cur + z * h_hat
  47. return output
  48. class GGNN(nn.Module):
  49. """
  50. Gated Graph Sequence Neural Networks (GGNN)
  51. Mode: SelectNode
  52. Implementation based on https://arxiv.org/abs/1511.05493
  53. """
  54. def __init__(self, opt):
  55. super(GGNN, self).__init__()
  56. # assert (opt.state_dim >= opt.annotation_dim, 'state_dim must be no less than annotation_dim')
  57. self.is_training_ggnn = opt.is_training_ggnn
  58. self.state_dim = opt.state_dim
  59. self.n_edge_types = opt.n_edge_types
  60. self.n_node = opt.n_node
  61. self.n_steps = opt.n_steps
  62. self.n_classes = opt.n_classes
  63. for i in range(self.n_edge_types):
  64. # incoming and outgoing edge embedding
  65. in_fc = nn.Linear(self.state_dim, self.state_dim)
  66. out_fc = nn.Linear(self.state_dim, self.state_dim)
  67. self.add_module("in_{}".format(i), in_fc)
  68. self.add_module("out_{}".format(i), out_fc)
  69. self.in_fcs = AttrProxy(self, "in_")
  70. self.out_fcs = AttrProxy(self, "out_")
  71. # Propogation Model
  72. self.propogator = Propogator(self.state_dim, self.n_node, self.n_edge_types)
  73. # Output Model
  74. self.out = nn.Sequential(
  75. nn.Linear(self.state_dim, self.state_dim),
  76. nn.LeakyReLU(),
  77. nn.Linear(self.state_dim, 1),
  78. nn.Tanh(),
  79. )
  80. self.soft_attention = nn.Sequential(
  81. nn.Linear(self.state_dim, self.state_dim),
  82. nn.LeakyReLU(),
  83. nn.Linear(self.state_dim, 1),
  84. nn.Sigmoid(),
  85. )
  86. self.class_prediction = nn.Sequential(
  87. nn.Linear(opt.state_dim, opt.n_hidden),
  88. nn.LeakyReLU(),
  89. nn.Linear(opt.n_hidden, opt.n_classes),
  90. # nn.Softmax(dim=1)
  91. )
  92. # self.class_prediction = nn.Sequential(
  93. # nn.Linear(opt.n_node, opt.n_classes),
  94. # nn.Softmax(dim=1)
  95. # )
  96. self._initialization()
  97. def _initialization(self):
  98. for m in self.modules():
  99. if isinstance(m, nn.Linear):
  100. init.xavier_normal_(m.weight.data)
  101. if m.bias is not None:
  102. init.normal_(m.bias.data)
  103. def forward(self, prop_state, A):
  104. # print(prop_state.shape)
  105. for i_step in range(self.n_steps):
  106. in_states = []
  107. out_states = []
  108. for i in range(self.n_edge_types):
  109. in_states.append(self.in_fcs[i](prop_state))
  110. out_states.append(self.out_fcs[i](prop_state))
  111. in_states = torch.stack(in_states).transpose(0, 1).contiguous()
  112. in_states = in_states.view(-1, self.n_node * self.n_edge_types, self.state_dim)
  113. out_states = torch.stack(out_states).transpose(0, 1).contiguous()
  114. out_states = out_states.view(-1, self.n_node * self.n_edge_types, self.state_dim)
  115. prop_state = self.propogator(in_states, out_states, prop_state, A)
  116. # print("Prop state : " + str(prop_state.shape))
  117. # output = self.out(prop_state)
  118. # print("Out : " + str(output.shape))
  119. soft_attention_ouput = self.soft_attention(prop_state)
  120. # print("Soft : " + str(soft_attention_ouput.shape))
  121. # Element wise hadamard product to get the graph representation, check Equation 7 in GGNN paper for more details
  122. output = torch.mul(prop_state, soft_attention_ouput)
  123. # print("Out : " + str(output.shape))
  124. output = output.sum(1)
  125. # print(output.shape)
  126. if self.is_training_ggnn == True:
  127. output = self.class_prediction(output)
  128. return output