from typing import Tuple import torch.nn as nn from torch import Tensor from modules.multi_head_attention import MultiHeadAttention from modules.positionwise_feed_forward import PositionwiseFeedForwardNetwork class DecoderLayer(nn.Module): """ A Decoder layer. Args: """ def __init__( self, d_model: int, num_heads: int, d_ff: int, dropout_p: int, ) -> None: super(DecoderLayer, self).__init__() self.self_attn_prenorm = nn.LayerNorm(d_model) self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout_p=dropout_p) self.self_attn_dropout = nn.Dropout(p=dropout_p) self.cross_attn_prenorm = nn.LayerNorm(d_model) self.cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout_p=dropout_p) self.cross_attn_dropout = nn.Dropout(p=dropout_p) self.feed_forward_prenorm = nn.LayerNorm(d_model) self.feed_forward = PositionwiseFeedForwardNetwork(d_model=d_model, d_ff=d_ff, dropout_p=dropout_p) def forward( self, decoder_inputs: Tensor, encoder_outputs: Tensor, src_mask: Tensor, tgt_mask: Tensor, ) -> Tuple[Tensor, Tensor]: residual = decoder_inputs outputs = self.self_attn_prenorm(decoder_inputs) outputs, attn = self.self_attn(outputs, outputs, outputs, tgt_mask) outputs = self.self_attn_dropout(outputs) + residual residual = outputs outputs = self.self_attn_prenorm(outputs) outputs, attn = self.self_attn(outputs, encoder_outputs, encoder_outputs, src_mask) outputs = self.self_attn_dropout(outputs) + residual residual = outputs outputs = self.feed_forward_prenorm(outputs) outputs = self.feed_forward(outputs) outputs += residual return outputs, attn