winglian commited on
Commit
da10af0
1 Parent(s): 85cf4f8

fix eval steps and strategy (#403)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +6 -8
src/axolotl/utils/trainer.py CHANGED
@@ -452,13 +452,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
452
  ] = cfg.sample_packing_eff_est
453
 
454
  if cfg.val_set_size == 0:
455
- evaluation_strategy = "no"
456
- elif cfg.eval_steps < 1:
457
- # eval every epoch
458
- evaluation_strategy = "epoch"
459
  else:
460
- # eval every eval_steps steps
461
- evaluation_strategy = "steps"
462
 
463
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
464
  max_steps=total_num_steps if cfg.max_steps else -1,
@@ -471,9 +471,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
471
  eval_accumulation_steps=cfg.gradient_accumulation_steps,
472
  num_train_epochs=cfg.num_epochs,
473
  learning_rate=cfg.learning_rate,
474
- evaluation_strategy=evaluation_strategy,
475
  save_strategy="steps" if cfg.save_steps else "epoch",
476
- eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
477
  save_steps=cfg.save_steps,
478
  output_dir=cfg.output_dir,
479
  save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
 
452
  ] = cfg.sample_packing_eff_est
453
 
454
  if cfg.val_set_size == 0:
455
+ training_arguments_kwargs["evaluation_strategy"] = "no"
456
+ elif cfg.eval_steps:
457
+ training_arguments_kwargs["evaluation_strategy"] = "steps"
458
+ training_arguments_kwargs["eval_steps"] = cfg.eval_steps
459
  else:
460
+ # we have an eval set, but no steps defined, use epoch
461
+ training_arguments_kwargs["evaluation_strategy"] = "epoch"
462
 
463
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
464
  max_steps=total_num_steps if cfg.max_steps else -1,
 
471
  eval_accumulation_steps=cfg.gradient_accumulation_steps,
472
  num_train_epochs=cfg.num_epochs,
473
  learning_rate=cfg.learning_rate,
 
474
  save_strategy="steps" if cfg.save_steps else "epoch",
 
475
  save_steps=cfg.save_steps,
476
  output_dir=cfg.output_dir,
477
  save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,