|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from collections import OrderedDict |
|
|
|
import torch |
|
from packaging import version |
|
from torch import Tensor, nn |
|
|
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class PytorchGELUTanh(nn.Module): |
|
""" |
|
A fast C implementation of the tanh approximation of the GeLU activation function. See |
|
https://arxiv.org/abs/1606.08415. |
|
|
|
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical |
|
match due to rounding errors. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
if version.parse(torch.__version__) < version.parse("1.12.0"): |
|
raise ImportError( |
|
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use " |
|
"PytorchGELUTanh. Please upgrade torch." |
|
) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return nn.functional.gelu(input, approximate="tanh") |
|
|
|
|
|
class NewGELUActivation(nn.Module): |
|
""" |
|
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see |
|
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 |
|
""" |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) |
|
|
|
|
|
class GELUActivation(nn.Module): |
|
""" |
|
Original Implementation of the GELU activation function in Google BERT repo when initially created. For |
|
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + |
|
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional |
|
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 |
|
""" |
|
|
|
def __init__(self, use_gelu_python: bool = False): |
|
super().__init__() |
|
if use_gelu_python: |
|
self.act = self._gelu_python |
|
else: |
|
self.act = nn.functional.gelu |
|
|
|
def _gelu_python(self, input: Tensor) -> Tensor: |
|
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return self.act(input) |
|
|
|
|
|
class FastGELUActivation(nn.Module): |
|
""" |
|
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs |
|
""" |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) |
|
|
|
|
|
class QuickGELUActivation(nn.Module): |
|
""" |
|
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs |
|
""" |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return input * torch.sigmoid(1.702 * input) |
|
|
|
|
|
class ClippedGELUActivation(nn.Module): |
|
""" |
|
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as |
|
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to |
|
https://arxiv.org/abs/2004.09602. |
|
|
|
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when |
|
initially created. |
|
|
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + |
|
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 |
|
""" |
|
|
|
def __init__(self, min: float, max: float): |
|
if min > max: |
|
raise ValueError(f"min should be < max (got min: {min}, max: {max})") |
|
|
|
super().__init__() |
|
self.min = min |
|
self.max = max |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return torch.clip(gelu(x), self.min, self.max) |
|
|
|
|
|
class AccurateGELUActivation(nn.Module): |
|
""" |
|
Applies GELU approximation that is faster than default and more accurate than QuickGELU. See: |
|
https://github.com/hendrycks/GELUs |
|
|
|
Implemented along with MEGA (Moving Average Equipped Gated Attention) |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.precomputed_constant = math.sqrt(2 / math.pi) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3)))) |
|
|
|
|
|
class MishActivation(nn.Module): |
|
""" |
|
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also |
|
visit the official repository for the paper: https://github.com/digantamisra98/Mish |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
if version.parse(torch.__version__) < version.parse("1.9.0"): |
|
self.act = self._mish_python |
|
else: |
|
self.act = nn.functional.mish |
|
|
|
def _mish_python(self, input: Tensor) -> Tensor: |
|
return input * torch.tanh(nn.functional.softplus(input)) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return self.act(input) |
|
|
|
|
|
class LinearActivation(nn.Module): |
|
""" |
|
Applies the linear activation function, i.e. forwarding input directly to output. |
|
""" |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return input |
|
|
|
|
|
class LaplaceActivation(nn.Module): |
|
""" |
|
Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See |
|
https://arxiv.org/abs/2209.10655 |
|
|
|
Inspired by squared relu, but with bounded range and gradient for better stability |
|
""" |
|
|
|
def forward(self, input, mu=0.707107, sigma=0.282095): |
|
input = (input - mu).div(sigma * math.sqrt(2.0)) |
|
return 0.5 * (1.0 + torch.erf(input)) |
|
|
|
|
|
class ReLUSquaredActivation(nn.Module): |
|
""" |
|
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 |
|
""" |
|
|
|
def forward(self, input): |
|
relu_applied = nn.functional.relu(input) |
|
squared = torch.square(relu_applied) |
|
return squared |
|
|
|
|
|
class ClassInstantier(OrderedDict): |
|
def __getitem__(self, key): |
|
content = super().__getitem__(key) |
|
cls, kwargs = content if isinstance(content, tuple) else (content, {}) |
|
return cls(**kwargs) |
|
|
|
|
|
ACT2CLS = { |
|
"gelu": GELUActivation, |
|
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), |
|
"gelu_fast": FastGELUActivation, |
|
"gelu_new": NewGELUActivation, |
|
"gelu_python": (GELUActivation, {"use_gelu_python": True}), |
|
"gelu_pytorch_tanh": PytorchGELUTanh, |
|
"gelu_accurate": AccurateGELUActivation, |
|
"laplace": LaplaceActivation, |
|
"leaky_relu": nn.LeakyReLU, |
|
"linear": LinearActivation, |
|
"mish": MishActivation, |
|
"quick_gelu": QuickGELUActivation, |
|
"relu": nn.ReLU, |
|
"relu2": ReLUSquaredActivation, |
|
"relu6": nn.ReLU6, |
|
"sigmoid": nn.Sigmoid, |
|
"silu": nn.SiLU, |
|
"swish": nn.SiLU, |
|
"tanh": nn.Tanh, |
|
} |
|
ACT2FN = ClassInstantier(ACT2CLS) |
|
|
|
|
|
def get_activation(activation_string): |
|
if activation_string in ACT2FN: |
|
return ACT2FN[activation_string] |
|
else: |
|
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") |
|
|
|
|
|
|
|
gelu_python = get_activation("gelu_python") |
|
gelu_new = get_activation("gelu_new") |
|
gelu = get_activation("gelu") |
|
gelu_fast = get_activation("gelu_fast") |
|
quick_gelu = get_activation("quick_gelu") |
|
silu = get_activation("silu") |
|
mish = get_activation("mish") |
|
linear_act = get_activation("linear") |
|
|