winglian commited on
Commit
cb9d3af
1 Parent(s): c969f0a

add validation and tests for adamw hyperparam

Browse files
src/axolotl/utils/validation.py CHANGED
@@ -87,6 +87,11 @@ def validate_config(cfg):
87
  "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
88
  )
89
 
 
 
 
 
 
90
  # TODO
91
  # MPT 7b
92
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
87
  "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
88
  )
89
 
90
+ if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
91
+ not cfg.optimizer or "adamw" not in cfg.optimizer
92
+ ):
93
+ logging.warning("adamw hyperparameters found, but no adamw optimizer set")
94
+
95
  # TODO
96
  # MPT 7b
97
  # https://github.com/facebookresearch/bitsandbytes/issues/25
tests/test_validation.py CHANGED
@@ -263,3 +263,45 @@ class ValidationTest(unittest.TestCase):
263
 
264
  with pytest.raises(ValueError, match=regex_exp):
265
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  with pytest.raises(ValueError, match=regex_exp):
265
  validate_config(cfg)
266
+
267
+ def test_adamw_hyperparams(self):
268
+ cfg = DictDefault(
269
+ {
270
+ "optimizer": None,
271
+ "adamw_epsilon": 0.0001,
272
+ }
273
+ )
274
+
275
+ with self._caplog.at_level(logging.WARNING):
276
+ validate_config(cfg)
277
+ assert any(
278
+ "adamw hyperparameters found, but no adamw optimizer set"
279
+ in record.message
280
+ for record in self._caplog.records
281
+ )
282
+
283
+ cfg = DictDefault(
284
+ {
285
+ "optimizer": "adafactor",
286
+ "adamw_beta1": 0.0001,
287
+ }
288
+ )
289
+
290
+ with self._caplog.at_level(logging.WARNING):
291
+ validate_config(cfg)
292
+ assert any(
293
+ "adamw hyperparameters found, but no adamw optimizer set"
294
+ in record.message
295
+ for record in self._caplog.records
296
+ )
297
+
298
+ cfg = DictDefault(
299
+ {
300
+ "optimizer": "adamw_bnb_8bit",
301
+ "adamw_beta1": 0.0001,
302
+ "adamw_beta2": 0.0001,
303
+ "adamw_epsilon": 0.0001,
304
+ }
305
+ )
306
+
307
+ validate_config(cfg)