Philip May
commited on
Add more save strategies for DPO training. (#1255)
Browse files* Set save_strategy and save_steps in HFDPOTrainerBuilder
* fix doublicate save_steps
src/axolotl/core/trainer_builder.py
CHANGED
@@ -1096,13 +1096,21 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
1096 |
"use_reentrant": False
|
1097 |
}
|
1098 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1099 |
training_args = TrainingArguments(
|
1100 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
1101 |
max_steps=self.cfg.max_steps or total_num_steps,
|
1102 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
1103 |
learning_rate=self.cfg.learning_rate,
|
1104 |
-
save_strategy="steps",
|
1105 |
-
save_steps=self.cfg.save_steps,
|
1106 |
output_dir=self.cfg.output_dir,
|
1107 |
warmup_steps=self.cfg.warmup_steps,
|
1108 |
logging_first_step=True,
|
|
|
1096 |
"use_reentrant": False
|
1097 |
}
|
1098 |
|
1099 |
+
# set save_strategy and save_steps
|
1100 |
+
if self.cfg.save_steps:
|
1101 |
+
training_args_kwargs["save_strategy"] = "steps"
|
1102 |
+
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
1103 |
+
elif self.cfg.save_strategy:
|
1104 |
+
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
1105 |
+
else:
|
1106 |
+
# default to saving each epoch if not defined
|
1107 |
+
training_args_kwargs["save_strategy"] = "epoch"
|
1108 |
+
|
1109 |
training_args = TrainingArguments(
|
1110 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
1111 |
max_steps=self.cfg.max_steps or total_num_steps,
|
1112 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
1113 |
learning_rate=self.cfg.learning_rate,
|
|
|
|
|
1114 |
output_dir=self.cfg.output_dir,
|
1115 |
warmup_steps=self.cfg.warmup_steps,
|
1116 |
logging_first_step=True,
|