|
from typing import Tuple |
|
|
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from layers.encoder_layer import EncoderLayer |
|
|
|
class Encoder(nn.Module): |
|
""" |
|
A transformer Encoder (no embeddings or positional embeddings) |
|
|
|
Args: |
|
""" |
|
def __init__( |
|
self, |
|
d_model: int, |
|
num_heads: int, |
|
d_ff: int, |
|
dropout_p: int, |
|
num_layers: int, |
|
) -> None: |
|
super(Encoder, self).__init__() |
|
self.layers = nn.ModuleList( |
|
[ |
|
EncoderLayer( |
|
d_model=d_model, |
|
num_heads=num_heads, |
|
d_ff=d_ff, |
|
dropout_p=dropout_p, |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
|
|
def forward(self, x: Tensor, src_mask: Tensor): |
|
for layer in self.layers: |
|
x, attn = layer(x, src_mask) |
|
return x |