File size: 2,103 Bytes
a164e13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
from typing import Any, Dict, List
import torch
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
class PassthroughMergeTask(Task[torch.Tensor]):
gather_tensors: GatherTensors
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]]
def arguments(self) -> Dict[str, Task]:
return {"tensors": self.gather_tensors}
def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor:
if len(tensors) != 1:
raise RuntimeError("Passthrough merge expects exactly one tensor")
model, tensor = list(tensors.items())[0]
scale = self.tensor_parameters[model].data.get("scale", None)
if scale is not None:
tensor = tensor * scale
return tensor
class PassthroughMerge(MergeMethod):
def tensor_parameters(self) -> List[ConfigParameterDef]:
return [ConfigParameterDef(name="scale", required=False, default_value=None)]
def make_task(
self,
*,
tensors: GatherTensors,
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
**kwargs,
) -> Task:
return PassthroughMergeTask(
gather_tensors=tensors, tensor_parameters=tensor_parameters
)
|