homemade_lo_vi / model /encoder.py
moiduy04's picture
Upload 18 files
bc1ada8
raw
history blame
902 Bytes
from typing import Tuple
import torch.nn as nn
from torch import Tensor
from layers.encoder_layer import EncoderLayer
class Encoder(nn.Module):
"""
A transformer Encoder (no embeddings or positional embeddings)
Args:
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout_p: int,
num_layers: int,
) -> None:
super(Encoder, self).__init__()
self.layers = nn.ModuleList(
[
EncoderLayer(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout_p=dropout_p,
)
for _ in range(num_layers)
]
)
def forward(self, x: Tensor, src_mask: Tensor):
for layer in self.layers:
x, attn = layer(x, src_mask)
return x