File size: 902 Bytes
bc1ada8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from typing import Tuple
import torch.nn as nn
from torch import Tensor
from layers.encoder_layer import EncoderLayer
class Encoder(nn.Module):
"""
A transformer Encoder (no embeddings or positional embeddings)
Args:
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout_p: int,
num_layers: int,
) -> None:
super(Encoder, self).__init__()
self.layers = nn.ModuleList(
[
EncoderLayer(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout_p=dropout_p,
)
for _ in range(num_layers)
]
)
def forward(self, x: Tensor, src_mask: Tensor):
for layer in self.layers:
x, attn = layer(x, src_mask)
return x |