import torch import torch.nn as nn class LayerNorm(nn.Module): """ Layer Normalization. https://arxiv.org/abs/1607.06450 """ def __init__(self, hidden_size, eps=1e-6): super(LayerNorm, self).__init__() self.eps = eps self.gamma = nn.Parameter(torch.ones(hidden_size)) self.beta = nn.Parameter(torch.zeros(hidden_size)) def forward(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) hidden_states = self.gamma * (x-mean) / (std + self.eps) return hidden_states + self.beta class T5LayerNorm(nn.Module): """ Construct a layernorm module in the T5 style No bias and no subtraction of mean. """ def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): # layer norm should always be calculated in float32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.type_as(self.weight) class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight