Feat: Add warmup_ratio (#893)
Browse files* Feat: Add warmup_ratio
* fix: update readme with more details on conflict
- README.md +2 -1
- src/axolotl/core/trainer_builder.py +8 -5
- src/axolotl/utils/config.py +3 -0
- tests/test_validation.py +30 -0
README.md
CHANGED
@@ -675,7 +675,8 @@ gradient_accumulation_steps: 1
|
|
675 |
micro_batch_size: 2
|
676 |
eval_batch_size:
|
677 |
num_epochs: 4
|
678 |
-
warmup_steps: 100
|
|
|
679 |
learning_rate: 0.00003
|
680 |
lr_quadratic_warmup:
|
681 |
logging_steps:
|
|
|
675 |
micro_batch_size: 2
|
676 |
eval_batch_size:
|
677 |
num_epochs: 4
|
678 |
+
warmup_steps: 100 # cannot use with warmup_ratio
|
679 |
+
warmup_ratio: 0.05 # cannot use with warmup_steps
|
680 |
learning_rate: 0.00003
|
681 |
lr_quadratic_warmup:
|
682 |
logging_steps:
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -461,11 +461,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
461 |
return AxolotlTrainer
|
462 |
|
463 |
def build(self, total_num_steps):
|
464 |
-
warmup_steps =
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
|
|
|
|
|
|
469 |
logging_steps = (
|
470 |
self.cfg.logging_steps
|
471 |
if self.cfg.logging_steps is not None
|
|
|
461 |
return AxolotlTrainer
|
462 |
|
463 |
def build(self, total_num_steps):
|
464 |
+
warmup_steps = None
|
465 |
+
if self.cfg.warmup_steps is not None:
|
466 |
+
warmup_steps = self.cfg.warmup_steps
|
467 |
+
elif self.cfg.warmup_ratio is not None:
|
468 |
+
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
469 |
+
else:
|
470 |
+
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
471 |
+
|
472 |
logging_steps = (
|
473 |
self.cfg.logging_steps
|
474 |
if self.cfg.logging_steps is not None
|
src/axolotl/utils/config.py
CHANGED
@@ -372,6 +372,9 @@ def validate_config(cfg):
|
|
372 |
if cfg.rope_scaling:
|
373 |
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
374 |
|
|
|
|
|
|
|
375 |
# TODO
|
376 |
# MPT 7b
|
377 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
372 |
if cfg.rope_scaling:
|
373 |
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
374 |
|
375 |
+
if cfg.warmup_steps and cfg.warmup_ratio:
|
376 |
+
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
377 |
+
|
378 |
# TODO
|
379 |
# MPT 7b
|
380 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
tests/test_validation.py
CHANGED
@@ -649,3 +649,33 @@ class ValidationTest(unittest.TestCase):
|
|
649 |
)
|
650 |
|
651 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
)
|
650 |
|
651 |
validate_config(cfg)
|
652 |
+
|
653 |
+
def test_warmup_step_no_conflict(self):
|
654 |
+
cfg = DictDefault(
|
655 |
+
{
|
656 |
+
"warmup_steps": 10,
|
657 |
+
"warmup_ratio": 0.1,
|
658 |
+
}
|
659 |
+
)
|
660 |
+
|
661 |
+
with pytest.raises(
|
662 |
+
ValueError,
|
663 |
+
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
|
664 |
+
):
|
665 |
+
validate_config(cfg)
|
666 |
+
|
667 |
+
cfg = DictDefault(
|
668 |
+
{
|
669 |
+
"warmup_steps": 10,
|
670 |
+
}
|
671 |
+
)
|
672 |
+
|
673 |
+
validate_config(cfg)
|
674 |
+
|
675 |
+
cfg = DictDefault(
|
676 |
+
{
|
677 |
+
"warmup_ratio": 0.1,
|
678 |
+
}
|
679 |
+
)
|
680 |
+
|
681 |
+
validate_config(cfg)
|