import math from collections.abc import Callable import torch import torch.nn as nn class Lambda(nn.Module): """ Overview: A custom lambda module for constructing custom layers. Interfaces: ``__init__``, ``forward``. """ def __init__(self, f: Callable): """ Overview: Initialize the lambda module with a given function. Arguments: - f (:obj:`Callable`): a python function """ super(Lambda, self).__init__() self.f = f def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overview: Compute the function of the input tensor. Arguments: - x (:obj:`torch.Tensor`): The input tensor. """ return self.f(x) class GLU(nn.Module): """ Overview: Gating Linear Unit (GLU), a specific type of activation function, which is first proposed in [Language Modeling with Gated Convolutional Networks](https://arxiv.org/pdf/1612.08083.pdf). Interfaces: ``__init__``, ``forward``. """ def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None: """ Overview: Initialize the GLU module. Arguments: - input_dim (:obj:`int`): The dimension of the input tensor. - output_dim (:obj:`int`): The dimension of the output tensor. - context_dim (:obj:`int`): The dimension of the context tensor. - input_type (:obj:`str`): The type of input, now supports ['fc', 'conv2d'] """ super(GLU, self).__init__() assert (input_type in ['fc', 'conv2d']) if input_type == 'fc': self.layer1 = nn.Linear(context_dim, input_dim) self.layer2 = nn.Linear(input_dim, output_dim) elif input_type == 'conv2d': self.layer1 = nn.Conv2d(context_dim, input_dim, 1, 1, 0) self.layer2 = nn.Conv2d(input_dim, output_dim, 1, 1, 0) def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: """ Overview: Compute the GLU transformation of the input tensor. Arguments: - x (:obj:`torch.Tensor`): The input tensor. - context (:obj:`torch.Tensor`): The context tensor. Returns: - x (:obj:`torch.Tensor`): The output tensor after GLU transformation. """ gate = self.layer1(context) gate = torch.sigmoid(gate) x = gate * x x = self.layer2(x) return x class Swish(nn.Module): """ Overview: Swish activation function, which is a smooth, non-monotonic activation function. For more details, please refer to [Searching for Activation Functions](https://arxiv.org/pdf/1710.05941.pdf). Interfaces: ``__init__``, ``forward``. """ def __init__(self): """ Overview: Initialize the Swish module. """ super(Swish, self).__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overview: Compute the Swish transformation of the input tensor. Arguments: - x (:obj:`torch.Tensor`): The input tensor. Returns: - x (:obj:`torch.Tensor`): The output tensor after Swish transformation. """ return x * torch.sigmoid(x) class GELU(nn.Module): """ Overview: Gaussian Error Linear Units (GELU) activation function, which is widely used in NLP models like GPT, BERT. For more details, please refer to the original paper: https://arxiv.org/pdf/1606.08415.pdf. Interfaces: ``__init__``, ``forward``. """ def __init__(self): """ Overview: Initialize the GELU module. """ super(GELU, self).__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overview: Compute the GELU transformation of the input tensor. Arguments: - x (:obj:`torch.Tensor`): The input tensor. Returns: - x (:obj:`torch.Tensor`): The output tensor after GELU transformation. """ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) def build_activation(activation: str, inplace: bool = None) -> nn.Module: """ Overview: Build and return the activation module according to the given type. Arguments: - activation (:obj:`str`): The type of activation module, now supports \ ['relu', 'glu', 'prelu', 'swish', 'gelu', 'tanh', 'sigmoid', 'softplus', 'elu', 'square', 'identity']. - inplace (Optional[:obj:`bool`): Execute the operation in-place in activation, defaults to None. Returns: - act_func (:obj:`nn.module`): The corresponding activation module. """ if inplace is not None: assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation) else: inplace = False act_func = { 'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(), 'swish': Swish(), 'gelu': GELU(), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(), "softplus": nn.Softplus(), "elu": nn.ELU(), "square": Lambda(lambda x: x ** 2), "identity": Lambda(lambda x: x), } if activation.lower() in act_func.keys(): return act_func[activation] else: raise KeyError("invalid key for activation: {}".format(activation))