File size: 800 Bytes
bc1ada8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.nn as nn
from torch import Tensor

from modules.wrapper import Linear

class PositionwiseFeedForwardNetwork(nn.Module):
    """
    Position-wise Feed-Forward Network (section 3.3)

    Args:
        - d_model (int): dimension of input and output
        - d_ff (int): dimension of inner-layer
        - dropout_p (float): dropout probability
    """

    def __init__(self, d_model: int, d_ff: int, dropout_p: float) -> None:
        super(PositionwiseFeedForwardNetwork, self).__init__()

        self.feed_forward = nn.Sequential(
            Linear(d_model, d_ff),
            nn.Dropout(dropout_p),
            nn.ReLU(),
            Linear(d_ff, d_model),
            nn.Dropout(dropout_p)
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.feed_forward(x)