|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
import torch |
|
from pydantic import BaseModel |
|
|
|
from mergekit.common import ImmutableMap, ModelReference |
|
from mergekit.graph import Task |
|
from mergekit.io.tasks import GatherTensors |
|
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod |
|
from mergekit.merge_methods.slerp import slerp |
|
from mergekit.tokenizer import BuildTokenizer, TokenizerInfo |
|
|
|
|
|
class TokenizerPermutationMergeTask(Task[torch.Tensor]): |
|
tokenizer_task: BuildTokenizer |
|
gather_tensors: GatherTensors |
|
base_model: Optional[ModelReference] |
|
use_slerp: bool |
|
slerp_t: Optional[float] |
|
tensor_parameters: ImmutableMap[ModelReference, Any] |
|
|
|
def uses_accelerator(self) -> bool: |
|
return True |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors} |
|
|
|
def execute( |
|
self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor] |
|
) -> torch.Tensor: |
|
if not tensors: |
|
return None |
|
if len(tensors) == 1: |
|
return list(tensors.values())[0] |
|
|
|
if self.use_slerp and self.slerp_t is None: |
|
raise RuntimeError("Must set t to use embed_slerp") |
|
|
|
models = [] |
|
expanded = [] |
|
masks = [] |
|
weights = [] |
|
for model in tensors: |
|
models.append(model) |
|
|
|
x = tensors[model] |
|
p = tokenizer_info.permutations[model] |
|
|
|
xp = torch.zeros((len(p), x.shape[-1]), dtype=x.dtype, device=x.device) |
|
mask = torch.zeros((len(p),), dtype=torch.bool, device=x.device) |
|
for out_idx in p: |
|
in_idx = p[out_idx] |
|
if in_idx < 0: |
|
continue |
|
|
|
xp[out_idx, :] = x[in_idx, :] |
|
mask[out_idx] = 1 |
|
|
|
expanded.append(xp) |
|
masks.append(mask) |
|
|
|
is_base = model == self.base_model |
|
if self.use_slerp: |
|
weight = (1.0 - self.slerp_t) if is_base else self.slerp_t |
|
else: |
|
weight = self.tensor_parameters[model]["weight"] |
|
|
|
weights.append(weight) |
|
|
|
expanded = torch.stack(expanded, dim=0) |
|
masks = torch.stack(masks, dim=0).unsqueeze(-1) |
|
weights = ( |
|
torch.tensor(weights, dtype=expanded.dtype, device=expanded.device) |
|
.unsqueeze(-1) |
|
.unsqueeze(-1) |
|
) |
|
|
|
total_weight = (masks * weights).sum(dim=0) |
|
scale = 1 / total_weight |
|
scale[total_weight.abs() < 1e-8] = 0 |
|
|
|
linear_merged = (expanded * weights * masks).sum(dim=0) * scale |
|
|
|
if self.use_slerp: |
|
if expanded.shape[0] != 2: |
|
raise RuntimeError("SLERP takes exactly two models") |
|
|
|
if models[0] == self.base_model: |
|
v0 = expanded[0, ...] |
|
v1 = expanded[1, ...] |
|
else: |
|
v0 = expanded[1, ...] |
|
v1 = expanded[0, ...] |
|
|
|
res = slerp(self.slerp_t, v0, v1) |
|
need_linear = (masks.sum(dim=0) != 2).squeeze(dim=-1) |
|
res[need_linear, :] = linear_merged[need_linear, :].to( |
|
device=res.device, dtype=res.dtype |
|
) |
|
return res |
|
|
|
return linear_merged |
|
|
|
|
|
class TokenizerPermutationMerge(MergeMethod, BaseModel): |
|
tokenizer_task: BuildTokenizer |
|
|
|
def parameters(self) -> List[ConfigParameterDef]: |
|
return [ |
|
ConfigParameterDef(name="t", required=False), |
|
ConfigParameterDef(name="embed_slerp", required=False, default_value=False), |
|
] |
|
|
|
def tensor_parameters(self) -> List[ConfigParameterDef]: |
|
return [ |
|
ConfigParameterDef(name="weight", required=False), |
|
] |
|
|
|
def make_task( |
|
self, |
|
*, |
|
tensors: GatherTensors, |
|
parameters: Dict[str, Any], |
|
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], |
|
base_model: Optional[ModelReference], |
|
**_kwargs, |
|
) -> Task: |
|
return TokenizerPermutationMergeTask( |
|
base_model=base_model, |
|
tokenizer_task=self.tokenizer_task, |
|
gather_tensors=tensors, |
|
use_slerp=parameters["embed_slerp"], |
|
slerp_t=parameters["t"], |
|
tensor_parameters=tensor_parameters, |
|
) |
|
|