|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from modules.wrapper import Linear |
|
|
|
class PositionwiseFeedForwardNetwork(nn.Module): |
|
""" |
|
Position-wise Feed-Forward Network (section 3.3) |
|
|
|
Args: |
|
- d_model (int): dimension of input and output |
|
- d_ff (int): dimension of inner-layer |
|
- dropout_p (float): dropout probability |
|
""" |
|
|
|
def __init__(self, d_model: int, d_ff: int, dropout_p: float) -> None: |
|
super(PositionwiseFeedForwardNetwork, self).__init__() |
|
|
|
self.feed_forward = nn.Sequential( |
|
Linear(d_model, d_ff), |
|
nn.Dropout(dropout_p), |
|
nn.ReLU(), |
|
Linear(d_ff, d_model), |
|
nn.Dropout(dropout_p) |
|
) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return self.feed_forward(x) |
|
|