|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List |
|
|
|
import torch |
|
|
|
from mergekit.architecture import WeightInfo |
|
from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes |
|
from mergekit.graph import Task |
|
from mergekit.io.tasks import GatherTensors |
|
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod |
|
|
|
|
|
class LinearMergeTask(Task[torch.Tensor]): |
|
gather_tensors: GatherTensors |
|
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] |
|
normalize: bool |
|
parameter_name: str |
|
|
|
def uses_accelerator(self) -> bool: |
|
return True |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return {"tensors": self.gather_tensors} |
|
|
|
def execute( |
|
self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs |
|
) -> torch.Tensor: |
|
keys = list(tensors.keys()) |
|
|
|
tensors = [tensors[key] for key in keys] |
|
weights = [self.tensor_parameters[key]["weight"] for key in keys] |
|
|
|
rectify_embed_sizes(self.parameter_name, tensors) |
|
|
|
unique_shapes = set(t.shape for t in tensors) |
|
if len(unique_shapes) != 1: |
|
raise RuntimeError( |
|
f"Tensor size mismatch for {self.parameter_name}, sizes: {list(unique_shapes)}" |
|
) |
|
|
|
tensors = torch.stack(tensors, dim=0) |
|
weights = torch.tensor(weights, dtype=tensors.dtype, device=tensors.device) |
|
while len(weights.shape) < len(tensors.shape): |
|
weights.unsqueeze_(-1) |
|
|
|
res = (weights * tensors).sum(dim=0) |
|
if self.normalize: |
|
res /= weights.sum(dim=0) |
|
|
|
return res |
|
|
|
|
|
class LinearMerge(MergeMethod): |
|
def parameters(self) -> List[ConfigParameterDef]: |
|
return [ |
|
ConfigParameterDef(name="normalize", required=False, default_value=True), |
|
] |
|
|
|
def tensor_parameters(self) -> List[ConfigParameterDef]: |
|
return [ConfigParameterDef(name="weight", required=True)] |
|
|
|
def make_task( |
|
self, |
|
*, |
|
output_weight: WeightInfo, |
|
tensors: GatherTensors, |
|
parameters: Dict[str, Any], |
|
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], |
|
**_kwargs, |
|
) -> Task: |
|
return LinearMergeTask( |
|
gather_tensors=tensors, |
|
tensor_parameters=tensor_parameters, |
|
normalize=parameters["normalize"], |
|
parameter_name=output_weight.name, |
|
) |
|
|