from typing import Tuple import torch.nn as nn from torch import Tensor from layers.decoder_layer import DecoderLayer class Decoder(nn.Module): """ A transformer Decoder (no embeddings or positional embeddings) Args: - Outputs: - (batch, seq_len, d_model): decoder output - (batch, seq_len, seq_len): decoder attention """ def __init__( self, d_model: int, num_heads: int, d_ff: int, dropout_p: int, num_layers: int, ) -> None: super(Decoder, self).__init__() self.layers = nn.ModuleList( [ DecoderLayer( d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout_p=dropout_p, ) for _ in range(num_layers) ] ) def forward( self, x: Tensor, encoder_output: Tensor, src_mask: Tensor, tgt_mask: Tensor ) -> Tuple[Tensor, Tensor]: for layer in self.layers: x, attn = layer(x, encoder_output, src_mask, tgt_mask) return x, attn