winglian commited on
Commit
6d0ee4b
1 Parent(s): a81f52d

support adamw and grad norm hyperparams

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +9 -0
src/axolotl/utils/trainer.py CHANGED
@@ -115,6 +115,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
115
  # TODO search Path("./") for one
116
  training_arguments_kwargs["deepspeed"] = "./ds_config.json"
117
 
 
 
 
 
 
 
 
 
 
118
  training_args = transformers.TrainingArguments(
119
  per_device_train_batch_size=cfg.micro_batch_size,
120
  per_device_eval_batch_size=cfg.eval_batch_size
 
115
  # TODO search Path("./") for one
116
  training_arguments_kwargs["deepspeed"] = "./ds_config.json"
117
 
118
+ if cfg.adam_beta1:
119
+ training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
120
+ if cfg.adam_beta2:
121
+ training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
122
+ if cfg.adam_epsilon:
123
+ training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
124
+ if cfg.max_grad_norm:
125
+ training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
126
+
127
  training_args = transformers.TrainingArguments(
128
  per_device_train_batch_size=cfg.micro_batch_size,
129
  per_device_eval_batch_size=cfg.eval_batch_size