Fix(cfg): Add validation for save_strategy and eval_strategy (#633)
Browse files* Fix(cfg): Check save_strategy cfg conflict with save_steps
* Fix(cfg): Check evaluation_strategy cfg conflict with eval_steps
* chore: add extra check for steps only
- src/axolotl/utils/config.py +18 -0
- src/axolotl/utils/trainer.py +4 -11
- tests/test_validation.py +168 -0
src/axolotl/utils/config.py
CHANGED
@@ -296,6 +296,24 @@ def validate_config(cfg):
|
|
296 |
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
297 |
"sharegpt_simple", "sharegpt"
|
298 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
# TODO
|
301 |
# MPT 7b
|
|
|
296 |
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
297 |
"sharegpt_simple", "sharegpt"
|
298 |
)
|
299 |
+
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
300 |
+
raise ValueError(
|
301 |
+
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
302 |
+
)
|
303 |
+
|
304 |
+
if (
|
305 |
+
cfg.evaluation_strategy
|
306 |
+
and cfg.eval_steps
|
307 |
+
and cfg.evaluation_strategy != "steps"
|
308 |
+
):
|
309 |
+
raise ValueError(
|
310 |
+
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
311 |
+
)
|
312 |
+
|
313 |
+
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
|
314 |
+
raise ValueError(
|
315 |
+
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
316 |
+
)
|
317 |
|
318 |
# TODO
|
319 |
# MPT 7b
|
src/axolotl/utils/trainer.py
CHANGED
@@ -604,26 +604,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
604 |
"sample_packing_efficiency"
|
605 |
] = cfg.sample_packing_eff_est
|
606 |
|
607 |
-
if cfg.eval_steps
|
608 |
-
|
609 |
-
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
610 |
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
|
|
|
|
611 |
elif cfg.val_set_size == 0:
|
612 |
# no eval set, so don't eval
|
613 |
training_arguments_kwargs["evaluation_strategy"] = "no"
|
614 |
-
elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
|
615 |
-
# if explicitly set for epoch, just set, and eval steps don't matter
|
616 |
-
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
617 |
-
elif cfg.eval_steps:
|
618 |
-
# steps isn't used w/ epochs
|
619 |
-
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
620 |
-
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
621 |
else:
|
622 |
# we have an eval set, but no steps defined, default to use epoch
|
623 |
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
624 |
|
625 |
if cfg.save_steps:
|
626 |
-
# save_steps implies save_strategy of steps
|
627 |
training_arguments_kwargs["save_strategy"] = "steps"
|
628 |
training_arguments_kwargs["save_steps"] = cfg.save_steps
|
629 |
elif cfg.save_strategy:
|
|
|
604 |
"sample_packing_efficiency"
|
605 |
] = cfg.sample_packing_eff_est
|
606 |
|
607 |
+
if cfg.eval_steps:
|
608 |
+
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
|
|
609 |
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
610 |
+
elif cfg.evaluation_strategy:
|
611 |
+
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
612 |
elif cfg.val_set_size == 0:
|
613 |
# no eval set, so don't eval
|
614 |
training_arguments_kwargs["evaluation_strategy"] = "no"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
615 |
else:
|
616 |
# we have an eval set, but no steps defined, default to use epoch
|
617 |
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
618 |
|
619 |
if cfg.save_steps:
|
|
|
620 |
training_arguments_kwargs["save_strategy"] = "steps"
|
621 |
training_arguments_kwargs["save_steps"] = cfg.save_steps
|
622 |
elif cfg.save_strategy:
|
tests/test_validation.py
CHANGED
@@ -397,3 +397,171 @@ class ValidationTest(unittest.TestCase):
|
|
397 |
for record in self._caplog.records
|
398 |
)
|
399 |
assert cfg.datasets[0].type == "sharegpt:load_role"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
for record in self._caplog.records
|
398 |
)
|
399 |
assert cfg.datasets[0].type == "sharegpt:load_role"
|
400 |
+
|
401 |
+
def test_no_conflict_save_strategy(self):
|
402 |
+
cfg = DictDefault(
|
403 |
+
{
|
404 |
+
"save_strategy": "epoch",
|
405 |
+
"save_steps": 10,
|
406 |
+
}
|
407 |
+
)
|
408 |
+
|
409 |
+
with pytest.raises(
|
410 |
+
ValueError, match=r".*save_strategy and save_steps mismatch.*"
|
411 |
+
):
|
412 |
+
validate_config(cfg)
|
413 |
+
|
414 |
+
cfg = DictDefault(
|
415 |
+
{
|
416 |
+
"save_strategy": "no",
|
417 |
+
"save_steps": 10,
|
418 |
+
}
|
419 |
+
)
|
420 |
+
|
421 |
+
with pytest.raises(
|
422 |
+
ValueError, match=r".*save_strategy and save_steps mismatch.*"
|
423 |
+
):
|
424 |
+
validate_config(cfg)
|
425 |
+
|
426 |
+
cfg = DictDefault(
|
427 |
+
{
|
428 |
+
"save_strategy": "steps",
|
429 |
+
}
|
430 |
+
)
|
431 |
+
|
432 |
+
validate_config(cfg)
|
433 |
+
|
434 |
+
cfg = DictDefault(
|
435 |
+
{
|
436 |
+
"save_strategy": "steps",
|
437 |
+
"save_steps": 10,
|
438 |
+
}
|
439 |
+
)
|
440 |
+
|
441 |
+
validate_config(cfg)
|
442 |
+
|
443 |
+
cfg = DictDefault(
|
444 |
+
{
|
445 |
+
"save_steps": 10,
|
446 |
+
}
|
447 |
+
)
|
448 |
+
|
449 |
+
validate_config(cfg)
|
450 |
+
|
451 |
+
cfg = DictDefault(
|
452 |
+
{
|
453 |
+
"save_strategy": "no",
|
454 |
+
}
|
455 |
+
)
|
456 |
+
|
457 |
+
validate_config(cfg)
|
458 |
+
|
459 |
+
def test_no_conflict_eval_strategy(self):
|
460 |
+
cfg = DictDefault(
|
461 |
+
{
|
462 |
+
"evaluation_strategy": "epoch",
|
463 |
+
"eval_steps": 10,
|
464 |
+
}
|
465 |
+
)
|
466 |
+
|
467 |
+
with pytest.raises(
|
468 |
+
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
469 |
+
):
|
470 |
+
validate_config(cfg)
|
471 |
+
|
472 |
+
cfg = DictDefault(
|
473 |
+
{
|
474 |
+
"evaluation_strategy": "no",
|
475 |
+
"eval_steps": 10,
|
476 |
+
}
|
477 |
+
)
|
478 |
+
|
479 |
+
with pytest.raises(
|
480 |
+
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
481 |
+
):
|
482 |
+
validate_config(cfg)
|
483 |
+
|
484 |
+
cfg = DictDefault(
|
485 |
+
{
|
486 |
+
"evaluation_strategy": "steps",
|
487 |
+
}
|
488 |
+
)
|
489 |
+
|
490 |
+
validate_config(cfg)
|
491 |
+
|
492 |
+
cfg = DictDefault(
|
493 |
+
{
|
494 |
+
"evaluation_strategy": "steps",
|
495 |
+
"eval_steps": 10,
|
496 |
+
}
|
497 |
+
)
|
498 |
+
|
499 |
+
validate_config(cfg)
|
500 |
+
|
501 |
+
cfg = DictDefault(
|
502 |
+
{
|
503 |
+
"eval_steps": 10,
|
504 |
+
}
|
505 |
+
)
|
506 |
+
|
507 |
+
validate_config(cfg)
|
508 |
+
|
509 |
+
cfg = DictDefault(
|
510 |
+
{
|
511 |
+
"evaluation_strategy": "no",
|
512 |
+
}
|
513 |
+
)
|
514 |
+
|
515 |
+
validate_config(cfg)
|
516 |
+
|
517 |
+
cfg = DictDefault(
|
518 |
+
{
|
519 |
+
"evaluation_strategy": "epoch",
|
520 |
+
"val_set_size": 0,
|
521 |
+
}
|
522 |
+
)
|
523 |
+
|
524 |
+
with pytest.raises(
|
525 |
+
ValueError,
|
526 |
+
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
527 |
+
):
|
528 |
+
validate_config(cfg)
|
529 |
+
|
530 |
+
cfg = DictDefault(
|
531 |
+
{
|
532 |
+
"eval_steps": 10,
|
533 |
+
"val_set_size": 0,
|
534 |
+
}
|
535 |
+
)
|
536 |
+
|
537 |
+
with pytest.raises(
|
538 |
+
ValueError,
|
539 |
+
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
540 |
+
):
|
541 |
+
validate_config(cfg)
|
542 |
+
|
543 |
+
cfg = DictDefault(
|
544 |
+
{
|
545 |
+
"val_set_size": 0,
|
546 |
+
}
|
547 |
+
)
|
548 |
+
|
549 |
+
validate_config(cfg)
|
550 |
+
|
551 |
+
cfg = DictDefault(
|
552 |
+
{
|
553 |
+
"eval_steps": 10,
|
554 |
+
"val_set_size": 0.01,
|
555 |
+
}
|
556 |
+
)
|
557 |
+
|
558 |
+
validate_config(cfg)
|
559 |
+
|
560 |
+
cfg = DictDefault(
|
561 |
+
{
|
562 |
+
"evaluation_strategy": "epoch",
|
563 |
+
"val_set_size": 0.01,
|
564 |
+
}
|
565 |
+
)
|
566 |
+
|
567 |
+
validate_config(cfg)
|