homemade_lo_vi / modules /positionwise_feed_forward.py
moiduy04's picture
Upload 18 files
bc1ada8
raw
history blame
800 Bytes
import torch.nn as nn
from torch import Tensor
from modules.wrapper import Linear
class PositionwiseFeedForwardNetwork(nn.Module):
"""
Position-wise Feed-Forward Network (section 3.3)
Args:
- d_model (int): dimension of input and output
- d_ff (int): dimension of inner-layer
- dropout_p (float): dropout probability
"""
def __init__(self, d_model: int, d_ff: int, dropout_p: float) -> None:
super(PositionwiseFeedForwardNetwork, self).__init__()
self.feed_forward = nn.Sequential(
Linear(d_model, d_ff),
nn.Dropout(dropout_p),
nn.ReLU(),
Linear(d_ff, d_model),
nn.Dropout(dropout_p)
)
def forward(self, x: Tensor) -> Tensor:
return self.feed_forward(x)