JohanWork winglian commited on
Commit
af29d81
1 Parent(s): 1b18003

ADD: warning if hub_model_id ist set but not any save strategy (#1202)

Browse files

* warning if hub model id set but no save

* add warning

* move the warning

* add test

* allow more public methods for tests for now

* fix tests

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/utils/config.py CHANGED
@@ -340,6 +340,11 @@ def validate_config(cfg):
340
  "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
341
  )
342
 
 
 
 
 
 
343
  if cfg.gptq and cfg.model_revision:
344
  raise ValueError(
345
  "model_revision is not supported for GPTQ models. "
 
340
  "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
341
  )
342
 
343
+ if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
344
+ LOG.warning(
345
+ "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
346
+ )
347
+
348
  if cfg.gptq and cfg.model_revision:
349
  raise ValueError(
350
  "model_revision is not supported for GPTQ models. "
tests/test_validation.py CHANGED
@@ -26,6 +26,7 @@ class BaseValidation(unittest.TestCase):
26
  self._caplog = caplog
27
 
28
 
 
29
  class ValidationTest(BaseValidation):
30
  """
31
  Test the validation module
@@ -698,6 +699,22 @@ class ValidationTest(BaseValidation):
698
  ):
699
  validate_config(cfg)
700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
702
  class ValidationCheckModelConfig(BaseValidation):
703
  """
 
26
  self._caplog = caplog
27
 
28
 
29
+ # pylint: disable=too-many-public-methods
30
  class ValidationTest(BaseValidation):
31
  """
32
  Test the validation module
 
699
  ):
700
  validate_config(cfg)
701
 
702
+ def test_hub_model_id_save_value_warns(self):
703
+ cfg = DictDefault({"hub_model_id": "test"})
704
+
705
+ with self._caplog.at_level(logging.WARNING):
706
+ validate_config(cfg)
707
+ assert (
708
+ "set without any models being saved" in self._caplog.records[0].message
709
+ )
710
+
711
+ def test_hub_model_id_save_value(self):
712
+ cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
713
+
714
+ with self._caplog.at_level(logging.WARNING):
715
+ validate_config(cfg)
716
+ assert len(self._caplog.records) == 0
717
+
718
 
719
  class ValidationCheckModelConfig(BaseValidation):
720
  """