Maxime winglian commited on
Commit
1648279
1 Parent(s): f30d062

add lion-pytorch optimizer (#1299) [skip ci]

Browse files

* add lion-pytorch optimizer

* update pydantic to support lion optimizer

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

setup.py CHANGED
@@ -18,6 +18,7 @@ def parse_requirements():
18
  or "flash-attention" in line
19
  or "deepspeed" in line
20
  or "mamba-ssm" in line
 
21
  )
22
  if line.startswith("--extra-index-url"):
23
  # Handle custom index URLs
@@ -85,5 +86,8 @@ setup(
85
  "mlflow": [
86
  "mlflow",
87
  ],
 
 
 
88
  },
89
  )
 
18
  or "flash-attention" in line
19
  or "deepspeed" in line
20
  or "mamba-ssm" in line
21
+ or "lion-pytorch" in line
22
  )
23
  if line.startswith("--extra-index-url"):
24
  # Handle custom index URLs
 
86
  "mlflow": [
87
  "mlflow",
88
  ],
89
+ "lion-pytorch": [
90
+ "lion-pytorch==0.1.2",
91
+ ],
92
  },
93
  )
src/axolotl/core/trainer_builder.py CHANGED
@@ -970,19 +970,43 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
970
  "neftune_noise_alpha"
971
  ] = self.cfg.neftune_noise_alpha
972
 
973
- training_args = (
974
- AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
975
- **training_arguments_kwargs,
976
- )
977
- )
978
- training_args = self.hook_post_create_training_args(training_args)
979
  trainer_kwargs = {}
980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981
  if self.cfg.optimizer == "adamw_anyprecision":
982
  if Path(self.cfg.torchdistx_path).exists():
983
  sys.path.append(self.cfg.torchdistx_path)
984
  importlib.import_module("torchdistx")
985
 
 
 
 
 
 
 
 
986
  data_collator_kwargs = {
987
  "padding": True, # True/"longest" is the default
988
  }
 
970
  "neftune_noise_alpha"
971
  ] = self.cfg.neftune_noise_alpha
972
 
 
 
 
 
 
 
973
  trainer_kwargs = {}
974
 
975
+ if self.cfg.optimizer == "lion_pytorch":
976
+ from lion_pytorch import Lion
977
+
978
+ lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
979
+ if "weight_decay" in training_arguments_kwargs:
980
+ lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
981
+
982
+ if (
983
+ "adam_beta1" in training_arguments_kwargs
984
+ and "adam_beta2" in training_arguments_kwargs
985
+ ):
986
+ lion_kwargs["betas"] = (
987
+ training_arguments_kwargs["adam_beta1"],
988
+ training_arguments_kwargs["adam_beta2"],
989
+ )
990
+
991
+ trainer_kwargs["optimizers"] = (
992
+ Lion(params=self.model.parameters(), **lion_kwargs),
993
+ None,
994
+ )
995
+ # Set default so transformers doesn't throw
996
+ training_arguments_kwargs["optim"] = "adamw_hf"
997
+
998
  if self.cfg.optimizer == "adamw_anyprecision":
999
  if Path(self.cfg.torchdistx_path).exists():
1000
  sys.path.append(self.cfg.torchdistx_path)
1001
  importlib.import_module("torchdistx")
1002
 
1003
+ training_args = (
1004
+ AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
1005
+ **training_arguments_kwargs,
1006
+ )
1007
+ )
1008
+ training_args = self.hook_post_create_training_args(training_args)
1009
+
1010
  data_collator_kwargs = {
1011
  "padding": True, # True/"longest" is the default
1012
  }
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -263,7 +263,7 @@ class HyperparametersConfig(BaseModel):
263
 
264
  learning_rate: Union[str, float]
265
  weight_decay: Optional[float] = None
266
- optimizer: Optional[OptimizerNames] = None
267
  torchdistx_path: Optional[str] = None
268
  lr_scheduler: Optional[SchedulerType] = None
269
  lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
 
263
 
264
  learning_rate: Union[str, float]
265
  weight_decay: Optional[float] = None
266
+ optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
267
  torchdistx_path: Optional[str] = None
268
  lr_scheduler: Optional[SchedulerType] = None
269
  lr_scheduler_kwargs: Optional[Dict[str, Any]] = None