ReactSeq / onmt /encoders /cnn_encoder.py
Oopstom's picture
Upload 313 files
c668e80 verified
"""
Implementation of "Convolutional Sequence to Sequence Learning"
"""
import torch.nn as nn
from onmt.encoders.encoder import EncoderBase
from onmt.utils.cnn_factory import shape_transform, StackedCNN
SCALE_WEIGHT = 0.5**0.5
class CNNEncoder(EncoderBase):
"""Encoder based on "Convolutional Sequence to Sequence Learning"
:cite:`DBLP:journals/corr/GehringAGYD17`.
"""
def __init__(self, num_layers, hidden_size, cnn_kernel_width, dropout, embeddings):
super(CNNEncoder, self).__init__()
self.embeddings = embeddings
input_size = embeddings.embedding_size
self.linear = nn.Linear(input_size, hidden_size)
self.cnn = StackedCNN(num_layers, hidden_size, cnn_kernel_width, dropout)
@classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor."""
return cls(
opt.enc_layers,
opt.enc_hid_size,
opt.cnn_kernel_width,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings,
)
def forward(self, input, src_len=None, hidden=None):
"""See :func:`EncoderBase.forward()`"""
# batch x len x dim
emb = self.embeddings(input)
emb_reshape = emb.view(emb.size(0) * emb.size(1), -1)
emb_remap = self.linear(emb_reshape)
emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1)
emb_remap = shape_transform(emb_remap)
out = self.cnn(emb_remap)
return out.squeeze(3), emb_remap.squeeze(3), src_len
def update_dropout(self, dropout, attention_dropout=None):
self.cnn.dropout.p = dropout