|
import mamba_ssm |
|
from transformers import PretrainedConfig |
|
|
|
mamba_config_defaults = mamba_ssm.models.config_mamba.MambaConfig() |
|
|
|
class MambaConfig(PretrainedConfig): |
|
model_type = "mamba" |
|
|
|
def __init__( |
|
self, |
|
d_model: int = mamba_config_defaults.d_model, |
|
fused_add_norm: bool = mamba_config_defaults.fused_add_norm, |
|
n_layer: int = mamba_config_defaults.n_layer, |
|
pad_vocab_size_multiple: int = mamba_config_defaults.pad_vocab_size_multiple, |
|
residual_in_fp32: bool = mamba_config_defaults.residual_in_fp32, |
|
rms_norm: bool = mamba_config_defaults.rms_norm, |
|
ssm_cfg: dict = mamba_config_defaults.ssm_cfg, |
|
vocab_size: int = mamba_config_defaults.vocab_size, |
|
**kwargs, |
|
): |
|
self.d_model = d_model |
|
self.fused_add_norm = fused_add_norm |
|
self.n_layer = n_layer |
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple |
|
self.residual_in_fp32 = residual_in_fp32 |
|
self.rms_norm = rms_norm |
|
self.ssm_cfg = ssm_cfg |
|
self.vocab_size = vocab_size |
|
|
|
super().__init__(**kwargs) |
|
|