|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List |
|
|
|
import torch |
|
|
|
from mergekit.common import ImmutableMap, ModelReference |
|
from mergekit.graph import Task |
|
from mergekit.io.tasks import GatherTensors |
|
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod |
|
|
|
|
|
class PassthroughMergeTask(Task[torch.Tensor]): |
|
gather_tensors: GatherTensors |
|
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return {"tensors": self.gather_tensors} |
|
|
|
def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: |
|
if len(tensors) != 1: |
|
raise RuntimeError("Passthrough merge expects exactly one tensor") |
|
|
|
model, tensor = list(tensors.items())[0] |
|
scale = self.tensor_parameters[model].data.get("scale", None) |
|
if scale is not None: |
|
tensor = tensor * scale |
|
|
|
return tensor |
|
|
|
|
|
class PassthroughMerge(MergeMethod): |
|
def tensor_parameters(self) -> List[ConfigParameterDef]: |
|
return [ConfigParameterDef(name="scale", required=False, default_value=None)] |
|
|
|
def make_task( |
|
self, |
|
*, |
|
tensors: GatherTensors, |
|
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], |
|
**kwargs, |
|
) -> Task: |
|
return PassthroughMergeTask( |
|
gather_tensors=tensors, tensor_parameters=tensor_parameters |
|
) |
|
|