Philip May commited on
Commit
13eea21
·
unverified ·
1 Parent(s): 1072f28

Add more save strategies for DPO training. (#1255)

Browse files

* Set save_strategy and save_steps in HFDPOTrainerBuilder

* fix doublicate save_steps

Files changed (1) hide show
  1. src/axolotl/core/trainer_builder.py +10 -2
src/axolotl/core/trainer_builder.py CHANGED
@@ -1096,13 +1096,21 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
1096
  "use_reentrant": False
1097
  }
1098
 
 
 
 
 
 
 
 
 
 
 
1099
  training_args = TrainingArguments(
1100
  per_device_train_batch_size=self.cfg.micro_batch_size,
1101
  max_steps=self.cfg.max_steps or total_num_steps,
1102
  gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
1103
  learning_rate=self.cfg.learning_rate,
1104
- save_strategy="steps",
1105
- save_steps=self.cfg.save_steps,
1106
  output_dir=self.cfg.output_dir,
1107
  warmup_steps=self.cfg.warmup_steps,
1108
  logging_first_step=True,
 
1096
  "use_reentrant": False
1097
  }
1098
 
1099
+ # set save_strategy and save_steps
1100
+ if self.cfg.save_steps:
1101
+ training_args_kwargs["save_strategy"] = "steps"
1102
+ training_args_kwargs["save_steps"] = self.cfg.save_steps
1103
+ elif self.cfg.save_strategy:
1104
+ training_args_kwargs["save_strategy"] = self.cfg.save_strategy
1105
+ else:
1106
+ # default to saving each epoch if not defined
1107
+ training_args_kwargs["save_strategy"] = "epoch"
1108
+
1109
  training_args = TrainingArguments(
1110
  per_device_train_batch_size=self.cfg.micro_batch_size,
1111
  max_steps=self.cfg.max_steps or total_num_steps,
1112
  gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
1113
  learning_rate=self.cfg.learning_rate,
 
 
1114
  output_dir=self.cfg.output_dir,
1115
  warmup_steps=self.cfg.warmup_steps,
1116
  logging_first_step=True,