"""Define GGNN-based encoders.""" import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from onmt.encoders.encoder import EncoderBase class GGNNAttrProxy(object): """ Translates index lookups into attribute lookups. To implement some trick which able to use list of nn.Module in a nn.Module see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2 """ def __init__(self, module, prefix): self.module = module self.prefix = prefix def __getitem__(self, i): return getattr(self.module, self.prefix + str(i)) class GGNNPropogator(nn.Module): """ Gated Propogator for GGNN Using LSTM gating mechanism """ def __init__(self, state_dim, n_node, n_edge_types): super(GGNNPropogator, self).__init__() self.n_node = n_node self.n_edge_types = n_edge_types self.reset_gate = nn.Sequential( nn.Linear(state_dim * 3, state_dim), nn.Sigmoid() ) self.update_gate = nn.Sequential( nn.Linear(state_dim * 3, state_dim), nn.Sigmoid() ) self.tansform = nn.Sequential( nn.Linear(state_dim * 3, state_dim), nn.LeakyReLU() ) def forward(self, state_in, state_out, state_cur, edges, nodes): edges_in = edges[:, :, : nodes * self.n_edge_types] edges_out = edges[:, :, nodes * self.n_edge_types :] a_in = torch.bmm(edges_in, state_in) a_out = torch.bmm(edges_out, state_out) a = torch.cat((a_in, a_out, state_cur), 2) r = self.reset_gate(a) z = self.update_gate(a) joined_input = torch.cat((a_in, a_out, r * state_cur), 2) h_hat = self.tansform(joined_input) prop_out = (1 - z) * state_cur + z * h_hat return prop_out class GGNNEncoder(EncoderBase): """A gated graph neural network configured as an encoder. Based on github.com/JamesChuanggg/ggnn.pytorch.git, which is based on the paper "Gated Graph Sequence Neural Networks" by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel. Args: rnn_type (str): style of recurrent unit to use, one of [LSTM] src_ggnn_size (int) : Size of token-to-node embedding input src_word_vec_size (int) : Size of token-to-node embedding output state_dim (int) : Number of state dimensions in nodes n_edge_types (int) : Number of edge types bidir_edges (bool): True if reverse edges should be autocreated n_node (int) : Max nodes in graph bridge_extra_node (bool): True indicates only 1st extra node (after token listing) should be used for decoder init. n_steps (int): Steps to advance graph encoder for stabilization src_vocab (int): Path to source vocabulary.(The ggnn uses src_vocab during training because the graph is built using edge information which requires parsing the input sequence.) """ def __init__( self, rnn_type, src_word_vec_size, src_ggnn_size, state_dim, bidir_edges, n_edge_types, n_node, bridge_extra_node, n_steps, src_vocab, ): super(GGNNEncoder, self).__init__() self.src_word_vec_size = src_word_vec_size self.src_ggnn_size = src_ggnn_size self.state_dim = state_dim self.n_edge_types = n_edge_types self.n_node = n_node self.n_steps = n_steps self.bidir_edges = bidir_edges self.bridge_extra_node = bridge_extra_node for i in range(self.n_edge_types): # incoming and outgoing edge embedding in_fc = nn.Linear(self.state_dim, self.state_dim) out_fc = nn.Linear(self.state_dim, self.state_dim) self.add_module("in_{}".format(i), in_fc) self.add_module("out_{}".format(i), out_fc) self.in_fcs = GGNNAttrProxy(self, "in_") self.out_fcs = GGNNAttrProxy(self, "out_") # Find vocab data for tree builting f = open(src_vocab, "r") idx = 0 self.COMMA = -1 self.DELIMITER = -1 self.idx2num = [] found_n_minus_one = False for ln in f: ln = ln.strip("\n") ln = ln.split("\t")[0] if idx == 0 and ln != "": idx += 1 self.idx2num.append(-1) if idx == 1 and ln != "": idx += 1 self.idx2num.append(-1) if ln == ",": self.COMMA = idx if ln == "": self.DELIMITER = idx if ln.isdigit(): self.idx2num.append(int(ln)) if int(ln) == n_node - 1: found_n_minus_one = True else: self.idx2num.append(-1) idx += 1 assert self.COMMA >= 0, "GGNN src_vocab must include ',' character" assert self.DELIMITER >= 0, "GGNN src_vocab must include token" assert ( found_n_minus_one ), "GGNN src_vocab must include node numbers for edge connections" # Propogation Model self.propogator = GGNNPropogator(self.state_dim, self.n_node, self.n_edge_types) self._initialization() # Initialize the bridge layer self._initialize_bridge(rnn_type, self.state_dim, 1) # Token embedding if src_ggnn_size > 0: self.embed = nn.Sequential( nn.Linear(src_ggnn_size, src_word_vec_size), nn.LeakyReLU() ) assert ( self.src_ggnn_size >= self.DELIMITER ), "Embedding input must be larger than vocabulary" assert ( self.src_word_vec_size < self.state_dim ), "Embedding size must be smaller than state_dim" else: assert ( self.DELIMITER < self.state_dim ), "Vocabulary too large, consider -src_ggnn_size" @classmethod def from_opt(cls, opt, embeddings): """Alternate constructor.""" return cls( opt.rnn_type, opt.src_word_vec_size, opt.src_ggnn_size, opt.state_dim, opt.bidir_edges, opt.n_edge_types, opt.n_node, opt.bridge_extra_node, opt.n_steps, opt.src_vocab, ) def _initialization(self): for m in self.modules(): if isinstance(m, nn.Linear): m.weight.data.normal_(0.0, 0.02) m.bias.data.fill_(0) def forward(self, src, src_len=None): """See :func:`EncoderBase.forward()`""" nodes = self.n_node batch_size = src.size()[0] first_extra = np.zeros(batch_size, dtype=np.int32) token_onehot = np.zeros( ( batch_size, nodes, self.src_ggnn_size if self.src_ggnn_size > 0 else self.state_dim, ), dtype=np.int32, ) edges = np.zeros( (batch_size, nodes, nodes * self.n_edge_types * 2), dtype=np.int32 ) npsrc = src[:, :, 0].cpu().data.numpy().astype(np.int32) # Initialize graph using formatted input sequence for i in range(batch_size): tokens_done = False # Number of flagged nodes defines node count for this sample # (Nodes can have no flags on them, but must be in 'flags' list). flag_node = 0 flags_done = False edge = 0 source_node = -1 for j in range(len(npsrc)): token = npsrc[i][j] if not tokens_done: if token == self.DELIMITER: tokens_done = True first_extra[i] = j else: token_onehot[i][j][token] = 1 elif token == self.DELIMITER: flag_node += 1 flags_done = True assert flag_node <= nodes, "Too many nodes with flags" elif not flags_done: # The total number of integers in the vocab should allow # for all features and edges to be defined. if token == self.COMMA: flag_node = 0 else: num = self.idx2num[token] if num >= 0: token_onehot[i][flag_node][num + self.DELIMITER] = 1 flag_node += 1 elif token == self.COMMA: edge += 1 assert ( source_node == -1 ), f"Error in graph edge input: {source_node} unpaired" assert edge < self.n_edge_types, "Too many edge types in input" else: num = self.idx2num[token] if source_node < 0: source_node = num else: edges[i][source_node][num + nodes * edge] = 1 if self.bidir_edges: edges[i][num][ nodes * (edge + self.n_edge_types) + source_node ] = 1 source_node = -1 token_onehot = torch.from_numpy(token_onehot).float().to(src.device) if self.src_ggnn_size > 0: token_embed = self.embed(token_onehot) prop_state = torch.cat( ( token_embed, torch.zeros( (batch_size, nodes, self.state_dim - self.src_word_vec_size) ) .float() .to(src.device), ), 2, ) else: prop_state = token_onehot edges = torch.from_numpy(edges).float().to(src.device) for i_step in range(self.n_steps): in_states = [] out_states = [] for i in range(self.n_edge_types): in_states.append(self.in_fcs[i](prop_state)) out_states.append(self.out_fcs[i](prop_state)) in_states = torch.stack(in_states).transpose(0, 1).contiguous() in_states = in_states.view(-1, nodes * self.n_edge_types, self.state_dim) out_states = torch.stack(out_states).transpose(0, 1).contiguous() out_states = out_states.view(-1, nodes * self.n_edge_types, self.state_dim) prop_state = self.propogator( in_states, out_states, prop_state, edges, nodes ) if self.bridge_extra_node: # Use first extra node as only source for decoder init join_state = prop_state[first_extra, torch.arange(batch_size)] else: # Average all nodes to get bridge input join_state = prop_state.mean(0) join_state = torch.stack((join_state, join_state, join_state, join_state)) join_state = (join_state, join_state) enc_final_hs = self._bridge(join_state) return prop_state, enc_final_hs, src_len def _initialize_bridge(self, rnn_type, hidden_size, num_layers): # LSTM has hidden and cell state, other only one number_of_states = 2 if rnn_type == "LSTM" else 1 # Total number of states self.total_hidden_dim = hidden_size * num_layers # Build a linear layer for each self.bridge = nn.ModuleList( [ nn.Linear(self.total_hidden_dim, self.total_hidden_dim, bias=True) for _ in range(number_of_states) ] ) def _bridge(self, hidden): """Forward hidden state through bridge.""" def bottle_hidden(linear, states): """ Transform from 3D to 2D, apply linear and return initial size """ size = states.size() result = linear(states.view(-1, self.total_hidden_dim)) return F.leaky_relu(result).view(size) if isinstance(hidden, tuple): # LSTM outs = tuple( [ bottle_hidden(layer, hidden[ix]) for ix, layer in enumerate(self.bridge) ] ) else: outs = bottle_hidden(self.bridge[0], hidden) return outs