Shaleen123's picture
Upload folder using huggingface_hub
a164e13 verified
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
from typing import Any, Dict, List
import torch
from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
class LinearMergeTask(Task[torch.Tensor]):
gather_tensors: GatherTensors
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]]
normalize: bool
parameter_name: str
def uses_accelerator(self) -> bool:
return True
def arguments(self) -> Dict[str, Task]:
return {"tensors": self.gather_tensors}
def execute(
self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs
) -> torch.Tensor:
keys = list(tensors.keys())
tensors = [tensors[key] for key in keys]
weights = [self.tensor_parameters[key]["weight"] for key in keys]
rectify_embed_sizes(self.parameter_name, tensors)
unique_shapes = set(t.shape for t in tensors)
if len(unique_shapes) != 1:
raise RuntimeError(
f"Tensor size mismatch for {self.parameter_name}, sizes: {list(unique_shapes)}"
)
tensors = torch.stack(tensors, dim=0)
weights = torch.tensor(weights, dtype=tensors.dtype, device=tensors.device)
while len(weights.shape) < len(tensors.shape):
weights.unsqueeze_(-1)
res = (weights * tensors).sum(dim=0)
if self.normalize:
res /= weights.sum(dim=0)
return res
class LinearMerge(MergeMethod):
def parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(name="normalize", required=False, default_value=True),
]
def tensor_parameters(self) -> List[ConfigParameterDef]:
return [ConfigParameterDef(name="weight", required=True)]
def make_task(
self,
*,
output_weight: WeightInfo,
tensors: GatherTensors,
parameters: Dict[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
**_kwargs,
) -> Task:
return LinearMergeTask(
gather_tensors=tensors,
tensor_parameters=tensor_parameters,
normalize=parameters["normalize"],
parameter_name=output_weight.name,
)