ReactSeq / onmt /decoders /decoder.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
14.2 kB
import torch
import torch.nn as nn
from onmt.modules.stacked_rnn import StackedLSTM, StackedGRU
from onmt.modules import context_gate_factory, GlobalAttention
from onmt.utils.rnn_factory import rnn_factory
class DecoderBase(nn.Module):
"""Abstract class for decoders.
Args:
attentional (bool): The decoder returns non-empty attention.
"""
def __init__(self, attentional=True):
super(DecoderBase, self).__init__()
self.attentional = attentional
@classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor.
Subclasses should override this method.
"""
raise NotImplementedError
class RNNDecoderBase(DecoderBase):
"""Base recurrent attention-based decoder class.
Specifies the interface used by different decoder types
and required by :class:`~onmt.models.NMTModel`.
Args:
rnn_type (str):
style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]
bidirectional_encoder (bool) : use with a bidirectional encoder
num_layers (int) : number of stacked layers
hidden_size (int) : hidden size of each layer
attn_type (str) : see :class:`~onmt.modules.GlobalAttention`
attn_func (str) : see :class:`~onmt.modules.GlobalAttention`
coverage_attn (str): see :class:`~onmt.modules.GlobalAttention`
context_gate (str): see :class:`~onmt.modules.ContextGate`
copy_attn (bool): setup a separate copy attention mechanism
dropout (float) : dropout value for :class:`torch.nn.Dropout`
embeddings (onmt.modules.Embeddings): embedding module to use
reuse_copy_attn (bool): reuse the attention for copying
copy_attn_type (str): The copy attention style. See
:class:`~onmt.modules.GlobalAttention`.
"""
def __init__(
self,
rnn_type,
bidirectional_encoder,
num_layers,
hidden_size,
attn_type="general",
attn_func="softmax",
coverage_attn=False,
context_gate=None,
copy_attn=False,
dropout=0.0,
embeddings=None,
reuse_copy_attn=False,
copy_attn_type="general",
):
super(RNNDecoderBase, self).__init__(
attentional=attn_type != "none" and attn_type is not None
)
self.bidirectional_encoder = bidirectional_encoder
self.num_layers = num_layers
self.hidden_size = hidden_size
self.embeddings = embeddings
self.dropout = nn.Dropout(dropout)
# Decoder state
self.state = {}
# Build the RNN.
self.rnn = self._build_rnn(
rnn_type,
input_size=self._input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
)
# Set up the context gate.
self.context_gate = None
if context_gate is not None:
self.context_gate = context_gate_factory(
context_gate, self._input_size, hidden_size, hidden_size, hidden_size
)
# Set up the standard attention.
self._coverage = coverage_attn
if not self.attentional:
if self._coverage:
raise ValueError("Cannot use coverage term with no attention.")
self.attn = None
else:
self.attn = GlobalAttention(
hidden_size,
coverage=coverage_attn,
attn_type=attn_type,
attn_func=attn_func,
)
if copy_attn and not reuse_copy_attn:
if copy_attn_type == "none" or copy_attn_type is None:
raise ValueError("Cannot use copy_attn with copy_attn_type none")
self.copy_attn = GlobalAttention(
hidden_size, attn_type=copy_attn_type, attn_func=attn_func
)
else:
self.copy_attn = None
self._reuse_copy_attn = reuse_copy_attn and copy_attn
if self._reuse_copy_attn and not self.attentional:
raise ValueError("Cannot reuse copy attention with no attention.")
@classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor."""
return cls(
opt.rnn_type,
opt.brnn,
opt.dec_layers,
opt.dec_hid_size,
opt.global_attention,
opt.global_attention_function,
opt.coverage_attn,
opt.context_gate,
opt.copy_attn,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings,
opt.reuse_copy_attn,
opt.copy_attn_type,
)
def init_state(self, src, _, enc_final_hs):
"""Initialize decoder state with last state of the encoder."""
def _fix_enc_hidden(hidden):
# The encoder hidden is (layers*directions) x batch x dim.
# We need to convert it to layers x batch x (directions*dim).
if self.bidirectional_encoder:
hidden = torch.cat(
[hidden[0 : hidden.size(0) : 2], hidden[1 : hidden.size(0) : 2]], 2
)
return hidden
if isinstance(enc_final_hs, tuple): # LSTM
self.state["hidden"] = tuple(
_fix_enc_hidden(enc_hid) for enc_hid in enc_final_hs
)
else: # GRU
self.state["hidden"] = (_fix_enc_hidden(enc_final_hs),)
# Init the input feed.
batch_size = self.state["hidden"][0].size(1)
h_size = (batch_size, self.hidden_size)
self.state["input_feed"] = (
self.state["hidden"][0].data.new(*h_size).zero_().unsqueeze(0)
)
self.state["coverage"] = None
def map_state(self, fn):
self.state["hidden"] = tuple(fn(h, 1) for h in self.state["hidden"])
self.state["input_feed"] = fn(self.state["input_feed"], 1)
if self._coverage and self.state["coverage"] is not None:
self.state["coverage"] = fn(self.state["coverage"], 1)
def detach_state(self):
self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
self.state["input_feed"] = self.state["input_feed"].detach()
if self._coverage and self.state["coverage"] is not None:
self.state["coverage"] = self.state["coverage"].detach()
def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
"""
Args:
tgt (LongTensor): sequences of padded tokens
``(batch, tgt_len, nfeats)``.
enc_out (FloatTensor): vectors from the encoder
``(batch, src_len, hidden)``.
src_len (LongTensor): the padded source lengths
``(batch,)``.
Returns:
(FloatTensor, dict[str, FloatTensor]):
* dec_outs: output from the decoder (after attn)
``(batch, tgt_len, hidden)``.
* attns: distribution over src at each tgt
``(batch, tgt_len, src_len)``.
"""
dec_state, dec_outs, attns = self._run_forward_pass(
tgt, enc_out, src_len=src_len
)
# Update the state with the result.
if not isinstance(dec_state, tuple):
dec_state = (dec_state,)
self.state["hidden"] = dec_state
# Concatenates sequence of tensors along a new dimension.
# NOTE: v0.3 to 0.4: dec_outs / attns[*] may not be list
# (in particular in case of SRU) it was not raising error in 0.3
# since stack(Variable) was allowed.
# In 0.4, SRU returns a tensor that shouldn't be stacke
if type(dec_outs) == list:
dec_outs = torch.stack(dec_outs, dim=1)
for k in attns:
if type(attns[k]) == list:
attns[k] = torch.stack(attns[k])
self.state["input_feed"] = dec_outs[:, -1, :].unsqueeze(0)
self.state["coverage"] = None
if "coverage" in attns:
self.state["coverage"] = attns["coverage"][-1, :, :].unsqueeze(0)
return dec_outs, attns
def update_dropout(self, dropout, attention_dropout=None):
self.dropout.p = dropout
self.embeddings.update_dropout(dropout)
class StdRNNDecoder(RNNDecoderBase):
"""Standard fully batched RNN decoder with attention.
Faster implementation, uses CuDNN for implementation.
See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options.
Based around the approach from
"Neural Machine Translation By Jointly Learning To Align and Translate"
:cite:`Bahdanau2015`
Implemented without input_feeding and currently with no `coverage_attn`
or `copy_attn` support.
"""
def _run_forward_pass(self, tgt, enc_out, src_len=None):
"""
Private helper for running the specific RNN forward pass.
Must be overriden by all subclasses.
Args:
tgt (LongTensor): a sequence of input tokens tensors
``(batch, tgt_len, nfeats)``.
enc_out (FloatTensor): output(tensor sequence) from the
encoder RNN of size ``(batch, src_len, hidden_size)``.
src_len (LongTensor): the source enc_out lengths.
Returns:
(Tensor, List[FloatTensor], Dict[str, List[FloatTensor]):
* dec_state: final hidden state from the decoder.
* dec_outs: an array of output of every time
step from the decoder.
* attns: a dictionary of different
type of attention Tensor array of every time
step from the decoder.
"""
assert self.copy_attn is None # TODO, no support yet.
assert not self._coverage # TODO, no support yet.
attns = {}
emb = self.embeddings(tgt)
if isinstance(self.rnn, nn.GRU):
rnn_out, dec_state = self.rnn(emb, self.state["hidden"][0])
else:
rnn_out, dec_state = self.rnn(emb, self.state["hidden"])
tgt_batch, tgt_len, _ = tgt.size()
# Calculate the attention.
if not self.attentional:
dec_outs = rnn_out
else:
dec_outs, p_attn = self.attn(rnn_out, enc_out, src_len=src_len)
attns["std"] = p_attn
# Calculate the context gate.
if self.context_gate is not None:
dec_outs = self.context_gate(
emb.view(-1, emb.size(2)),
rnn_out.view(-1, rnn_out.size(2)),
dec_outs.view(-1, dec_outs.size(2)),
)
dec_outs = dec_outs.view(tgt_batch, tgt_len, self.hidden_size)
dec_outs = self.dropout(dec_outs)
return dec_state, dec_outs, attns
def _build_rnn(self, rnn_type, **kwargs):
rnn, _ = rnn_factory(rnn_type, **kwargs)
return rnn
@property
def _input_size(self):
return self.embeddings.embedding_size
class InputFeedRNNDecoder(RNNDecoderBase):
"""Input feeding based decoder.
See :class:`~onmt.decoders.decoder.RNNDecoderBase` for options.
Based around the input feeding approach from
"Effective Approaches to Attention-based Neural Machine Translation"
:cite:`Luong2015`
"""
def _run_forward_pass(self, tgt, enc_out, src_len=None):
"""
See StdRNNDecoder._run_forward_pass() for description
of arguments and return values.
"""
# Additional args check.
input_feed = self.state["input_feed"].squeeze(0)
dec_outs = []
attns = {}
if self.attn is not None:
attns["std"] = []
if self.copy_attn is not None or self._reuse_copy_attn:
attns["copy"] = []
if self._coverage:
attns["coverage"] = []
emb = self.embeddings(tgt)
assert emb.dim() == 3 # batch x len x embedding_dim
dec_state = self.state["hidden"]
coverage = (
self.state["coverage"].squeeze(0)
if self.state["coverage"] is not None
else None
)
# Input feed concatenates hidden state with
# input at every time step.
for emb_t in emb.split(1, dim=1):
dec_in = torch.cat([emb_t.squeeze(1), input_feed], 1)
rnn_out, dec_state = self.rnn(dec_in, dec_state)
if self.attentional:
dec_out, p_attn = self.attn(rnn_out, enc_out, src_len=src_len)
attns["std"].append(p_attn)
else:
dec_out = rnn_out
if self.context_gate is not None:
# TODO: context gate should be employed
# instead of second RNN transform.
dec_out = self.context_gate(dec_in, rnn_out, dec_out)
dec_out = self.dropout(dec_out)
input_feed = dec_out
dec_outs += [dec_out]
# Update the coverage attention.
# attns["coverage"] is actually c^(t+1) of See et al(2017)
# 1-index shifted
if self._coverage:
coverage = p_attn if coverage is None else p_attn + coverage
attns["coverage"] += [coverage]
if self.copy_attn is not None:
_, copy_attn = self.copy_attn(dec_out, enc_out)
attns["copy"] += [copy_attn]
elif self._reuse_copy_attn:
attns["copy"] = attns["std"]
return dec_state, dec_outs, attns
def _build_rnn(self, rnn_type, input_size, hidden_size, num_layers, dropout):
assert rnn_type != "SRU", (
"SRU doesn't support input feed! " "Please set -input_feed 0!"
)
stacked_cell = StackedLSTM if rnn_type == "LSTM" else StackedGRU
return stacked_cell(num_layers, input_size, hidden_size, dropout)
@property
def _input_size(self):
"""Using input feed by concatenating input with attention vectors."""
return self.embeddings.embedding_size + self.hidden_size
def update_dropout(self, dropout, attention_dropout=None):
self.dropout.p = dropout
self.rnn.dropout.p = dropout
self.embeddings.update_dropout(dropout)