123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- import torch
- import torch.nn as nn
- import torch.nn.init as init
- class AttrProxy(object):
- """
- Translates index lookups into attribute lookups.
- To implement some trick which able to use list of nn.Module in a nn.Module
- see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2
- """
- def __init__(self, module, prefix):
- self.module = module
- self.prefix = prefix
- def __getitem__(self, i):
- return getattr(self.module, self.prefix + str(i))
- class Propogator(nn.Module):
- """
- Gated Propogator for GGNN
- Using LSTM gating mechanism
- """
- def __init__(self, state_dim, n_node, n_edge_types):
- super(Propogator, self).__init__()
- self.n_node = n_node
- self.n_edge_types = n_edge_types
- self.reset_gate = nn.Sequential(
- nn.Linear(state_dim * 3, state_dim),
- nn.Sigmoid()
- )
- self.update_gate = nn.Sequential(
- nn.Linear(state_dim * 3, state_dim),
- nn.Sigmoid()
- )
- self.tansform = nn.Sequential(
- nn.Linear(state_dim * 3, state_dim),
- nn.Tanh()
- )
- def forward(self, state_in, state_out, state_cur, A):
- A_in = A[:, :, :self.n_node * self.n_edge_types]
- A_out = A[:, :, self.n_node * self.n_edge_types:]
- a_in = torch.bmm(A_in, state_in)
- a_out = torch.bmm(A_out, state_out)
- a = torch.cat((a_in, a_out, state_cur), 2)
- r = self.reset_gate(a)
- z = self.update_gate(a)
- joined_input = torch.cat((a_in, a_out, r * state_cur), 2)
- h_hat = self.tansform(joined_input)
- output = (1 - z) * state_cur + z * h_hat
- return output
- class GGNN(nn.Module):
- """
- Gated Graph Sequence Neural Networks (GGNN)
- Mode: SelectNode
- Implementation based on https://arxiv.org/abs/1511.05493
- """
- def __init__(self, opt):
- super(GGNN, self).__init__()
- # assert (opt.state_dim >= opt.annotation_dim, 'state_dim must be no less than annotation_dim')
- self.is_training_ggnn = opt.is_training_ggnn
- self.state_dim = opt.state_dim
- self.n_edge_types = opt.n_edge_types
- self.n_node = opt.n_node
- self.n_steps = opt.n_steps
- self.n_classes = opt.n_classes
- for i in range(self.n_edge_types):
- # incoming and outgoing edge embedding
- in_fc = nn.Linear(self.state_dim, self.state_dim)
- out_fc = nn.Linear(self.state_dim, self.state_dim)
- self.add_module("in_{}".format(i), in_fc)
- self.add_module("out_{}".format(i), out_fc)
- self.in_fcs = AttrProxy(self, "in_")
- self.out_fcs = AttrProxy(self, "out_")
- # Propogation Model
- self.propogator = Propogator(self.state_dim, self.n_node, self.n_edge_types)
- # Output Model
- self.out = nn.Sequential(
- nn.Linear(self.state_dim, self.state_dim),
- nn.LeakyReLU(),
- nn.Linear(self.state_dim, 1),
- nn.Tanh(),
- )
- self.soft_attention = nn.Sequential(
- nn.Linear(self.state_dim, self.state_dim),
- nn.LeakyReLU(),
- nn.Linear(self.state_dim, 1),
- nn.Sigmoid(),
- )
- self.class_prediction = nn.Sequential(
- nn.Linear(opt.state_dim, opt.n_hidden),
- nn.LeakyReLU(),
- nn.Linear(opt.n_hidden, opt.n_classes),
- # nn.Softmax(dim=1)
- )
- # self.class_prediction = nn.Sequential(
- # nn.Linear(opt.n_node, opt.n_classes),
- # nn.Softmax(dim=1)
- # )
- self._initialization()
- def _initialization(self):
- for m in self.modules():
- if isinstance(m, nn.Linear):
- init.xavier_normal_(m.weight.data)
- if m.bias is not None:
- init.normal_(m.bias.data)
- def forward(self, prop_state, A):
- # print(prop_state.shape)
- for i_step in range(self.n_steps):
- in_states = []
- out_states = []
- for i in range(self.n_edge_types):
- in_states.append(self.in_fcs[i](prop_state))
- out_states.append(self.out_fcs[i](prop_state))
- in_states = torch.stack(in_states).transpose(0, 1).contiguous()
- in_states = in_states.view(-1, self.n_node * self.n_edge_types, self.state_dim)
- out_states = torch.stack(out_states).transpose(0, 1).contiguous()
- out_states = out_states.view(-1, self.n_node * self.n_edge_types, self.state_dim)
- prop_state = self.propogator(in_states, out_states, prop_state, A)
- # print("Prop state : " + str(prop_state.shape))
- # output = self.out(prop_state)
- # print("Out : " + str(output.shape))
- soft_attention_ouput = self.soft_attention(prop_state)
- # print("Soft : " + str(soft_attention_ouput.shape))
- # Element wise hadamard product to get the graph representation, check Equation 7 in GGNN paper for more details
- output = torch.mul(prop_state, soft_attention_ouput)
- # print("Out : " + str(output.shape))
- output = output.sum(1)
- # print(output.shape)
- if self.is_training_ggnn == True:
- output = self.class_prediction(output)
- return output
|