winglian commited on
Commit
9492d4e
·
unverified ·
2 Parent(s): a81f52d ad5ca4f

Merge pull request #215 from OpenAccess-AI-Collective/adamw-hyperparams-cfg

Browse files
README.md CHANGED
@@ -422,6 +422,12 @@ log_sweep_max_lr:
422
  optimizer:
423
  # specify weight decay
424
  weight_decay:
 
 
 
 
 
 
425
 
426
  # whether to bettertransformers
427
  flash_optimum:
 
422
  optimizer:
423
  # specify weight decay
424
  weight_decay:
425
+ # adamw hyperparams
426
+ adam_beta1:
427
+ adam_beta2:
428
+ adam_epsilon:
429
+ # Gradient clipping max norm
430
+ max_grad_norm:
431
 
432
  # whether to bettertransformers
433
  flash_optimum:
src/axolotl/utils/trainer.py CHANGED
@@ -115,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
115
  # TODO search Path("./") for one
116
  training_arguments_kwargs["deepspeed"] = "./ds_config.json"
117
 
 
 
 
 
 
 
 
 
 
118
  training_args = transformers.TrainingArguments(
119
  per_device_train_batch_size=cfg.micro_batch_size,
120
  per_device_eval_batch_size=cfg.eval_batch_size
 
115
  # TODO search Path("./") for one
116
  training_arguments_kwargs["deepspeed"] = "./ds_config.json"
117
 
118
+ if cfg.adam_beta1:
119
+ training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
120
+ if cfg.adam_beta2:
121
+ training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
122
+ if cfg.adam_epsilon:
123
+ training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
124
+ if cfg.max_grad_norm:
125
+ training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
126
+
127
  training_args = transformers.TrainingArguments(
128
  per_device_train_batch_size=cfg.micro_batch_size,
129
  per_device_eval_batch_size=cfg.eval_batch_size
src/axolotl/utils/validation.py CHANGED
@@ -87,6 +87,11 @@ def validate_config(cfg):
87
  "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
88
  )
89
 
 
 
 
 
 
90
  # TODO
91
  # MPT 7b
92
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
87
  "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
88
  )
89
 
90
+ if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
91
+ not cfg.optimizer or "adamw" not in cfg.optimizer
92
+ ):
93
+ logging.warning("adamw hyperparameters found, but no adamw optimizer set")
94
+
95
  # TODO
96
  # MPT 7b
97
  # https://github.com/facebookresearch/bitsandbytes/issues/25
tests/test_validation.py CHANGED
@@ -263,3 +263,53 @@ class ValidationTest(unittest.TestCase):
263
 
264
  with pytest.raises(ValueError, match=regex_exp):
265
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  with pytest.raises(ValueError, match=regex_exp):
265
  validate_config(cfg)
266
+
267
+ def test_adamw_hyperparams(self):
268
+ cfg = DictDefault(
269
+ {
270
+ "optimizer": None,
271
+ "adamw_epsilon": 0.0001,
272
+ }
273
+ )
274
+
275
+ with self._caplog.at_level(logging.WARNING):
276
+ validate_config(cfg)
277
+ assert any(
278
+ "adamw hyperparameters found, but no adamw optimizer set"
279
+ in record.message
280
+ for record in self._caplog.records
281
+ )
282
+
283
+ cfg = DictDefault(
284
+ {
285
+ "optimizer": "adafactor",
286
+ "adamw_beta1": 0.0001,
287
+ }
288
+ )
289
+
290
+ with self._caplog.at_level(logging.WARNING):
291
+ validate_config(cfg)
292
+ assert any(
293
+ "adamw hyperparameters found, but no adamw optimizer set"
294
+ in record.message
295
+ for record in self._caplog.records
296
+ )
297
+
298
+ cfg = DictDefault(
299
+ {
300
+ "optimizer": "adamw_bnb_8bit",
301
+ "adamw_beta1": 0.0001,
302
+ "adamw_beta2": 0.0001,
303
+ "adamw_epsilon": 0.0001,
304
+ }
305
+ )
306
+
307
+ validate_config(cfg)
308
+
309
+ cfg = DictDefault(
310
+ {
311
+ "optimizer": "adafactor",
312
+ }
313
+ )
314
+
315
+ validate_config(cfg)