|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from mergekit.merge_methods.base import MergeMethod |
|
from mergekit.merge_methods.generalized_task_arithmetic import ( |
|
ConsensusMethod, |
|
GeneralizedTaskArithmeticMerge, |
|
SparsificationMethod, |
|
) |
|
from mergekit.merge_methods.linear import LinearMerge |
|
from mergekit.merge_methods.model_stock import ModelStockMerge |
|
from mergekit.merge_methods.passthrough import PassthroughMerge |
|
from mergekit.merge_methods.slerp import SlerpMerge |
|
from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge |
|
|
|
|
|
def get(method: str) -> MergeMethod: |
|
if method == "linear": |
|
return LinearMerge() |
|
elif method == "slerp": |
|
return SlerpMerge() |
|
elif method == "passthrough": |
|
return PassthroughMerge() |
|
elif method == "task_arithmetic": |
|
return GeneralizedTaskArithmeticMerge( |
|
consensus_method=None, |
|
sparsification_method=None, |
|
default_normalize=False, |
|
) |
|
elif method == "ties": |
|
return GeneralizedTaskArithmeticMerge( |
|
consensus_method=ConsensusMethod.sum, |
|
sparsification_method=SparsificationMethod.magnitude, |
|
default_normalize=True, |
|
) |
|
elif method == "dare_ties": |
|
return GeneralizedTaskArithmeticMerge( |
|
consensus_method=ConsensusMethod.sum, |
|
sparsification_method=SparsificationMethod.rescaled_random, |
|
default_normalize=False, |
|
) |
|
elif method == "dare_linear": |
|
return GeneralizedTaskArithmeticMerge( |
|
consensus_method=None, |
|
sparsification_method=SparsificationMethod.rescaled_random, |
|
default_normalize=False, |
|
) |
|
elif method == "model_stock": |
|
return ModelStockMerge() |
|
raise RuntimeError(f"Unimplemented merge method {method}") |
|
|
|
|
|
__all__ = [ |
|
"MergeMethod", |
|
"get", |
|
"LinearMerge", |
|
"SlerpMerge", |
|
"PassthroughMerge", |
|
"GeneralizedTaskArithmeticMerge", |
|
"TokenizerPermutationMerge", |
|
] |
|
|