|
"""Implementation of the CNN Decoder part of |
|
"Convolutional Sequence to Sequence Learning" |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
|
|
from onmt.modules import ConvMultiStepAttention, GlobalAttention |
|
from onmt.utils.cnn_factory import shape_transform, GatedConv |
|
from onmt.decoders.decoder import DecoderBase |
|
|
|
SCALE_WEIGHT = 0.5**0.5 |
|
|
|
|
|
class CNNDecoder(DecoderBase): |
|
"""Decoder based on "Convolutional Sequence to Sequence Learning" |
|
:cite:`DBLP:journals/corr/GehringAGYD17`. |
|
|
|
Consists of residual convolutional layers, with ConvMultiStepAttention. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_layers, |
|
hidden_size, |
|
attn_type, |
|
copy_attn, |
|
cnn_kernel_width, |
|
dropout, |
|
embeddings, |
|
copy_attn_type, |
|
): |
|
super(CNNDecoder, self).__init__() |
|
|
|
self.cnn_kernel_width = cnn_kernel_width |
|
self.embeddings = embeddings |
|
|
|
|
|
self.state = {} |
|
|
|
input_size = self.embeddings.embedding_size |
|
self.linear = nn.Linear(input_size, hidden_size) |
|
self.conv_layers = nn.ModuleList( |
|
[ |
|
GatedConv(hidden_size, cnn_kernel_width, dropout, True) |
|
for i in range(num_layers) |
|
] |
|
) |
|
self.attn_layers = nn.ModuleList( |
|
[ConvMultiStepAttention(hidden_size) for i in range(num_layers)] |
|
) |
|
|
|
|
|
|
|
assert not copy_attn, "Copy mechanism not yet tested in conv2conv" |
|
if copy_attn: |
|
self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type) |
|
else: |
|
self.copy_attn = None |
|
|
|
@classmethod |
|
def from_opt(cls, opt, embeddings): |
|
"""Alternate constructor.""" |
|
return cls( |
|
opt.dec_layers, |
|
opt.dec_hid_size, |
|
opt.global_attention, |
|
opt.copy_attn, |
|
opt.cnn_kernel_width, |
|
opt.dropout[0] if type(opt.dropout) is list else opt.dropout, |
|
embeddings, |
|
opt.copy_attn_type, |
|
) |
|
|
|
def init_state(self, _, enc_out, enc_hidden): |
|
"""Init decoder state.""" |
|
self.state["src"] = (enc_out + enc_hidden) * SCALE_WEIGHT |
|
self.state["previous_input"] = None |
|
|
|
def map_state(self, fn): |
|
self.state["src"] = fn(self.state["src"], 0) |
|
if self.state["previous_input"] is not None: |
|
self.state["previous_input"] = fn(self.state["previous_input"], 0) |
|
|
|
def detach_state(self): |
|
self.state["previous_input"] = self.state["previous_input"].detach() |
|
|
|
def forward(self, tgt, enc_out, step=None, **kwargs): |
|
"""See :obj:`onmt.modules.RNNDecoderBase.forward()`""" |
|
|
|
if self.state["previous_input"] is not None: |
|
tgt = torch.cat([self.state["previous_input"], tgt], 1) |
|
|
|
dec_outs = [] |
|
attns = {"std": []} |
|
if self.copy_attn is not None: |
|
attns["copy"] = [] |
|
|
|
emb = self.embeddings(tgt) |
|
assert emb.dim() == 3 |
|
|
|
tgt_emb = emb |
|
|
|
enc_out_t = enc_out |
|
|
|
enc_out_c = self.state["src"] |
|
|
|
emb_reshape = tgt_emb.view(tgt_emb.size(0) * tgt_emb.size(1), -1) |
|
linear_out = self.linear(emb_reshape) |
|
x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) |
|
x = shape_transform(x) |
|
|
|
pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1) |
|
|
|
pad = pad.type_as(x) |
|
base_target_emb = x |
|
|
|
for conv, attention in zip(self.conv_layers, self.attn_layers): |
|
new_target_input = torch.cat([pad, x], 2) |
|
out = conv(new_target_input) |
|
c, attn = attention(base_target_emb, out, enc_out_t, enc_out_c) |
|
x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT |
|
|
|
dec_outs = x.squeeze(3).transpose(1, 2) |
|
|
|
|
|
if self.state["previous_input"] is not None: |
|
dec_outs = dec_outs[:, self.state["previous_input"].size(1) :, :] |
|
attn = attn[:, self.state["previous_input"].size(1) :].squeeze() |
|
attn = torch.stack([attn]) |
|
attns["std"] = attn |
|
if self.copy_attn is not None: |
|
attns["copy"] = attn |
|
|
|
|
|
self.state["previous_input"] = tgt |
|
|
|
return dec_outs, attns |
|
|
|
def update_dropout(self, dropout, attention_dropout=None): |
|
for layer in self.conv_layers: |
|
layer.dropout.p = dropout |
|
|