phi-2-merge / mergekit /merge_methods /tokenizer_permute.py
Shaleen123's picture
Upload folder using huggingface_hub
a164e13 verified
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
from typing import Any, Dict, List, Optional
import torch
from pydantic import BaseModel
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.slerp import slerp
from mergekit.tokenizer import BuildTokenizer, TokenizerInfo
class TokenizerPermutationMergeTask(Task[torch.Tensor]):
tokenizer_task: BuildTokenizer
gather_tensors: GatherTensors
base_model: Optional[ModelReference]
use_slerp: bool
slerp_t: Optional[float]
tensor_parameters: ImmutableMap[ModelReference, Any]
def uses_accelerator(self) -> bool:
return True
def arguments(self) -> Dict[str, Task]:
return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors}
def execute(
self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor]
) -> torch.Tensor:
if not tensors:
return None
if len(tensors) == 1:
return list(tensors.values())[0]
if self.use_slerp and self.slerp_t is None:
raise RuntimeError("Must set t to use embed_slerp")
models = []
expanded = []
masks = []
weights = []
for model in tensors:
models.append(model)
x = tensors[model]
p = tokenizer_info.permutations[model]
xp = torch.zeros((len(p), x.shape[-1]), dtype=x.dtype, device=x.device)
mask = torch.zeros((len(p),), dtype=torch.bool, device=x.device)
for out_idx in p:
in_idx = p[out_idx]
if in_idx < 0:
continue
xp[out_idx, :] = x[in_idx, :]
mask[out_idx] = 1
expanded.append(xp)
masks.append(mask)
is_base = model == self.base_model
if self.use_slerp:
weight = (1.0 - self.slerp_t) if is_base else self.slerp_t
else:
weight = self.tensor_parameters[model]["weight"]
weights.append(weight)
expanded = torch.stack(expanded, dim=0)
masks = torch.stack(masks, dim=0).unsqueeze(-1)
weights = (
torch.tensor(weights, dtype=expanded.dtype, device=expanded.device)
.unsqueeze(-1)
.unsqueeze(-1)
)
total_weight = (masks * weights).sum(dim=0)
scale = 1 / total_weight
scale[total_weight.abs() < 1e-8] = 0
linear_merged = (expanded * weights * masks).sum(dim=0) * scale
if self.use_slerp:
if expanded.shape[0] != 2:
raise RuntimeError("SLERP takes exactly two models")
if models[0] == self.base_model:
v0 = expanded[0, ...]
v1 = expanded[1, ...]
else:
v0 = expanded[1, ...]
v1 = expanded[0, ...]
res = slerp(self.slerp_t, v0, v1)
need_linear = (masks.sum(dim=0) != 2).squeeze(dim=-1)
res[need_linear, :] = linear_merged[need_linear, :].to(
device=res.device, dtype=res.dtype
)
return res
return linear_merged
class TokenizerPermutationMerge(MergeMethod, BaseModel):
tokenizer_task: BuildTokenizer
def parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(name="t", required=False),
ConfigParameterDef(name="embed_slerp", required=False, default_value=False),
]
def tensor_parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(name="weight", required=False),
]
def make_task(
self,
*,
tensors: GatherTensors,
parameters: Dict[str, Any],
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
base_model: Optional[ModelReference],
**_kwargs,
) -> Task:
return TokenizerPermutationMergeTask(
base_model=base_model,
tokenizer_task=self.tokenizer_task,
gather_tensors=tensors,
use_slerp=parameters["embed_slerp"],
slerp_t=parameters["t"],
tensor_parameters=tensor_parameters,
)