|
"""Define GGNN-based encoders.""" |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from onmt.encoders.encoder import EncoderBase |
|
|
|
|
|
class GGNNAttrProxy(object): |
|
""" |
|
Translates index lookups into attribute lookups. |
|
To implement some trick which able to use list of nn.Module in a nn.Module |
|
see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2 |
|
""" |
|
|
|
def __init__(self, module, prefix): |
|
self.module = module |
|
self.prefix = prefix |
|
|
|
def __getitem__(self, i): |
|
return getattr(self.module, self.prefix + str(i)) |
|
|
|
|
|
class GGNNPropogator(nn.Module): |
|
""" |
|
Gated Propogator for GGNN |
|
Using LSTM gating mechanism |
|
""" |
|
|
|
def __init__(self, state_dim, n_node, n_edge_types): |
|
super(GGNNPropogator, self).__init__() |
|
|
|
self.n_node = n_node |
|
self.n_edge_types = n_edge_types |
|
|
|
self.reset_gate = nn.Sequential( |
|
nn.Linear(state_dim * 3, state_dim), nn.Sigmoid() |
|
) |
|
self.update_gate = nn.Sequential( |
|
nn.Linear(state_dim * 3, state_dim), nn.Sigmoid() |
|
) |
|
self.tansform = nn.Sequential( |
|
nn.Linear(state_dim * 3, state_dim), nn.LeakyReLU() |
|
) |
|
|
|
def forward(self, state_in, state_out, state_cur, edges, nodes): |
|
edges_in = edges[:, :, : nodes * self.n_edge_types] |
|
edges_out = edges[:, :, nodes * self.n_edge_types :] |
|
|
|
a_in = torch.bmm(edges_in, state_in) |
|
a_out = torch.bmm(edges_out, state_out) |
|
a = torch.cat((a_in, a_out, state_cur), 2) |
|
|
|
r = self.reset_gate(a) |
|
z = self.update_gate(a) |
|
joined_input = torch.cat((a_in, a_out, r * state_cur), 2) |
|
h_hat = self.tansform(joined_input) |
|
|
|
prop_out = (1 - z) * state_cur + z * h_hat |
|
|
|
return prop_out |
|
|
|
|
|
class GGNNEncoder(EncoderBase): |
|
"""A gated graph neural network configured as an encoder. |
|
Based on github.com/JamesChuanggg/ggnn.pytorch.git, |
|
which is based on the paper "Gated Graph Sequence Neural Networks" |
|
by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel. |
|
|
|
Args: |
|
rnn_type (str): |
|
style of recurrent unit to use, one of [LSTM] |
|
src_ggnn_size (int) : Size of token-to-node embedding input |
|
src_word_vec_size (int) : Size of token-to-node embedding output |
|
state_dim (int) : Number of state dimensions in nodes |
|
n_edge_types (int) : Number of edge types |
|
bidir_edges (bool): True if reverse edges should be autocreated |
|
n_node (int) : Max nodes in graph |
|
bridge_extra_node (bool): True indicates only 1st extra node |
|
(after token listing) should be used for decoder init. |
|
n_steps (int): Steps to advance graph encoder for stabilization |
|
src_vocab (int): Path to source vocabulary.(The ggnn uses src_vocab |
|
during training because the graph is built using edge information |
|
which requires parsing the input sequence.) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
rnn_type, |
|
src_word_vec_size, |
|
src_ggnn_size, |
|
state_dim, |
|
bidir_edges, |
|
n_edge_types, |
|
n_node, |
|
bridge_extra_node, |
|
n_steps, |
|
src_vocab, |
|
): |
|
super(GGNNEncoder, self).__init__() |
|
|
|
self.src_word_vec_size = src_word_vec_size |
|
self.src_ggnn_size = src_ggnn_size |
|
self.state_dim = state_dim |
|
self.n_edge_types = n_edge_types |
|
self.n_node = n_node |
|
self.n_steps = n_steps |
|
self.bidir_edges = bidir_edges |
|
self.bridge_extra_node = bridge_extra_node |
|
|
|
for i in range(self.n_edge_types): |
|
|
|
in_fc = nn.Linear(self.state_dim, self.state_dim) |
|
out_fc = nn.Linear(self.state_dim, self.state_dim) |
|
self.add_module("in_{}".format(i), in_fc) |
|
self.add_module("out_{}".format(i), out_fc) |
|
|
|
self.in_fcs = GGNNAttrProxy(self, "in_") |
|
self.out_fcs = GGNNAttrProxy(self, "out_") |
|
|
|
|
|
f = open(src_vocab, "r") |
|
idx = 0 |
|
self.COMMA = -1 |
|
self.DELIMITER = -1 |
|
self.idx2num = [] |
|
found_n_minus_one = False |
|
for ln in f: |
|
ln = ln.strip("\n") |
|
ln = ln.split("\t")[0] |
|
if idx == 0 and ln != "<unk>": |
|
idx += 1 |
|
self.idx2num.append(-1) |
|
if idx == 1 and ln != "<blank>": |
|
idx += 1 |
|
self.idx2num.append(-1) |
|
if ln == ",": |
|
self.COMMA = idx |
|
if ln == "<EOT>": |
|
self.DELIMITER = idx |
|
if ln.isdigit(): |
|
self.idx2num.append(int(ln)) |
|
if int(ln) == n_node - 1: |
|
found_n_minus_one = True |
|
else: |
|
self.idx2num.append(-1) |
|
idx += 1 |
|
|
|
assert self.COMMA >= 0, "GGNN src_vocab must include ',' character" |
|
assert self.DELIMITER >= 0, "GGNN src_vocab must include <EOT> token" |
|
assert ( |
|
found_n_minus_one |
|
), "GGNN src_vocab must include node numbers for edge connections" |
|
|
|
|
|
self.propogator = GGNNPropogator(self.state_dim, self.n_node, self.n_edge_types) |
|
|
|
self._initialization() |
|
|
|
|
|
self._initialize_bridge(rnn_type, self.state_dim, 1) |
|
|
|
|
|
if src_ggnn_size > 0: |
|
self.embed = nn.Sequential( |
|
nn.Linear(src_ggnn_size, src_word_vec_size), nn.LeakyReLU() |
|
) |
|
assert ( |
|
self.src_ggnn_size >= self.DELIMITER |
|
), "Embedding input must be larger than vocabulary" |
|
assert ( |
|
self.src_word_vec_size < self.state_dim |
|
), "Embedding size must be smaller than state_dim" |
|
else: |
|
assert ( |
|
self.DELIMITER < self.state_dim |
|
), "Vocabulary too large, consider -src_ggnn_size" |
|
|
|
@classmethod |
|
def from_opt(cls, opt, embeddings): |
|
"""Alternate constructor.""" |
|
return cls( |
|
opt.rnn_type, |
|
opt.src_word_vec_size, |
|
opt.src_ggnn_size, |
|
opt.state_dim, |
|
opt.bidir_edges, |
|
opt.n_edge_types, |
|
opt.n_node, |
|
opt.bridge_extra_node, |
|
opt.n_steps, |
|
opt.src_vocab, |
|
) |
|
|
|
def _initialization(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Linear): |
|
m.weight.data.normal_(0.0, 0.02) |
|
m.bias.data.fill_(0) |
|
|
|
def forward(self, src, src_len=None): |
|
"""See :func:`EncoderBase.forward()`""" |
|
|
|
nodes = self.n_node |
|
batch_size = src.size()[0] |
|
first_extra = np.zeros(batch_size, dtype=np.int32) |
|
token_onehot = np.zeros( |
|
( |
|
batch_size, |
|
nodes, |
|
self.src_ggnn_size if self.src_ggnn_size > 0 else self.state_dim, |
|
), |
|
dtype=np.int32, |
|
) |
|
edges = np.zeros( |
|
(batch_size, nodes, nodes * self.n_edge_types * 2), dtype=np.int32 |
|
) |
|
npsrc = src[:, :, 0].cpu().data.numpy().astype(np.int32) |
|
|
|
|
|
for i in range(batch_size): |
|
tokens_done = False |
|
|
|
|
|
flag_node = 0 |
|
flags_done = False |
|
edge = 0 |
|
source_node = -1 |
|
for j in range(len(npsrc)): |
|
token = npsrc[i][j] |
|
if not tokens_done: |
|
if token == self.DELIMITER: |
|
tokens_done = True |
|
first_extra[i] = j |
|
else: |
|
token_onehot[i][j][token] = 1 |
|
elif token == self.DELIMITER: |
|
flag_node += 1 |
|
flags_done = True |
|
assert flag_node <= nodes, "Too many nodes with flags" |
|
elif not flags_done: |
|
|
|
|
|
if token == self.COMMA: |
|
flag_node = 0 |
|
else: |
|
num = self.idx2num[token] |
|
if num >= 0: |
|
token_onehot[i][flag_node][num + self.DELIMITER] = 1 |
|
flag_node += 1 |
|
elif token == self.COMMA: |
|
edge += 1 |
|
assert ( |
|
source_node == -1 |
|
), f"Error in graph edge input: {source_node} unpaired" |
|
assert edge < self.n_edge_types, "Too many edge types in input" |
|
else: |
|
num = self.idx2num[token] |
|
if source_node < 0: |
|
source_node = num |
|
else: |
|
edges[i][source_node][num + nodes * edge] = 1 |
|
if self.bidir_edges: |
|
edges[i][num][ |
|
nodes * (edge + self.n_edge_types) + source_node |
|
] = 1 |
|
source_node = -1 |
|
|
|
token_onehot = torch.from_numpy(token_onehot).float().to(src.device) |
|
if self.src_ggnn_size > 0: |
|
token_embed = self.embed(token_onehot) |
|
prop_state = torch.cat( |
|
( |
|
token_embed, |
|
torch.zeros( |
|
(batch_size, nodes, self.state_dim - self.src_word_vec_size) |
|
) |
|
.float() |
|
.to(src.device), |
|
), |
|
2, |
|
) |
|
else: |
|
prop_state = token_onehot |
|
edges = torch.from_numpy(edges).float().to(src.device) |
|
|
|
for i_step in range(self.n_steps): |
|
in_states = [] |
|
out_states = [] |
|
for i in range(self.n_edge_types): |
|
in_states.append(self.in_fcs[i](prop_state)) |
|
out_states.append(self.out_fcs[i](prop_state)) |
|
in_states = torch.stack(in_states).transpose(0, 1).contiguous() |
|
in_states = in_states.view(-1, nodes * self.n_edge_types, self.state_dim) |
|
out_states = torch.stack(out_states).transpose(0, 1).contiguous() |
|
out_states = out_states.view(-1, nodes * self.n_edge_types, self.state_dim) |
|
|
|
prop_state = self.propogator( |
|
in_states, out_states, prop_state, edges, nodes |
|
) |
|
|
|
if self.bridge_extra_node: |
|
|
|
join_state = prop_state[first_extra, torch.arange(batch_size)] |
|
else: |
|
|
|
join_state = prop_state.mean(0) |
|
join_state = torch.stack((join_state, join_state, join_state, join_state)) |
|
join_state = (join_state, join_state) |
|
|
|
enc_final_hs = self._bridge(join_state) |
|
|
|
return prop_state, enc_final_hs, src_len |
|
|
|
def _initialize_bridge(self, rnn_type, hidden_size, num_layers): |
|
|
|
number_of_states = 2 if rnn_type == "LSTM" else 1 |
|
|
|
self.total_hidden_dim = hidden_size * num_layers |
|
|
|
|
|
self.bridge = nn.ModuleList( |
|
[ |
|
nn.Linear(self.total_hidden_dim, self.total_hidden_dim, bias=True) |
|
for _ in range(number_of_states) |
|
] |
|
) |
|
|
|
def _bridge(self, hidden): |
|
"""Forward hidden state through bridge.""" |
|
|
|
def bottle_hidden(linear, states): |
|
""" |
|
Transform from 3D to 2D, apply linear and return initial size |
|
""" |
|
size = states.size() |
|
result = linear(states.view(-1, self.total_hidden_dim)) |
|
return F.leaky_relu(result).view(size) |
|
|
|
if isinstance(hidden, tuple): |
|
outs = tuple( |
|
[ |
|
bottle_hidden(layer, hidden[ix]) |
|
for ix, layer in enumerate(self.bridge) |
|
] |
|
) |
|
else: |
|
outs = bottle_hidden(self.bridge[0], hidden) |
|
return outs |
|
|