import torch.nn as nn from torch import Tensor from modules.wrapper import Linear class PositionwiseFeedForwardNetwork(nn.Module): """ Position-wise Feed-Forward Network (section 3.3) Args: - d_model (int): dimension of input and output - d_ff (int): dimension of inner-layer - dropout_p (float): dropout probability """ def __init__(self, d_model: int, d_ff: int, dropout_p: float) -> None: super(PositionwiseFeedForwardNetwork, self).__init__() self.feed_forward = nn.Sequential( Linear(d_model, d_ff), nn.Dropout(dropout_p), nn.ReLU(), Linear(d_ff, d_model), nn.Dropout(dropout_p) ) def forward(self, x: Tensor) -> Tensor: return self.feed_forward(x)