|
"""Caduceus config for Hugging Face. |
|
|
|
""" |
|
|
|
from typing import Optional, Union |
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
class CaduceusConfig(PretrainedConfig): |
|
"""Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance.""" |
|
model_type = "caduceus" |
|
|
|
def __init__( |
|
self, |
|
|
|
d_model: int = 2560, |
|
n_layer: int = 64, |
|
vocab_size: int = 50277, |
|
ssm_cfg: Optional[dict] = None, |
|
rms_norm: bool = True, |
|
residual_in_fp32: bool = True, |
|
fused_add_norm: bool = True, |
|
pad_vocab_size_multiple: int = 8, |
|
|
|
|
|
norm_epsilon: float = 1e-5, |
|
|
|
|
|
initializer_cfg: Optional[dict] = None, |
|
|
|
|
|
bidirectional: bool = True, |
|
bidirectional_strategy: Union[str, None] = "add", |
|
bidirectional_weight_tie: bool = True, |
|
rcps: bool = False, |
|
complement_map: Optional[dict] = None, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.d_model = d_model |
|
self.n_layer = n_layer |
|
self.vocab_size = vocab_size |
|
self.ssm_cfg = ssm_cfg |
|
self.rms_norm = rms_norm |
|
self.residual_in_fp32 = residual_in_fp32 |
|
self.fused_add_norm = fused_add_norm |
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple |
|
self.norm_epsilon = norm_epsilon |
|
self.initializer_cfg = initializer_cfg |
|
self.bidirectional = bidirectional |
|
self.bidirectional_strategy = bidirectional_strategy |
|
self.bidirectional_weight_tie = bidirectional_weight_tie |
|
self.rcps = rcps |
|
self.complement_map = complement_map |
|
|