File size: 800 Bytes
bc1ada8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch.nn as nn
from torch import Tensor
from modules.wrapper import Linear
class PositionwiseFeedForwardNetwork(nn.Module):
"""
Position-wise Feed-Forward Network (section 3.3)
Args:
- d_model (int): dimension of input and output
- d_ff (int): dimension of inner-layer
- dropout_p (float): dropout probability
"""
def __init__(self, d_model: int, d_ff: int, dropout_p: float) -> None:
super(PositionwiseFeedForwardNetwork, self).__init__()
self.feed_forward = nn.Sequential(
Linear(d_model, d_ff),
nn.Dropout(dropout_p),
nn.ReLU(),
Linear(d_ff, d_model),
nn.Dropout(dropout_p)
)
def forward(self, x: Tensor) -> Tensor:
return self.feed_forward(x)
|