|
"""Define RNN-based encoders.""" |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from torch.nn.utils.rnn import pack_padded_sequence as pack |
|
from torch.nn.utils.rnn import pad_packed_sequence as unpack |
|
|
|
from onmt.encoders.encoder import EncoderBase |
|
from onmt.utils.rnn_factory import rnn_factory |
|
|
|
|
|
class RNNEncoder(EncoderBase): |
|
"""A generic recurrent neural network encoder. |
|
|
|
Args: |
|
rnn_type (str): |
|
style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] |
|
bidirectional (bool) : use a bidirectional RNN |
|
num_layers (int) : number of stacked layers |
|
hidden_size (int) : hidden size of each layer |
|
dropout (float) : dropout value for :class:`torch.nn.Dropout` |
|
embeddings (onmt.modules.Embeddings): embedding module to use |
|
""" |
|
|
|
def __init__( |
|
self, |
|
rnn_type, |
|
bidirectional, |
|
num_layers, |
|
hidden_size, |
|
dropout=0.0, |
|
embeddings=None, |
|
use_bridge=False, |
|
): |
|
super(RNNEncoder, self).__init__() |
|
assert embeddings is not None |
|
|
|
num_directions = 2 if bidirectional else 1 |
|
assert hidden_size % num_directions == 0 |
|
hidden_size = hidden_size // num_directions |
|
self.embeddings = embeddings |
|
|
|
self.rnn, self.no_pack_padded_seq = rnn_factory( |
|
rnn_type, |
|
input_size=embeddings.embedding_size, |
|
hidden_size=hidden_size, |
|
num_layers=num_layers, |
|
dropout=dropout, |
|
bidirectional=bidirectional, |
|
) |
|
|
|
|
|
self.use_bridge = use_bridge |
|
if self.use_bridge: |
|
self._initialize_bridge(rnn_type, hidden_size, num_layers) |
|
|
|
@classmethod |
|
def from_opt(cls, opt, embeddings): |
|
"""Alternate constructor.""" |
|
return cls( |
|
opt.rnn_type, |
|
opt.brnn, |
|
opt.enc_layers, |
|
opt.enc_hid_size, |
|
opt.dropout[0] if type(opt.dropout) is list else opt.dropout, |
|
embeddings, |
|
opt.bridge, |
|
) |
|
|
|
def forward(self, src, src_len=None): |
|
"""See :func:`EncoderBase.forward()`""" |
|
|
|
emb = self.embeddings(src) |
|
|
|
packed_emb = emb |
|
if src_len is not None and not self.no_pack_padded_seq: |
|
|
|
src_len_list = src_len.view(-1).tolist() |
|
packed_emb = pack(emb, src_len_list, batch_first=True, enforce_sorted=False) |
|
|
|
enc_out, enc_final_hs = self.rnn(packed_emb) |
|
|
|
if src_len is not None and not self.no_pack_padded_seq: |
|
enc_out = unpack(enc_out, batch_first=True)[0] |
|
|
|
if self.use_bridge: |
|
enc_final_hs = self._bridge(enc_final_hs) |
|
|
|
return enc_out, enc_final_hs, src_len |
|
|
|
def _initialize_bridge(self, rnn_type, hidden_size, num_layers): |
|
|
|
number_of_states = 2 if rnn_type == "LSTM" else 1 |
|
|
|
self.total_hidden_dim = hidden_size * num_layers |
|
|
|
|
|
self.bridge = nn.ModuleList( |
|
[ |
|
nn.Linear(self.total_hidden_dim, self.total_hidden_dim, bias=True) |
|
for _ in range(number_of_states) |
|
] |
|
) |
|
|
|
def _bridge(self, hidden): |
|
"""Forward hidden state through bridge. |
|
final hidden state ``(num_layers x dir, batch, hidden_size)`` |
|
""" |
|
|
|
def bottle_hidden(linear, states): |
|
""" |
|
Transform from 3D to 2D, apply linear and return initial size |
|
""" |
|
states = states.permute(1, 0, 2).contiguous() |
|
size = states.size() |
|
result = linear(states.view(-1, self.total_hidden_dim)) |
|
result = F.relu(result).view(size) |
|
return result.permute(1, 0, 2).contiguous() |
|
|
|
if isinstance(hidden, tuple): |
|
outs = tuple( |
|
[ |
|
bottle_hidden(layer, hidden[ix]) |
|
for ix, layer in enumerate(self.bridge) |
|
] |
|
) |
|
else: |
|
outs = bottle_hidden(self.bridge[0], hidden) |
|
return outs |
|
|
|
def update_dropout(self, dropout, attention_dropout=None): |
|
self.rnn.dropout = dropout |
|
|