File size: 14,776 Bytes
204da06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
import warnings
from transformers import BertConfig as TransformersBertConfig
class BertConfig(TransformersBertConfig):
def __init__(
self,
alibi_starting_size: int = 512,
normalization: str = "layernorm",
attention_probs_dropout_prob: float = 0.0,
head_pred_act: str = "gelu",
deterministic_fa2: bool = False,
allow_embedding_resizing: bool = False,
**kwargs,
):
"""Configuration class for MosaicBert.
Args:
alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to
create when initializing the model. You should be able to ignore this parameter in most cases.
Defaults to 512.
attention_probs_dropout_prob (float): By default, turn off attention dropout in MosaicBERT
Note that the custom Triton Flash Attention with ALiBi implementation does not support droput.
However, Flash Attention 2 supports ALiBi and dropout https://github.com/Dao-AILab/flash-attention
embed_dropout_prob (float): Dropout probability for the embedding layer.
attn_out_dropout_prob (float): Dropout probability for the attention output layer.
mlp_dropout_prob (float): Dropout probability for the MLP layer.
allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size.
"""
super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
self.alibi_starting_size = alibi_starting_size
self.normalization = normalization
self.head_pred_act = head_pred_act
self.deterministic_fa2 = deterministic_fa2
self.allow_embedding_resizing = allow_embedding_resizing
class FlexBertConfig(TransformersBertConfig):
def __init__(
self,
attention_layer: str = "base",
attention_probs_dropout_prob: float = 0.0,
attn_out_bias: bool = False,
attn_out_dropout_prob: float = 0.0,
attn_qkv_bias: bool = False,
bert_layer: str = "prenorm",
decoder_bias: bool = True,
embed_dropout_prob: float = 0.0,
embed_norm: bool = True,
final_norm: bool = False,
embedding_layer: str = "absolute_pos",
encoder_layer: str = "base",
loss_function: str = "cross_entropy",
loss_kwargs: dict = {},
mlp_dropout_prob: float = 0.0,
mlp_in_bias: bool = False,
mlp_layer: str = "mlp",
mlp_out_bias: bool = False,
norm_kwargs: dict = {},
normalization: str = "rmsnorm",
padding: str = "unpadded",
head_class_act: str = "silu",
head_class_bias: bool = False,
head_class_dropout: float = 0.0,
head_class_norm: str = False,
head_pred_act: str = "silu",
head_pred_bias: bool = False,
head_pred_dropout: float = 0.0,
head_pred_norm: bool = True,
pooling_type: str = "cls",
rotary_emb_dim: int | None = None,
rotary_emb_base: float = 10000.0,
rotary_emb_scale_base=None,
rotary_emb_interleaved: bool = False,
use_fa2: bool = True,
use_sdpa_attn_mask: bool = False,
allow_embedding_resizing: bool = False,
init_method: str = "default",
init_std: float = 0.02,
init_cutoff_factor: float = 2.0,
init_small_embedding: bool = False,
initial_attention_layer: str | None = None,
initial_bert_layer: str | None = None,
initial_mlp_layer: str | None = None,
num_initial_layers: int = 1,
skip_first_prenorm: bool = False,
deterministic_fa2: bool = False,
sliding_window: int = -1,
global_attn_every_n_layers: int = -1,
local_attn_rotary_emb_base: float = -1,
local_attn_rotary_emb_dim: int | None = None,
unpad_embeddings: bool = False,
pad_logits: bool = False,
compile_model: bool = False,
masked_prediction: bool = False,
casual_mask: bool = False,
**kwargs,
):
"""
Args:
attention_layer (str): Attention layer type.
attention_probs_dropout_prob (float): Dropout probability for attention probabilities.
attn_out_bias (bool): use bias in attention output projection.
attn_out_dropout_prob (float): Dropout probability for attention output.
attn_qkv_bias (bool): use bias for query, key, value linear layer(s).
bert_layer (str): BERT layer type.
decoder_bias (bool): use bias in decoder linear layer.
embed_dropout_prob (float): Dropout probability for embeddings.
embed_norm (bool): Normalize embedding output.
final_norm (bool): Add normalization after the final encoder layer and before head.
embedding_layer (str): Embedding layer type.
encoder_layer (str): Encoder layer type.
loss_function (str): Loss function to use.
loss_kwargs (dict): Keyword arguments for loss function.
mlp_dropout_prob (float): Dropout probability for MLP layers.
mlp_in_bias (bool): Use bias in MLP input linear layer.
mlp_layer (str): MLP layer type.
mlp_out_bias (bool): Use bias in MLP output linear layer.
norm_kwargs (dict): Keyword arguments for normalization layers.
normalization (str): Normalization type.
padding (str): Unpad inputs. Best with `use_fa2=True`.
head_class_act (str): Activation function for classification head.
head_class_bias (bool): Use bias in classification head linear layer(s).
head_class_dropout (float): Dropout probability for classification head.
head_class_norm (str): Normalization type for classification head.
head_pred_act (str): Activation function for prediction head.
head_pred_bias (bool): Use bias in prediction head linear layer(s).
head_pred_dropout (float): Dropout probability for prediction head.
head_pred_norm (bool): Normalize prediction head output.
pooling_type (str): Pooling type.
rotary_emb_dim (int | None): Rotary embedding dimension.
rotary_emb_base (float): Rotary embedding base.
rotary_emb_scale_base (float): Rotary embedding scale base.
rotary_emb_interleaved (bool): Use interleaved rotary embeddings.
use_fa2 (bool): Use FlashAttention2. Requires flash_attn package.
use_sdpa_attn_mask (bool): Pass a mask to SDPA. This will prevent SDPA from using the PyTorch FA2 kernel.
allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size.
init_method (str): Model layers initialization method.
init_std (float): Standard deviation for initialization. Used for normal and full_megatron init.
init_cutoff_factor (float): Cutoff factor for initialization. Used for normal and full_megatron init.
init_small_embedding (bool): Initialize embeddings with RWKV small init.
initial_attention_layer (str | None): Replace first `num_initial_layers` attention_layer instance with this layer.
initial_bert_layer (str | None): Replace first `num_initial_layers` bert_layer instance with this layer.
initial_mlp_layer (str | None): Replace first `num_initial_layers` mlp_layer instance with this layer.
num_initial_layers (int): Number of initial layers to set via `initial_attention_layer`, `initial_bert_layer`, and `initial_mlp_layer`.
skip_first_prenorm (bool): Skip pre-normalization for the first bert layer. Requires `embed_norm=True`.
deterministic_fa2 (bool): Use Flash Attention 2 deterministic mode. This is slower then the default non-deterministic mode.
sliding_window (int): Use sliding window attention with window size `n`. -1 to disable. Window size split between the left and right context. Only supports FA2.
global_attn_every_n_layers (int): Use global attention every `n` layers and sliding window for the rest. -1 to disable.
local_attn_rotary_emb_base (float): Rotary embedding base for local attention. -1 to disable and use `rotary_emb_base` for all layers.
local_attn_rotary_emb_dim (int | None): Rotary embedding dimension for local attention. None to disable and use `rotary_emb_dim` for all layers.
unpad_embeddings (bool): Unpad inputs before the embedding layer.
pad_logits (bool): Pad logits after the calculating the loss.
compile_model (bool): Compile the subset of the model which can be compiled.
masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers
casual_mask (bool): Use a casual mask, defaulting to false.
**kwargs: Additional keyword arguments.
"""
super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
self.attention_layer = attention_layer
self.attn_out_bias = attn_out_bias
self.attn_out_dropout_prob = attn_out_dropout_prob
self.attn_qkv_bias = attn_qkv_bias
self.bert_layer = bert_layer
self.decoder_bias = decoder_bias
self.embed_dropout_prob = embed_dropout_prob
self.embed_norm = embed_norm
self.final_norm = final_norm
self.embedding_layer = embedding_layer
self.encoder_layer = encoder_layer
self.loss_function = loss_function
self.loss_kwargs = loss_kwargs
self.mlp_dropout_prob = mlp_dropout_prob
self.mlp_in_bias = mlp_in_bias
self.mlp_layer = mlp_layer
self.mlp_out_bias = mlp_out_bias
self.norm_kwargs = norm_kwargs
self.normalization = normalization
self.padding = padding
self.head_class_act = head_class_act
self.head_class_bias = head_class_bias
self.head_class_dropout = head_class_dropout
self.head_class_norm = head_class_norm
self.head_pred_act = head_pred_act
self.head_pred_bias = head_pred_bias
self.head_pred_dropout = head_pred_dropout
self.head_pred_norm = head_pred_norm
self.pooling_type = pooling_type
self.rotary_emb_dim = rotary_emb_dim
self.rotary_emb_base = rotary_emb_base
self.rotary_emb_scale_base = rotary_emb_scale_base
self.rotary_emb_interleaved = rotary_emb_interleaved
self.use_fa2 = use_fa2
self.use_sdpa_attn_mask = use_sdpa_attn_mask
self.allow_embedding_resizing = allow_embedding_resizing
self.init_method = init_method
self.init_std = init_std
self.init_cutoff_factor = init_cutoff_factor
self.init_small_embedding = init_small_embedding
self.initial_attention_layer = initial_attention_layer
self.initial_bert_layer = initial_bert_layer
self.initial_mlp_layer = initial_mlp_layer
self.num_initial_layers = num_initial_layers
self.skip_first_prenorm = skip_first_prenorm
self.deterministic_fa2 = deterministic_fa2
self.sliding_window = sliding_window
self.global_attn_every_n_layers = global_attn_every_n_layers
self.local_attn_rotary_emb_base = local_attn_rotary_emb_base
self.local_attn_rotary_emb_dim = local_attn_rotary_emb_dim
self.unpad_embeddings = unpad_embeddings
self.pad_logits = pad_logits
self.compile_model = compile_model
self.masked_prediction = masked_prediction
self.casual_mask = casual_mask
if loss_kwargs.get("return_z_loss", False):
if loss_function != "fa_cross_entropy":
raise ValueError("loss_function must be 'fa_cross_entropy' when return_z_loss is True")
if loss_kwargs.get("lse_square_scale", 0) <= 0:
raise ValueError(
"lse_square_scale must be passed to `loss_kwargs` and must be greater than 0 for z_loss"
)
if loss_kwargs.get("inplace_backward", False):
self.loss_kwargs["inplace_backward"] = False
warnings.warn("`inplace_backward=True` will cause incorrect metrics. Automatically setting to False.")
if global_attn_every_n_layers > 0 and (self.num_hidden_layers - 1) % global_attn_every_n_layers != 0:
raise ValueError(
f"{global_attn_every_n_layers=} must be a divisor of one less than {self.num_hidden_layers=}"
)
if self.sliding_window != -1:
if not self.use_fa2:
raise ValueError("Sliding window attention is only supported with FlashAttention2")
if self.sliding_window % 2 != 0 and self.sliding_window % 64 != 0:
raise ValueError(
f"Sliding window must be an even number and divisible by 64: {self.sliding_window=} {self.sliding_window % 64} {self.sliding_window % 2}"
)
else:
if self.global_attn_every_n_layers != -1:
raise ValueError("global_attn_every_n_layers must be -1 when sliding_window is disabled")
if self.local_attn_rotary_emb_base != -1:
raise ValueError("local_attn_rotary_emb_base must be -1 when sliding_window is disabled")
if self.local_attn_rotary_emb_dim is not None:
raise ValueError("local_attn_rotary_emb_dim must be None when sliding_window is disabled")
if self.unpad_embeddings and self.padding != "unpadded":
warnings.warn(
"`unpad_embeddings=True` requires `padding='unpadded'`. Automatically setting `padding='unpadded'`."
)
self.padding = "unpadded"
if self.pad_logits and not self.unpad_embeddings:
raise ValueError("`pad_logits=True` requires `unpad_embeddings=True`")
if self.unpad_embeddings and self.embedding_layer == "absolute_pos":
raise ValueError(f"{self.unpad_embeddings=} is incompatible with {self.embedding_layer=}")
PADDING = ["unpadded", "padded"]
def maybe_add_padding(config: FlexBertConfig, config_option: str) -> str:
if config.padding not in PADDING:
raise ValueError(f"Invalid padding type: {config.padding}, must be one of {PADDING}")
if not any(config_option.startswith(pad + "_") for pad in PADDING):
config_option = f"{config.padding}_{config_option}"
return config_option
|