|
from typing import Dict, Union, Tuple, List, Any, Literal, Optional |
|
import torch |
|
import numpy as np |
|
|
|
from .utils import rescaled_random, magnitude, random_wo_rescaled |
|
|
|
|
|
class GTA: |
|
def __init__(self, sparsify_method=None, consensus_method=None, normalize=False): |
|
self.sparsify_method = sparsify_method |
|
self.consensus_method = consensus_method |
|
|
|
self.normalize = normalize |
|
|
|
def execute( |
|
self, |
|
weights, |
|
base, |
|
tensors, |
|
densities, |
|
**_kwargs, |
|
) -> torch.Tensor: |
|
|
|
densities = [densities for _ in range(len(tensors))] |
|
|
|
assert len(densities) == len(weights) == len(tensors) |
|
deltas, base = get_task_vectors(base, tensors) |
|
if not deltas: |
|
return base |
|
|
|
|
|
if self.sparsify_method: |
|
if self.sparsify_method == 'magnitude': |
|
sparsify = magnitude |
|
elif self.sparsify_method == 'rescaled_random': |
|
sparsify = rescaled_random |
|
elif self.sparsify_method == 'random': |
|
sparsify = random_wo_rescaled |
|
else: |
|
raise NotImplementedError |
|
for i, delta in enumerate(deltas): |
|
deltas[i] = sparsify( |
|
delta, |
|
density=densities[i] |
|
) |
|
|
|
deltas = torch.stack(deltas, dim=0) |
|
weights = torch.tensor( |
|
[_ for _ in weights], dtype=deltas.dtype, device=deltas.device |
|
) |
|
while len(deltas.shape) > len(weights.shape): |
|
weights.unsqueeze_(-1) |
|
|
|
weighted_deltas = deltas * weights |
|
|
|
|
|
if self.consensus_method: |
|
mask_dtype = base.dtype |
|
mask = get_mask( |
|
weighted_deltas, |
|
method=self.consensus_method, |
|
mask_dtype=mask_dtype, |
|
) |
|
mixed_delta = (weighted_deltas * mask).sum(dim=0) |
|
divisor = (weights * mask).sum(dim=0) |
|
divisor[divisor == 0] = 1 |
|
else: |
|
mixed_delta = weighted_deltas.sum(dim=0) |
|
divisor = weights.sum(dim=0) |
|
divisor[divisor.abs() < 1e-8] = 1 |
|
|
|
if self.normalize: |
|
mixed_delta /= divisor |
|
|
|
return (base + mixed_delta).to(base.dtype) |
|
|
|
def get_task_vectors( |
|
base: Union[np.ndarray, torch.Tensor], |
|
tensors: Union[List[np.ndarray], List[torch.Tensor]], |
|
) -> Tuple[List[Dict[str, Any]], torch.Tensor]: |
|
|
|
res = [] |
|
for x in tensors: |
|
delta = x - base |
|
del x |
|
res.append(delta) |
|
return res, base |
|
|
|
def get_mask( |
|
delta: torch.Tensor, |
|
method: Literal["sum", "count"] = "sum", |
|
mask_dtype: Optional[torch.dtype] = None, |
|
): |
|
"""Returns a mask determining which delta vectors should be merged |
|
into the final model. |
|
|
|
For the methodology described in the TIES paper use 'sum'. For a |
|
simpler naive count of signs, use 'count'.""" |
|
if mask_dtype is None: |
|
mask_dtype = delta.dtype |
|
|
|
sign = delta.sign().to(mask_dtype) |
|
|
|
if method == "sum": |
|
sign_weight = delta.sum(dim=0) |
|
majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1 |
|
del sign_weight |
|
elif method == "count": |
|
majority_sign = (sign.sum(dim=0) >= 0).to(mask_dtype) * 2 - 1 |
|
else: |
|
raise RuntimeError(f'Unimplemented mask method "{method}"') |
|
|
|
return sign == majority_sign |