jonathanjordan21's picture
67809715652a92b22870c50ad30f6ff38e292006aedc75ddbdc828aa856ef68f
c021d8e verified
raw
history blame
1.37 kB
"""Base class for encoders and generic multi encoders."""
import torch.nn as nn
from .misc import aeq
class EncoderBase(nn.Module):
"""
Base encoder class. Specifies the interface used by different encoder types
and required by :class:`onmt.Models.NMTModel`.
.. mermaid::
graph BT
A[Input]
subgraph RNN
C[Pos 1]
D[Pos 2]
E[Pos N]
end
F[Memory_Bank]
G[Final]
A-->C
A-->D
A-->E
C-->F
D-->F
E-->F
E-->G
"""
@classmethod
def from_opt(cls, opt, embeddings=None):
raise NotImplementedError
def _check_args(self, src, lengths=None, hidden=None):
n_batch = src.size(1)
if lengths is not None:
n_batch_, = lengths.size()
aeq(n_batch, n_batch_)
def forward(self, src, lengths=None):
"""
Args:
src (LongTensor):
padded sequences of sparse indices ``(src_len, batch, nfeat)``
lengths (LongTensor): length of each sequence ``(batch,)``
Returns:
(FloatTensor, FloatTensor):
* final encoder state, used to initialize decoder
* memory bank for attention, ``(src_len, batch, hidden)``
"""
raise NotImplementedError