FlexBert / mlp.py
NohTow's picture
Using dict as input
ce9aa51
raw
history blame
9.1 kB
# Copyright 2024 **AUTHORS_TODO**
# License: Apache-2.0
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
# Copyright 2023 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, Tri Dao.
from typing import Optional
import torch
import torch.nn as nn
from .configuration_bert import FlexBertConfig
from .activation import get_act_fn
from .normalization import get_norm_layer
from .initialization import ModuleType, init_weights
class BertResidualGLU(nn.Module):
"""Applies the FFN at the end of each Mosaic BERT layer.
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
introduces Gated Linear Units.
Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
`config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
with the `config.intermediate_size=3072`.
However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
"""
def __init__(
self,
config,
):
super().__init__()
self.config = config
self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False)
self.act = get_act_fn(config.hidden_act)
self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.layernorm = get_norm_layer(config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Compute new hidden states from current hidden states.
Args:
hidden_states (torch.Tensor): The (unpadded) hidden states from
the attention layer [nnz, dim].
"""
residual_connection = hidden_states
# compute the activation
hidden_states = self.gated_layers(hidden_states)
gated = hidden_states[:, : self.config.intermediate_size]
non_gated = hidden_states[:, self.config.intermediate_size :]
hidden_states = self.act(gated) * non_gated
hidden_states = self.dropout(hidden_states)
# multiply by the second matrix
hidden_states = self.wo(hidden_states)
# add the residual connection and post-LN
hidden_states = self.layernorm(hidden_states + residual_connection)
return hidden_states
class FlexBertMLPBase(nn.Module):
"""A FlexBERT MLP base class for type hints."""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__()
self.config = config
self.layer_id = layer_id
def _init_weights(self, reset_params: bool = False):
raise NotImplementedError("This is a base class and should not be used directly.")
def reset_parameters(self):
self._init_weights(reset_params=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("This is a base class and should not be used directly.")
class FlexBertMLP(FlexBertMLPBase):
"""Applies the MLP at the end of each FlexBERT layer.
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
"""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)
self.Wi = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_in_bias)
self.act = get_act_fn(config.hidden_act)
self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity()
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias)
def _init_weights(self, reset_params: bool = False):
init_weights(
self.config,
self.Wi,
layer_dim=self.config.hidden_size,
layer_id=None,
type_of_module=ModuleType.in_module,
)
init_weights(
self.config,
self.Wo,
layer_dim=self.config.intermediate_size,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Compute new hidden states from current hidden states.
Args:
hidden_states (torch.Tensor): The (unpadded) hidden states from
the attention layer [nnz, dim].
"""
return self.Wo(self.drop(self.act(self.Wi(hidden_states))))
class FlexBertGLU(FlexBertMLPBase):
"""Applies the GLU at the end of each FlexBERT layer.
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
"""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)
self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_in_bias)
self.act = get_act_fn(config.hidden_act)
self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity()
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias)
def _init_weights(self, reset_params: bool = False):
init_weights(
self.config,
self.Wi,
layer_dim=self.config.hidden_size,
layer_id=None,
type_of_module=ModuleType.in_module,
)
init_weights(
self.config,
self.Wo,
layer_dim=self.config.intermediate_size,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
return self.Wo(self.drop(self.act(input) * gate))
class FlexBertParallelGLU(FlexBertMLPBase):
"""Applies the GLU at the end of each FlexBERT layer using intermediate_ff computed in parallel of the attention.
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
"""
def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
super().__init__(config=config, layer_id=layer_id)
self.act = get_act_fn(config.hidden_act)
self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity()
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias)
def _init_weights(self, reset_params: bool = False):
init_weights(
self.config,
self.Wo,
layer_dim=self.config.intermediate_size,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
)
def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor:
input, gate = intermediate_ff.chunk(2, dim=-1)
return self.Wo(self.drop(self.act(input) * gate))
MLP2CLS = {
"mlp": FlexBertMLP,
"glu": FlexBertGLU,
"parallel_glu": FlexBertParallelGLU,
}
def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertMLPBase:
try:
mlp_layer = (
config.initial_mlp_layer
if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None
else config.mlp_layer
)
return MLP2CLS[mlp_layer](config, layer_id=layer_id)
except KeyError as e:
if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None:
raise ValueError(
f"Invalid MLP layer type: {config.initial_mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}"
)
else:
raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}")