thepowerfuldeez
commited on
Commit
•
a27d5e1
1
Parent(s):
6299eb5
enable loraplus setting for dpo trainer (#1646)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
@@ -798,6 +798,40 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|
798 |
|
799 |
tag_names = ["axolotl", "dpo"]
|
800 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
801 |
@wraps(DPOTrainer.push_to_hub)
|
802 |
def push_to_hub(self, *args, **kwargs) -> str:
|
803 |
"""
|
@@ -1483,6 +1517,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1483 |
if self.cfg.bf16 or self.cfg.bfloat16:
|
1484 |
training_args_kwargs["bf16"] = True
|
1485 |
|
|
|
|
|
1486 |
training_args_kwargs["lr_scheduler_type"] = (
|
1487 |
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
1488 |
)
|
@@ -1535,7 +1571,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1535 |
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
1536 |
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
1537 |
|
1538 |
-
training_args_cls =
|
1539 |
if self.cfg.rl == "orpo":
|
1540 |
training_args_cls = ORPOConfig
|
1541 |
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
|
798 |
|
799 |
tag_names = ["axolotl", "dpo"]
|
800 |
|
801 |
+
def __init__(self, *args, **kwargs):
|
802 |
+
super().__init__(*args, **kwargs)
|
803 |
+
self.optimizer = None
|
804 |
+
|
805 |
+
def create_optimizer(self):
|
806 |
+
if self.args.loraplus_lr_ratio is None:
|
807 |
+
return super().create_optimizer()
|
808 |
+
|
809 |
+
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
810 |
+
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
811 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
812 |
+
self.args,
|
813 |
+
opt_model,
|
814 |
+
)
|
815 |
+
|
816 |
+
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
817 |
+
if loraplus_lr_ratio:
|
818 |
+
print("Using lora+")
|
819 |
+
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
820 |
+
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
821 |
+
opt_model,
|
822 |
+
optimizer_cls,
|
823 |
+
optimizer_kwargs,
|
824 |
+
loraplus_lr_ratio,
|
825 |
+
loraplus_lr_embedding,
|
826 |
+
)
|
827 |
+
|
828 |
+
if is_sagemaker_mp_enabled():
|
829 |
+
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
830 |
+
self.optimizer
|
831 |
+
)
|
832 |
+
|
833 |
+
return self.optimizer
|
834 |
+
|
835 |
@wraps(DPOTrainer.push_to_hub)
|
836 |
def push_to_hub(self, *args, **kwargs) -> str:
|
837 |
"""
|
|
|
1517 |
if self.cfg.bf16 or self.cfg.bfloat16:
|
1518 |
training_args_kwargs["bf16"] = True
|
1519 |
|
1520 |
+
training_args_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
1521 |
+
training_args_kwargs["loraplus_lr_embedding"] = self.cfg.loraplus_lr_embedding
|
1522 |
training_args_kwargs["lr_scheduler_type"] = (
|
1523 |
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
1524 |
)
|
|
|
1571 |
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
1572 |
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
1573 |
|
1574 |
+
training_args_cls = AxolotlTrainingArguments
|
1575 |
if self.cfg.rl == "orpo":
|
1576 |
training_args_cls = ORPOConfig
|
1577 |
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|