winglian commited on
Commit
29936bb
1 Parent(s): 7882181

fix fsdp training args

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +4 -2
src/axolotl/utils/trainer.py CHANGED
@@ -34,6 +34,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
34
  apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
35
  else:
36
  training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
 
 
 
 
37
 
38
 
39
  # deepspeed
@@ -64,8 +68,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
64
  optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
65
  lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
66
  weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
67
- fsdp=cfg.fsdp.split(" ") if cfg.fsdp else None,
68
- fsdp_transformer_layer_cls_to_wrap=cfg.fsdp_transformer_layer_cls_to_wrap if cfg.fsdp_transformer_layer_cls_to_wrap else None,
69
  **training_arguments_kwargs,
70
  )
71
 
 
34
  apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
35
  else:
36
  training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
37
+ if cfg.fsdp:
38
+ training_arguments_kwargs["fsdp"] = cfg.fsdp.split(" ")
39
+ if cfg.fsdp_transformer_layer_cls_to_wrap:
40
+ training_arguments_kwargs["fsdp_transformer_layer_cls_to_wrap"] = cfg.fsdp_transformer_layer_cls_to_wrap
41
 
42
 
43
  # deepspeed
 
68
  optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
69
  lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
70
  weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
 
 
71
  **training_arguments_kwargs,
72
  )
73