add lion-pytorch optimizer (#1299) [skip ci]
Browse files* add lion-pytorch optimizer
* update pydantic to support lion optimizer
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
setup.py
CHANGED
@@ -18,6 +18,7 @@ def parse_requirements():
|
|
18 |
or "flash-attention" in line
|
19 |
or "deepspeed" in line
|
20 |
or "mamba-ssm" in line
|
|
|
21 |
)
|
22 |
if line.startswith("--extra-index-url"):
|
23 |
# Handle custom index URLs
|
@@ -85,5 +86,8 @@ setup(
|
|
85 |
"mlflow": [
|
86 |
"mlflow",
|
87 |
],
|
|
|
|
|
|
|
88 |
},
|
89 |
)
|
|
|
18 |
or "flash-attention" in line
|
19 |
or "deepspeed" in line
|
20 |
or "mamba-ssm" in line
|
21 |
+
or "lion-pytorch" in line
|
22 |
)
|
23 |
if line.startswith("--extra-index-url"):
|
24 |
# Handle custom index URLs
|
|
|
86 |
"mlflow": [
|
87 |
"mlflow",
|
88 |
],
|
89 |
+
"lion-pytorch": [
|
90 |
+
"lion-pytorch==0.1.2",
|
91 |
+
],
|
92 |
},
|
93 |
)
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -970,19 +970,43 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
970 |
"neftune_noise_alpha"
|
971 |
] = self.cfg.neftune_noise_alpha
|
972 |
|
973 |
-
training_args = (
|
974 |
-
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
975 |
-
**training_arguments_kwargs,
|
976 |
-
)
|
977 |
-
)
|
978 |
-
training_args = self.hook_post_create_training_args(training_args)
|
979 |
trainer_kwargs = {}
|
980 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
981 |
if self.cfg.optimizer == "adamw_anyprecision":
|
982 |
if Path(self.cfg.torchdistx_path).exists():
|
983 |
sys.path.append(self.cfg.torchdistx_path)
|
984 |
importlib.import_module("torchdistx")
|
985 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
986 |
data_collator_kwargs = {
|
987 |
"padding": True, # True/"longest" is the default
|
988 |
}
|
|
|
970 |
"neftune_noise_alpha"
|
971 |
] = self.cfg.neftune_noise_alpha
|
972 |
|
|
|
|
|
|
|
|
|
|
|
|
|
973 |
trainer_kwargs = {}
|
974 |
|
975 |
+
if self.cfg.optimizer == "lion_pytorch":
|
976 |
+
from lion_pytorch import Lion
|
977 |
+
|
978 |
+
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
|
979 |
+
if "weight_decay" in training_arguments_kwargs:
|
980 |
+
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
|
981 |
+
|
982 |
+
if (
|
983 |
+
"adam_beta1" in training_arguments_kwargs
|
984 |
+
and "adam_beta2" in training_arguments_kwargs
|
985 |
+
):
|
986 |
+
lion_kwargs["betas"] = (
|
987 |
+
training_arguments_kwargs["adam_beta1"],
|
988 |
+
training_arguments_kwargs["adam_beta2"],
|
989 |
+
)
|
990 |
+
|
991 |
+
trainer_kwargs["optimizers"] = (
|
992 |
+
Lion(params=self.model.parameters(), **lion_kwargs),
|
993 |
+
None,
|
994 |
+
)
|
995 |
+
# Set default so transformers doesn't throw
|
996 |
+
training_arguments_kwargs["optim"] = "adamw_hf"
|
997 |
+
|
998 |
if self.cfg.optimizer == "adamw_anyprecision":
|
999 |
if Path(self.cfg.torchdistx_path).exists():
|
1000 |
sys.path.append(self.cfg.torchdistx_path)
|
1001 |
importlib.import_module("torchdistx")
|
1002 |
|
1003 |
+
training_args = (
|
1004 |
+
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
1005 |
+
**training_arguments_kwargs,
|
1006 |
+
)
|
1007 |
+
)
|
1008 |
+
training_args = self.hook_post_create_training_args(training_args)
|
1009 |
+
|
1010 |
data_collator_kwargs = {
|
1011 |
"padding": True, # True/"longest" is the default
|
1012 |
}
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -263,7 +263,7 @@ class HyperparametersConfig(BaseModel):
|
|
263 |
|
264 |
learning_rate: Union[str, float]
|
265 |
weight_decay: Optional[float] = None
|
266 |
-
optimizer: Optional[OptimizerNames] = None
|
267 |
torchdistx_path: Optional[str] = None
|
268 |
lr_scheduler: Optional[SchedulerType] = None
|
269 |
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
|
|
263 |
|
264 |
learning_rate: Union[str, float]
|
265 |
weight_decay: Optional[float] = None
|
266 |
+
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
|
267 |
torchdistx_path: Optional[str] = None
|
268 |
lr_scheduler: Optional[SchedulerType] = None
|
269 |
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|