thepowerfuldeez commited on
Commit
a27d5e1
1 Parent(s): 6299eb5

enable loraplus setting for dpo trainer (#1646)

Browse files
Files changed (1) hide show
  1. src/axolotl/core/trainer_builder.py +37 -1
src/axolotl/core/trainer_builder.py CHANGED
@@ -798,6 +798,40 @@ class AxolotlDPOTrainer(DPOTrainer):
798
 
799
  tag_names = ["axolotl", "dpo"]
800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
  @wraps(DPOTrainer.push_to_hub)
802
  def push_to_hub(self, *args, **kwargs) -> str:
803
  """
@@ -1483,6 +1517,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1483
  if self.cfg.bf16 or self.cfg.bfloat16:
1484
  training_args_kwargs["bf16"] = True
1485
 
 
 
1486
  training_args_kwargs["lr_scheduler_type"] = (
1487
  self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
1488
  )
@@ -1535,7 +1571,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1535
  # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
1536
  training_args_kwargs["beta"] = self.cfg.orpo_alpha
1537
 
1538
- training_args_cls = TrainingArguments
1539
  if self.cfg.rl == "orpo":
1540
  training_args_cls = ORPOConfig
1541
  training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
 
798
 
799
  tag_names = ["axolotl", "dpo"]
800
 
801
+ def __init__(self, *args, **kwargs):
802
+ super().__init__(*args, **kwargs)
803
+ self.optimizer = None
804
+
805
+ def create_optimizer(self):
806
+ if self.args.loraplus_lr_ratio is None:
807
+ return super().create_optimizer()
808
+
809
+ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
810
+ if self.optimizer is None: # pylint: disable=access-member-before-definition
811
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
812
+ self.args,
813
+ opt_model,
814
+ )
815
+
816
+ loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
817
+ if loraplus_lr_ratio:
818
+ print("Using lora+")
819
+ loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
820
+ self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
821
+ opt_model,
822
+ optimizer_cls,
823
+ optimizer_kwargs,
824
+ loraplus_lr_ratio,
825
+ loraplus_lr_embedding,
826
+ )
827
+
828
+ if is_sagemaker_mp_enabled():
829
+ self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
830
+ self.optimizer
831
+ )
832
+
833
+ return self.optimizer
834
+
835
  @wraps(DPOTrainer.push_to_hub)
836
  def push_to_hub(self, *args, **kwargs) -> str:
837
  """
 
1517
  if self.cfg.bf16 or self.cfg.bfloat16:
1518
  training_args_kwargs["bf16"] = True
1519
 
1520
+ training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
1521
+ training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
1522
  training_args_kwargs["lr_scheduler_type"] = (
1523
  self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
1524
  )
 
1571
  # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
1572
  training_args_kwargs["beta"] = self.cfg.orpo_alpha
1573
 
1574
+ training_args_cls = AxolotlTrainingArguments
1575
  if self.cfg.rl == "orpo":
1576
  training_args_cls = ORPOConfig
1577
  training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes