|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
import torch |
|
|
|
from mergekit.architecture import WeightInfo |
|
from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes |
|
from mergekit.graph import Task |
|
from mergekit.io.tasks import GatherTensors |
|
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod |
|
|
|
|
|
class ModelStockMergeTask(Task[torch.Tensor]): |
|
gather_tensors: GatherTensors |
|
base_model: ModelReference |
|
parameter_name: str |
|
filter_wise: bool = False |
|
|
|
def uses_accelerator(self) -> bool: |
|
return True |
|
|
|
def arguments(self) -> Dict[str, Task]: |
|
return {"tensors": self.gather_tensors} |
|
|
|
def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: |
|
if len(tensors) == 1 and self.base_model in tensors: |
|
return tensors[self.base_model] |
|
if len(tensors) < 3: |
|
raise ValueError( |
|
"ModelStockMerge requires at least 3 models (base plus two+ others)" |
|
) |
|
|
|
w_0, ws = self.get_rectified_weights(tensors) |
|
out_shape = w_0.shape |
|
|
|
if self.filter_wise: |
|
if w_0.dim() == 1: |
|
|
|
w_0 = w_0.unsqueeze(0) |
|
ws = [w.unsqueeze(0) for w in ws] |
|
else: |
|
w_0 = w_0.view(-1) |
|
ws = [w.view(-1) for w in ws] |
|
|
|
offsets = [w - w_0 for w in ws] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cos_thetas = [] |
|
for i, w_0_offset in enumerate(offsets): |
|
for j in range(i + 1, len(offsets)): |
|
w_1_offset = offsets[j] |
|
|
|
norm_product = torch.norm(w_0_offset, dim=-1) * torch.norm( |
|
w_1_offset, dim=-1 |
|
) |
|
cos_theta = ( |
|
(w_0_offset * w_1_offset).sum(dim=-1) / norm_product.clamp(min=1e-6) |
|
).clamp(-1, 1) |
|
cos_thetas.append(cos_theta) |
|
|
|
cos_theta = torch.stack(cos_thetas).mean(dim=0).unsqueeze(-1) |
|
N = len(ws) |
|
t = (N * cos_theta) / (1 + (N - 1) * cos_theta) |
|
|
|
w_avg = sum(ws) / len(ws) |
|
w_h = t * w_avg + (1 - t) * w_0 |
|
|
|
return w_h.reshape(out_shape) |
|
|
|
def get_rectified_weights(self, tensors: Dict[ModelReference, torch.Tensor]): |
|
if self.base_model not in tensors: |
|
raise ValueError("Base model tensor not found") |
|
|
|
all_weights = [tensors[self.base_model]] + [ |
|
tensors[k] for k in tensors if k != self.base_model |
|
] |
|
rectify_embed_sizes(self.parameter_name, all_weights) |
|
w_0 = all_weights[0] |
|
ws = all_weights[1:] |
|
return w_0, ws |
|
|
|
|
|
class ModelStockMerge(MergeMethod): |
|
def parameters(self) -> List[ConfigParameterDef]: |
|
return [ |
|
ConfigParameterDef(name="filter_wise", required=False, default_value=False) |
|
] |
|
|
|
def make_task( |
|
self, |
|
*, |
|
output_weight: WeightInfo, |
|
tensors: GatherTensors, |
|
base_model: Optional[ModelReference], |
|
parameters: ImmutableMap[str, Any], |
|
**_kwargs, |
|
) -> Task: |
|
return ModelStockMergeTask( |
|
gather_tensors=tensors, |
|
base_model=base_model, |
|
parameter_name=output_weight.name, |
|
filter_wise=parameters["filter_wise"], |
|
) |
|
|