|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import dataclasses |
|
import warnings |
|
from dataclasses import dataclass, MISSING |
|
from functools import partial |
|
from typing import Optional, Dict, Any |
|
|
|
from .transformers_4_44_2__configuration_llama import LlamaConfig |
|
from .transformers_4_44_2__modeling_rope_utils import \ |
|
rope_config_validation |
|
|
|
|
|
class DeciLMConfig(LlamaConfig): |
|
model_type = "nemotron-nas" |
|
|
|
def __init__( |
|
self, |
|
block_configs: list[dict] | list["BlockConfig"] = None, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.intermediate_size = None |
|
self.num_key_value_heads = None |
|
|
|
if block_configs is not None: |
|
assert len(block_configs) == self.num_hidden_layers |
|
if isinstance(block_configs[0], dict): |
|
block_configs = [BlockConfig(**conf) for conf in block_configs] |
|
self.block_configs: list[BlockConfig] = block_configs |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
self_dict = super().to_dict() |
|
if self.block_configs is not None: |
|
self_dict["block_configs"] = [dataclasses.asdict(conf) for conf in self.block_configs] |
|
return self_dict |
|
|
|
|
|
@partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True) |
|
class AttentionConfig: |
|
no_op: bool = False |
|
replace_with_linear: bool = False |
|
n_heads_in_group: Optional[int] = None |
|
|
|
def __post_init__(self): |
|
assert not (self.no_op and self.replace_with_linear) |
|
if self.no_op or self.replace_with_linear: |
|
object.__setattr__(self, 'n_heads_in_group', None) |
|
else: |
|
assert self.n_heads_in_group is not None |
|
|
|
|
|
@partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True) |
|
class FFNConfig: |
|
no_op: bool = False |
|
replace_with_linear: bool = False |
|
ffn_mult: Optional[float] = None |
|
|
|
def __post_init__(self): |
|
assert not (self.no_op and self.replace_with_linear) |
|
if self.no_op or self.replace_with_linear: |
|
object.__setattr__(self, 'ffn_mult', None) |
|
else: |
|
assert self.ffn_mult is not None |
|
|
|
|
|
@partial(dataclass, frozen=True, eq=True, unsafe_hash=True, order=True) |
|
class BlockConfig: |
|
attention: AttentionConfig = MISSING |
|
ffn: FFNConfig = MISSING |
|
|
|
def __post_init__(self): |
|
""" |
|
Init subblock dataclasses from dicts |
|
""" |
|
for subblock_name in dataclasses.fields(self): |
|
subblock_config = getattr(self, subblock_name.name) |
|
if isinstance(subblock_config, dict): |
|
subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)] |
|
unsupported_fields = [field_name for field_name in subblock_config.keys() |
|
if field_name not in subblock_fields] |
|
if len(unsupported_fields) > 0: |
|
warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}") |
|
subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields} |
|
object.__setattr__(self, subblock_name.name, |
|
subblock_name.type(**subblock_config)) |
|
|