FlexBert / normalization.py
NohTow's picture
Using dict as input
ce9aa51
raw
history blame
3.47 kB
# Copyright 2024 **AUTHORS_TODO**
# License: Apache-2.0
# RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation)
# License: LLAMA 2 COMMUNITY LICENSE AGREEMENT
import inspect
import torch
import torch.nn as nn
from torch.nn import init
from .configuration_bert import FlexBertConfig
try:
from flash_attn.ops.triton.layer_norm import RMSNorm as TritonRMSNorm
from flash_attn.ops.triton.layer_norm import layer_norm_fn
except ImportError:
TritonRMSNorm = None
layer_norm_fn = None
class RMSNorm(nn.Module):
"""Llama2 RMSNorm implementation"""
def __init__(self, dim: int, eps: float = 1e-5):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight
def reset_parameters(self):
init.ones_(self.weight)
if layer_norm_fn is not None:
class TritonLayerNorm(nn.LayerNorm):
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
return layer_norm_fn(
x,
self.weight,
self.bias,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
)
else:
TritonLayerNorm = None
NORM2CLS = {
"layernorm": nn.LayerNorm,
"triton_layernorm": TritonLayerNorm if TritonLayerNorm is not None else nn.LayerNorm,
"rmsnorm": RMSNorm,
"triton_rmsnorm": TritonRMSNorm if TritonRMSNorm is not None else RMSNorm,
}
def get_norm_layer(config: FlexBertConfig, compiled_norm: bool = False) -> nn.Module:
try:
if compiled_norm:
# Use non-Triton norms when compiling
if config.normalization.startswith("triton_"):
norm = config.normalization.replace("triton_", "")
else:
norm = config.normalization
else:
norm = config.normalization
signature = inspect.signature(NORM2CLS[norm])
if hasattr(config, "norm_kwargs"):
norm_kwargs = {k: v for k, v in config.norm_kwargs.items() if k in signature.parameters}
else:
norm_kwargs = {}
return NORM2CLS[norm](config.hidden_size, **norm_kwargs)
except KeyError:
raise ValueError(f"Invalid normalization layer type: {config.normalization}, must be one of {NORM2CLS.keys()}.")