|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
|
|
|
import yaml |
|
from pydantic import BaseModel, model_validator |
|
from typing_extensions import TypeAlias |
|
|
|
from mergekit.common import ModelReference |
|
|
|
ScalarOrGradient: TypeAlias = Union[float, List[float]] |
|
|
|
|
|
class ConditionalParameter(BaseModel): |
|
value: ScalarOrGradient |
|
filter: Optional[str] = None |
|
|
|
|
|
ParameterSetting: TypeAlias = Union[ |
|
ConditionalParameter, List[ConditionalParameter], ScalarOrGradient |
|
] |
|
|
|
|
|
def evaluate_setting( |
|
tensor_name: str, setting: ParameterSetting, t: float = 0 |
|
) -> float: |
|
if isinstance(setting, (float, int, bool, str)): |
|
return setting |
|
elif isinstance(setting, list): |
|
if all(isinstance(e, (int, float)) for e in setting): |
|
scaled = t * (len(setting) - 1) |
|
i0 = int(scaled) |
|
i1 = min(len(setting) - 1, i0 + 1) |
|
frac = scaled - i0 |
|
|
|
return (1 - frac) * setting[i0] + frac * setting[i1] |
|
elif all(isinstance(e, (float, int, bool, str)) for e in setting): |
|
return setting[int(t * (len(setting) - 1))] |
|
else: |
|
for cond in setting: |
|
if ( |
|
(cond.filter is None) |
|
or (cond.filter == "*") |
|
or (tensor_name and cond.filter in tensor_name) |
|
): |
|
res = evaluate_setting(tensor_name, cond.value, t) |
|
return res |
|
else: |
|
raise RuntimeError(f"Unexpected setting value: {setting}") |
|
return None |
|
|
|
|
|
class InputSliceDefinition(BaseModel): |
|
model: ModelReference |
|
layer_range: Tuple[int, int] |
|
parameters: Optional[Dict[str, ParameterSetting]] = None |
|
|
|
|
|
class InputModelDefinition(BaseModel): |
|
model: ModelReference |
|
parameters: Optional[Dict[str, ParameterSetting]] = None |
|
|
|
|
|
class OutputSliceDefinition(BaseModel): |
|
sources: List[InputSliceDefinition] |
|
base_model: Optional[ModelReference] = None |
|
residual_weight: Optional[float] = None |
|
parameters: Optional[Dict[str, ParameterSetting]] = None |
|
|
|
|
|
class MergeConfiguration(BaseModel): |
|
merge_method: str |
|
slices: Optional[List[OutputSliceDefinition]] = None |
|
models: Optional[List[InputModelDefinition]] = None |
|
parameters: Optional[Dict[str, ParameterSetting]] = None |
|
base_model: Optional[ModelReference] = None |
|
dtype: Optional[str] = None |
|
tokenizer_source: Optional[str] = None |
|
|
|
def referenced_models(self) -> List[ModelReference]: |
|
models = set() |
|
if self.base_model: |
|
models.add(self.base_model) |
|
if self.models: |
|
for model_in in self.models: |
|
models.add(model_in.model) |
|
if self.slices: |
|
for s in self.slices: |
|
for src in s.sources: |
|
models.add(src.model) |
|
return list(models) |
|
|
|
@model_validator(mode="after") |
|
def validate_inputs(self): |
|
if ((not self.slices) and (not self.models)) or (self.slices and self.models): |
|
raise RuntimeError("Must specify either output slices or models to merge") |
|
return self |
|
|
|
def to_yaml(self) -> str: |
|
return yaml.dump( |
|
self.model_dump(exclude_defaults=True, mode="json"), |
|
Dumper=ConfigYamlDumper, |
|
).rstrip() |
|
|
|
|
|
class ConfigReader(BaseModel): |
|
config: MergeConfiguration |
|
t: float |
|
tensor_name: Optional[str] = None |
|
slice_out: Optional[OutputSliceDefinition] = None |
|
|
|
@property |
|
def base_model(self) -> Optional[ModelReference]: |
|
if self.slice_out and self.slice_out.base_model: |
|
res = self.slice_out.base_model |
|
else: |
|
res = self.config.base_model |
|
|
|
return res |
|
|
|
def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader": |
|
return ConfigReader( |
|
config=self.config, |
|
t=self.t, |
|
tensor_name=self.tensor_name, |
|
slice_out=slice, |
|
) |
|
|
|
def for_tensor(self, tensor_name: str) -> "ConfigReader": |
|
return ConfigReader( |
|
config=self.config, |
|
t=self.t, |
|
tensor_name=tensor_name, |
|
slice_out=self.slice_out, |
|
) |
|
|
|
def with_t(self, t: float) -> "ConfigReader": |
|
return ConfigReader( |
|
config=self.config, |
|
t=t, |
|
tensor_name=self.tensor_name, |
|
slice_out=self.slice_out, |
|
) |
|
|
|
def parameter( |
|
self, |
|
name: str, |
|
model: Optional[ModelReference] = None, |
|
default: Any = None, |
|
required: bool = False, |
|
) -> Any: |
|
if self.slice_out: |
|
if model: |
|
for s in self.slice_out.sources: |
|
if s.model == model and s.parameters and name in s.parameters: |
|
value = evaluate_setting( |
|
self.tensor_name, s.parameters[name], self.t |
|
) |
|
if value is not None: |
|
return value |
|
|
|
if self.slice_out.parameters and name in self.slice_out.parameters: |
|
value = evaluate_setting( |
|
self.tensor_name, self.slice_out.parameters[name], self.t |
|
) |
|
if value is not None: |
|
return value |
|
|
|
if self.config.parameters and name in self.config.parameters: |
|
value = evaluate_setting( |
|
self.tensor_name, |
|
self.config.parameters[name], |
|
self.t, |
|
) |
|
if value is not None: |
|
return value |
|
|
|
if required: |
|
path_paths = [str(s) for s in [model, self.tensor_name] if s] |
|
p = ".".join(path_paths) |
|
suffix = f" for {p}" if p else "" |
|
raise RuntimeError(f"Missing required parameter {name}{suffix}") |
|
return default |
|
|
|
|
|
class ConfigYamlDumper(yaml.Dumper): |
|
"""Custom YAML dumper to format lists of numbers in flow style.""" |
|
|
|
def represent_list(self, data: Iterable[Any]) -> yaml.SequenceNode: |
|
flow_style = all(isinstance(e, (int, float)) for e in data) |
|
return self.represent_sequence( |
|
"tag:yaml.org,2002:seq", data, flow_style=flow_style |
|
) |
|
|
|
|
|
ConfigYamlDumper.add_representer(list, ConfigYamlDumper.represent_list) |
|
|