ADD: warning if hub_model_id ist set but not any save strategy (#1202)
Browse files* warning if hub model id set but no save
* add warning
* move the warning
* add test
* allow more public methods for tests for now
* fix tests
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
- src/axolotl/utils/config.py +5 -0
- tests/test_validation.py +17 -0
src/axolotl/utils/config.py
CHANGED
@@ -340,6 +340,11 @@ def validate_config(cfg):
|
|
340 |
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
341 |
)
|
342 |
|
|
|
|
|
|
|
|
|
|
|
343 |
if cfg.gptq and cfg.model_revision:
|
344 |
raise ValueError(
|
345 |
"model_revision is not supported for GPTQ models. "
|
|
|
340 |
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
341 |
)
|
342 |
|
343 |
+
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
|
344 |
+
LOG.warning(
|
345 |
+
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
346 |
+
)
|
347 |
+
|
348 |
if cfg.gptq and cfg.model_revision:
|
349 |
raise ValueError(
|
350 |
"model_revision is not supported for GPTQ models. "
|
tests/test_validation.py
CHANGED
@@ -26,6 +26,7 @@ class BaseValidation(unittest.TestCase):
|
|
26 |
self._caplog = caplog
|
27 |
|
28 |
|
|
|
29 |
class ValidationTest(BaseValidation):
|
30 |
"""
|
31 |
Test the validation module
|
@@ -698,6 +699,22 @@ class ValidationTest(BaseValidation):
|
|
698 |
):
|
699 |
validate_config(cfg)
|
700 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
701 |
|
702 |
class ValidationCheckModelConfig(BaseValidation):
|
703 |
"""
|
|
|
26 |
self._caplog = caplog
|
27 |
|
28 |
|
29 |
+
# pylint: disable=too-many-public-methods
|
30 |
class ValidationTest(BaseValidation):
|
31 |
"""
|
32 |
Test the validation module
|
|
|
699 |
):
|
700 |
validate_config(cfg)
|
701 |
|
702 |
+
def test_hub_model_id_save_value_warns(self):
|
703 |
+
cfg = DictDefault({"hub_model_id": "test"})
|
704 |
+
|
705 |
+
with self._caplog.at_level(logging.WARNING):
|
706 |
+
validate_config(cfg)
|
707 |
+
assert (
|
708 |
+
"set without any models being saved" in self._caplog.records[0].message
|
709 |
+
)
|
710 |
+
|
711 |
+
def test_hub_model_id_save_value(self):
|
712 |
+
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
|
713 |
+
|
714 |
+
with self._caplog.at_level(logging.WARNING):
|
715 |
+
validate_config(cfg)
|
716 |
+
assert len(self._caplog.records) == 0
|
717 |
+
|
718 |
|
719 |
class ValidationCheckModelConfig(BaseValidation):
|
720 |
"""
|