Nanobit commited on
Commit
383f88d
1 Parent(s): b6ab8aa

Fix(cfg): Add validation for save_strategy and eval_strategy (#633)

Browse files

* Fix(cfg): Check save_strategy cfg conflict with save_steps

* Fix(cfg): Check evaluation_strategy cfg conflict with eval_steps

* chore: add extra check for steps only

src/axolotl/utils/config.py CHANGED
@@ -296,6 +296,24 @@ def validate_config(cfg):
296
  cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
297
  "sharegpt_simple", "sharegpt"
298
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  # TODO
301
  # MPT 7b
 
296
  cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
297
  "sharegpt_simple", "sharegpt"
298
  )
299
+ if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
300
+ raise ValueError(
301
+ "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
302
+ )
303
+
304
+ if (
305
+ cfg.evaluation_strategy
306
+ and cfg.eval_steps
307
+ and cfg.evaluation_strategy != "steps"
308
+ ):
309
+ raise ValueError(
310
+ "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
311
+ )
312
+
313
+ if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
314
+ raise ValueError(
315
+ "eval_steps and evaluation_strategy are not supported with val_set_size == 0"
316
+ )
317
 
318
  # TODO
319
  # MPT 7b
src/axolotl/utils/trainer.py CHANGED
@@ -604,26 +604,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
604
  "sample_packing_efficiency"
605
  ] = cfg.sample_packing_eff_est
606
 
607
- if cfg.eval_steps and cfg.evaluation_strategy:
608
- # assume if the user set both, they know what they're doing
609
- training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
610
  training_arguments_kwargs["eval_steps"] = cfg.eval_steps
 
 
611
  elif cfg.val_set_size == 0:
612
  # no eval set, so don't eval
613
  training_arguments_kwargs["evaluation_strategy"] = "no"
614
- elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
615
- # if explicitly set for epoch, just set, and eval steps don't matter
616
- training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
617
- elif cfg.eval_steps:
618
- # steps isn't used w/ epochs
619
- training_arguments_kwargs["evaluation_strategy"] = "steps"
620
- training_arguments_kwargs["eval_steps"] = cfg.eval_steps
621
  else:
622
  # we have an eval set, but no steps defined, default to use epoch
623
  training_arguments_kwargs["evaluation_strategy"] = "epoch"
624
 
625
  if cfg.save_steps:
626
- # save_steps implies save_strategy of steps
627
  training_arguments_kwargs["save_strategy"] = "steps"
628
  training_arguments_kwargs["save_steps"] = cfg.save_steps
629
  elif cfg.save_strategy:
 
604
  "sample_packing_efficiency"
605
  ] = cfg.sample_packing_eff_est
606
 
607
+ if cfg.eval_steps:
608
+ training_arguments_kwargs["evaluation_strategy"] = "steps"
 
609
  training_arguments_kwargs["eval_steps"] = cfg.eval_steps
610
+ elif cfg.evaluation_strategy:
611
+ training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
612
  elif cfg.val_set_size == 0:
613
  # no eval set, so don't eval
614
  training_arguments_kwargs["evaluation_strategy"] = "no"
 
 
 
 
 
 
 
615
  else:
616
  # we have an eval set, but no steps defined, default to use epoch
617
  training_arguments_kwargs["evaluation_strategy"] = "epoch"
618
 
619
  if cfg.save_steps:
 
620
  training_arguments_kwargs["save_strategy"] = "steps"
621
  training_arguments_kwargs["save_steps"] = cfg.save_steps
622
  elif cfg.save_strategy:
tests/test_validation.py CHANGED
@@ -397,3 +397,171 @@ class ValidationTest(unittest.TestCase):
397
  for record in self._caplog.records
398
  )
399
  assert cfg.datasets[0].type == "sharegpt:load_role"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  for record in self._caplog.records
398
  )
399
  assert cfg.datasets[0].type == "sharegpt:load_role"
400
+
401
+ def test_no_conflict_save_strategy(self):
402
+ cfg = DictDefault(
403
+ {
404
+ "save_strategy": "epoch",
405
+ "save_steps": 10,
406
+ }
407
+ )
408
+
409
+ with pytest.raises(
410
+ ValueError, match=r".*save_strategy and save_steps mismatch.*"
411
+ ):
412
+ validate_config(cfg)
413
+
414
+ cfg = DictDefault(
415
+ {
416
+ "save_strategy": "no",
417
+ "save_steps": 10,
418
+ }
419
+ )
420
+
421
+ with pytest.raises(
422
+ ValueError, match=r".*save_strategy and save_steps mismatch.*"
423
+ ):
424
+ validate_config(cfg)
425
+
426
+ cfg = DictDefault(
427
+ {
428
+ "save_strategy": "steps",
429
+ }
430
+ )
431
+
432
+ validate_config(cfg)
433
+
434
+ cfg = DictDefault(
435
+ {
436
+ "save_strategy": "steps",
437
+ "save_steps": 10,
438
+ }
439
+ )
440
+
441
+ validate_config(cfg)
442
+
443
+ cfg = DictDefault(
444
+ {
445
+ "save_steps": 10,
446
+ }
447
+ )
448
+
449
+ validate_config(cfg)
450
+
451
+ cfg = DictDefault(
452
+ {
453
+ "save_strategy": "no",
454
+ }
455
+ )
456
+
457
+ validate_config(cfg)
458
+
459
+ def test_no_conflict_eval_strategy(self):
460
+ cfg = DictDefault(
461
+ {
462
+ "evaluation_strategy": "epoch",
463
+ "eval_steps": 10,
464
+ }
465
+ )
466
+
467
+ with pytest.raises(
468
+ ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
469
+ ):
470
+ validate_config(cfg)
471
+
472
+ cfg = DictDefault(
473
+ {
474
+ "evaluation_strategy": "no",
475
+ "eval_steps": 10,
476
+ }
477
+ )
478
+
479
+ with pytest.raises(
480
+ ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
481
+ ):
482
+ validate_config(cfg)
483
+
484
+ cfg = DictDefault(
485
+ {
486
+ "evaluation_strategy": "steps",
487
+ }
488
+ )
489
+
490
+ validate_config(cfg)
491
+
492
+ cfg = DictDefault(
493
+ {
494
+ "evaluation_strategy": "steps",
495
+ "eval_steps": 10,
496
+ }
497
+ )
498
+
499
+ validate_config(cfg)
500
+
501
+ cfg = DictDefault(
502
+ {
503
+ "eval_steps": 10,
504
+ }
505
+ )
506
+
507
+ validate_config(cfg)
508
+
509
+ cfg = DictDefault(
510
+ {
511
+ "evaluation_strategy": "no",
512
+ }
513
+ )
514
+
515
+ validate_config(cfg)
516
+
517
+ cfg = DictDefault(
518
+ {
519
+ "evaluation_strategy": "epoch",
520
+ "val_set_size": 0,
521
+ }
522
+ )
523
+
524
+ with pytest.raises(
525
+ ValueError,
526
+ match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
527
+ ):
528
+ validate_config(cfg)
529
+
530
+ cfg = DictDefault(
531
+ {
532
+ "eval_steps": 10,
533
+ "val_set_size": 0,
534
+ }
535
+ )
536
+
537
+ with pytest.raises(
538
+ ValueError,
539
+ match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
540
+ ):
541
+ validate_config(cfg)
542
+
543
+ cfg = DictDefault(
544
+ {
545
+ "val_set_size": 0,
546
+ }
547
+ )
548
+
549
+ validate_config(cfg)
550
+
551
+ cfg = DictDefault(
552
+ {
553
+ "eval_steps": 10,
554
+ "val_set_size": 0.01,
555
+ }
556
+ )
557
+
558
+ validate_config(cfg)
559
+
560
+ cfg = DictDefault(
561
+ {
562
+ "evaluation_strategy": "epoch",
563
+ "val_set_size": 0.01,
564
+ }
565
+ )
566
+
567
+ validate_config(cfg)