# Copyright (C) 2024 Charles O. Goddard # # This software is free software: you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # This software is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. import importlib.resources import string from abc import ABC, abstractmethod from typing import ClassVar, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, Field from transformers import PretrainedConfig from typing_extensions import Literal import mergekit._data.architectures class WeightInfo(BaseModel, frozen=True): """Information about an individual weight tensor in a model. Attributes: name (str): The name of the tensor representing the weight. is_embed (bool): Indicates whether the weight is for an embedding or language model head. input_space (Optional[str]): The name of the input space associated with the weight, if applicable. output_space (Optional[str]): The name of the output space associated with the weight, if applicable. optional (bool): Indicates whether the weight can be omitted from a model. aliases (Optional[List[str]]): List of alternative names for the weight, if applicable. """ name: str is_embed: bool = False input_space: Optional[str] = None output_space: Optional[str] = None optional: bool = False aliases: Optional[List[str]] = None class ProceduralSpaceInfo(BaseModel, frozen=True): """Defines a procedural space computed from one or more other spaces. Currently only supports residual connections. Attributes: name (str): The name of the space defined. type (str): The type of procedural space. inputs (List[str]): List of names of spaces used to define this space.""" name: str type: Literal["residual"] inputs: List[str] class ArchitectureInfo(ABC): @abstractmethod def name(self) -> str: """Return the name of the architecture.""" ... @abstractmethod def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: """Return a list of all weights preceding the first layer.""" ... @abstractmethod def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: """Return a list of all weights following the final layer.""" ... @abstractmethod def layer_weights( self, index: int, config: PretrainedConfig ) -> Optional[List[WeightInfo]]: """Return a list of all weights associated with a given layer.""" ... @abstractmethod def sliceable(self) -> bool: """ Return True if the layers of this architecture can be meaningfully sliced. """ ... def num_layers_config_key(self) -> str: """Key in config that represents number of layers""" return "num_hidden_layers" def num_layers(self, config: PretrainedConfig) -> int: """Return the number of layers in a model.""" return getattr(config, self.num_layers_config_key()) def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: """Return all weights associated with a model.""" num_layers = self.num_layers(config) res = list(self.pre_weights(config)) for layer_idx in range(num_layers): res.extend(self.layer_weights(layer_idx, config)) res.extend(self.post_weights(config)) return res def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]: """Return a list of all procedurally defined spaces in a model.""" return [] def has_defined_spaces(self) -> bool: """ Return True if this architecture defines space information needed for matching-based merge methods. """ return False class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True): info: ArchitectureInfo config: PretrainedConfig def name(self) -> str: return self.info.name() def num_layers(self) -> int: return self.info.num_layers(self.config) def pre_weights(self) -> List[WeightInfo]: return self.info.pre_weights(self.config) def post_weights(self) -> List[WeightInfo]: return self.info.post_weights(self.config) def layer_weights(self, index: int) -> List[WeightInfo]: return self.info.layer_weights(index, self.config) def procedural_spaces(self) -> List[ProceduralSpaceInfo]: return self.info.procedural_spaces(self.config) def all_weights(self) -> List[WeightInfo]: return self.info.all_weights(self.config) class JSONLayerTemplates(BaseModel, frozen=True): weights: List[WeightInfo] procedural_spaces: Optional[List[ProceduralSpaceInfo]] = None class JSONArchitectureDefinition(BaseModel, frozen=True): expected_model_type: str = Field(alias="model_type") architectures: List[str] pre_weights: List[WeightInfo] layer_templates: JSONLayerTemplates post_weights: List[WeightInfo] procedural_spaces: Optional[List[ProceduralSpaceInfo]] = None num_layers_config_key: Optional[str] = None class TemplateWithArithmetic(string.Template): idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)" def _template_substitution( template: str, num_layers: int, layer_idx: Optional[int] = None ) -> str: if "{" not in template: return template substitutions = { "num_layers": num_layers, "num_layers+1": num_layers + 1, "num_layers-1": num_layers - 1, } if layer_idx is not None: substitutions.update( { "layer_index": layer_idx, "layer_index+1": layer_idx + 1, "layer_index-1": layer_idx - 1, } ) return TemplateWithArithmetic(template).substitute(substitutions) class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True): definition: JSONArchitectureDefinition def _substitute( self, item: Union[WeightInfo, ProceduralSpaceInfo], config: PretrainedConfig, layer_idx: Optional[int] = None, ) -> Union[WeightInfo, ProceduralSpaceInfo]: num_layers = self.num_layers(config) obj_dict = item.model_dump(mode="json", exclude_unset=True) for key in obj_dict: if isinstance(obj_dict[key], str): obj_dict[key] = _template_substitution( obj_dict[key], num_layers, layer_idx ) elif isinstance(obj_dict[key], list): obj_dict[key] = [ ( _template_substitution(s, num_layers, layer_idx) if isinstance(s, str) else s ) for s in obj_dict[key] ] return type(item).model_validate(obj_dict) def name(self) -> str: return self.definition.expected_model_type def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: return [ self._substitute(wi, config=config) for wi in self.definition.pre_weights ] def layer_weights( self, index: int, config: PretrainedConfig ) -> Optional[List[WeightInfo]]: return [ self._substitute(wi, config=config, layer_idx=index) for wi in self.definition.layer_templates.weights ] def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: return [ self._substitute(wi, config=config) for wi in self.definition.post_weights ] def sliceable(self) -> bool: return True def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]: res = [] for s in self.definition.procedural_spaces or []: res.append(self._substitute(s, config=config)) for idx in range(self.num_layers(config)): for s in self.definition.layer_templates.procedural_spaces or []: res.append(self._substitute(s, config=config, layer_idx=idx)) return res def has_defined_spaces(self) -> bool: if ( self.definition.procedural_spaces or self.definition.layer_templates.procedural_spaces ): return True for wi in ( self.definition.layer_templates.weights + self.definition.pre_weights + self.definition.post_weights ): if wi.input_space or wi.output_space: return True return False def num_layers_config_key(self) -> str: return self.definition.num_layers_config_key class MixtralTensorNames(ArchitectureInfo, BaseModel): ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM" num_local_experts: int def name(self) -> str: return "mixtral" @classmethod def from_config(cls, config: PretrainedConfig): return MixtralTensorNames(num_local_experts=config.num_local_experts) def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: return MISTRAL_INFO.pre_weights(config) def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: return MISTRAL_INFO.post_weights(config) def num_layers_config_key(self) -> str: return MISTRAL_INFO.num_layers_config_key() def layer_weights( self, index: int, config: PretrainedConfig ) -> Optional[List[WeightInfo]]: num_experts = self.num_local_experts prefix = f"model.layers.{index}" tensor_names = [] for expert_idx in range(num_experts): for param in ("w1", "w2", "w3"): tensor_names.append( prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight" ) tensor_names.append(prefix + ".block_sparse_moe.gate.weight") res = [] for name in tensor_names: res.append(WeightInfo(name=name)) for weight_info in MISTRAL_INFO.layer_weights(index, config): if ".mlp." in weight_info.name: continue res.append(weight_info) return res def sliceable(self) -> bool: return True def has_defined_spaces(self) -> bool: return False def _load_json_arch(name: str) -> JsonArchitectureInfo: text = importlib.resources.read_text(mergekit._data.architectures, name) return JsonArchitectureInfo( definition=JSONArchitectureDefinition.model_validate_json(text) ) def _load_all_architectures() -> ( Tuple[List[JsonArchitectureInfo], Dict[str, List[JsonArchitectureInfo]]] ): architectures: List[JsonArchitectureInfo] = [] for f in importlib.resources.contents(mergekit._data.architectures): if f.lower().endswith(".json"): architectures.append(_load_json_arch(f)) name_to_arch: Dict[str, List[JsonArchitectureInfo]] = {} for arch_info in architectures: for name in arch_info.definition.architectures: name_to_arch[name] = name_to_arch.get(name, []) name_to_arch[name].append(arch_info) return architectures, name_to_arch JSON_ARCHITECTURES, NAME_TO_ARCH = _load_all_architectures() MISTRAL_INFO = _load_json_arch("mistral.json") def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: if len(config.architectures) != 1: raise RuntimeError("More than one architecture in config?") arch_name = config.architectures[0] if arch_name == MixtralTensorNames.ARCHITECTURE_NAME: return MixtralTensorNames.from_config(config) if arch_name not in NAME_TO_ARCH: raise RuntimeError(f"Unsupported architecture {arch_name}") candidates = list(NAME_TO_ARCH[arch_name]) if len(candidates) == 1: return candidates[0] for c in candidates: if c.definition.expected_model_type == config.model_type: return c raise RuntimeError( f"Unsupported model_type {config.model_type} for architecture {arch_name}" )