|
"""Define a minimal encoder.""" |
|
from onmt.encoders.encoder import EncoderBase |
|
from onmt.utils.misc import sequence_mask |
|
import torch |
|
|
|
|
|
class MeanEncoder(EncoderBase): |
|
"""A trivial non-recurrent encoder. Simply applies mean pooling. |
|
|
|
Args: |
|
num_layers (int): number of replicated layers |
|
embeddings (onmt.modules.Embeddings): embedding module to use |
|
""" |
|
|
|
def __init__(self, num_layers, embeddings): |
|
super(MeanEncoder, self).__init__() |
|
self.num_layers = num_layers |
|
self.embeddings = embeddings |
|
|
|
@classmethod |
|
def from_opt(cls, opt, embeddings): |
|
"""Alternate constructor.""" |
|
return cls(opt.enc_layers, embeddings) |
|
|
|
def forward(self, src, src_len=None): |
|
"""See :func:`EncoderBase.forward()`""" |
|
|
|
emb = self.embeddings(src) |
|
batch, _, emb_dim = emb.size() |
|
|
|
if src_len is not None: |
|
|
|
mask = sequence_mask(src_len).float() |
|
mask = mask / src_len.unsqueeze(1).float() |
|
mean = torch.bmm(mask.unsqueeze(1), emb).squeeze(1) |
|
else: |
|
mean = emb.mean(1) |
|
|
|
mean = mean.expand(self.num_layers, batch, emb_dim) |
|
enc_out = emb |
|
enc_final_hs = (mean, mean) |
|
return enc_out, enc_final_hs, src_len |
|
|