paired kto support (#1069)
Browse files- README.md +3 -0
- requirements.txt +1 -1
- src/axolotl/core/trainer_builder.py +2 -0
README.md
CHANGED
@@ -595,6 +595,9 @@ datasets:
|
|
595 |
# For `completion` datsets only, uses the provided field instead of `text` column
|
596 |
field:
|
597 |
|
|
|
|
|
|
|
598 |
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
599 |
# Currently supports chatml and inst (mistral/mixtral)
|
600 |
chat_template: chatml
|
|
|
595 |
# For `completion` datsets only, uses the provided field instead of `text` column
|
596 |
field:
|
597 |
|
598 |
+
# use RL training: dpo, ipo, kto_pair
|
599 |
+
rl:
|
600 |
+
|
601 |
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
602 |
# Currently supports chatml and inst (mistral/mixtral)
|
603 |
chat_template: chatml
|
requirements.txt
CHANGED
@@ -40,4 +40,4 @@ s3fs
|
|
40 |
gcsfs
|
41 |
# adlfs
|
42 |
|
43 |
-
trl
|
|
|
40 |
gcsfs
|
41 |
# adlfs
|
42 |
|
43 |
+
trl>=0.7.9
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -927,6 +927,8 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
927 |
dpo_trainer_kwargs["loss_type"] = "ipo"
|
928 |
if self.cfg.dpo_label_smoothing:
|
929 |
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
|
|
|
|
930 |
|
931 |
dpo_trainer = DPOTrainer(
|
932 |
self.model,
|
|
|
927 |
dpo_trainer_kwargs["loss_type"] = "ipo"
|
928 |
if self.cfg.dpo_label_smoothing:
|
929 |
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
930 |
+
elif self.cfg.rl == "kto_pair":
|
931 |
+
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
932 |
|
933 |
dpo_trainer = DPOTrainer(
|
934 |
self.model,
|