fix eval steps and strategy (#403)
Browse files
src/axolotl/utils/trainer.py
CHANGED
@@ -452,13 +452,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
452 |
] = cfg.sample_packing_eff_est
|
453 |
|
454 |
if cfg.val_set_size == 0:
|
455 |
-
evaluation_strategy = "no"
|
456 |
-
elif cfg.eval_steps
|
457 |
-
|
458 |
-
|
459 |
else:
|
460 |
-
# eval
|
461 |
-
evaluation_strategy = "
|
462 |
|
463 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
464 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
@@ -471,9 +471,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
471 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
472 |
num_train_epochs=cfg.num_epochs,
|
473 |
learning_rate=cfg.learning_rate,
|
474 |
-
evaluation_strategy=evaluation_strategy,
|
475 |
save_strategy="steps" if cfg.save_steps else "epoch",
|
476 |
-
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
477 |
save_steps=cfg.save_steps,
|
478 |
output_dir=cfg.output_dir,
|
479 |
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
|
|
452 |
] = cfg.sample_packing_eff_est
|
453 |
|
454 |
if cfg.val_set_size == 0:
|
455 |
+
training_arguments_kwargs["evaluation_strategy"] = "no"
|
456 |
+
elif cfg.eval_steps:
|
457 |
+
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
458 |
+
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
459 |
else:
|
460 |
+
# we have an eval set, but no steps defined, use epoch
|
461 |
+
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
462 |
|
463 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
464 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
|
|
471 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
472 |
num_train_epochs=cfg.num_epochs,
|
473 |
learning_rate=cfg.learning_rate,
|
|
|
474 |
save_strategy="steps" if cfg.save_steps else "epoch",
|
|
|
475 |
save_steps=cfg.save_steps,
|
476 |
output_dir=cfg.output_dir,
|
477 |
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|