winglian commited on
Commit
d7057cc
·
unverified ·
1 Parent(s): 768d348

paired kto support (#1069)

Browse files
README.md CHANGED
@@ -595,6 +595,9 @@ datasets:
595
  # For `completion` datsets only, uses the provided field instead of `text` column
596
  field:
597
 
 
 
 
598
  # Saves the desired chat template to the tokenizer_config.json for easier inferencing
599
  # Currently supports chatml and inst (mistral/mixtral)
600
  chat_template: chatml
 
595
  # For `completion` datsets only, uses the provided field instead of `text` column
596
  field:
597
 
598
+ # use RL training: dpo, ipo, kto_pair
599
+ rl:
600
+
601
  # Saves the desired chat template to the tokenizer_config.json for easier inferencing
602
  # Currently supports chatml and inst (mistral/mixtral)
603
  chat_template: chatml
requirements.txt CHANGED
@@ -40,4 +40,4 @@ s3fs
40
  gcsfs
41
  # adlfs
42
 
43
- trl @ git+https://github.com/huggingface/trl.git@main
 
40
  gcsfs
41
  # adlfs
42
 
43
+ trl>=0.7.9
src/axolotl/core/trainer_builder.py CHANGED
@@ -927,6 +927,8 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
927
  dpo_trainer_kwargs["loss_type"] = "ipo"
928
  if self.cfg.dpo_label_smoothing:
929
  dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
 
 
930
 
931
  dpo_trainer = DPOTrainer(
932
  self.model,
 
927
  dpo_trainer_kwargs["loss_type"] = "ipo"
928
  if self.cfg.dpo_label_smoothing:
929
  dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
930
+ elif self.cfg.rl == "kto_pair":
931
+ dpo_trainer_kwargs["loss_type"] = "kto_pair"
932
 
933
  dpo_trainer = DPOTrainer(
934
  self.model,