fix fsdp training args
Browse files
src/axolotl/utils/trainer.py
CHANGED
@@ -34,6 +34,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
34 |
apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
|
35 |
else:
|
36 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
|
|
|
|
|
|
|
|
37 |
|
38 |
|
39 |
# deepspeed
|
@@ -64,8 +68,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
64 |
optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
|
65 |
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
|
66 |
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
67 |
-
fsdp=cfg.fsdp.split(" ") if cfg.fsdp else None,
|
68 |
-
fsdp_transformer_layer_cls_to_wrap=cfg.fsdp_transformer_layer_cls_to_wrap if cfg.fsdp_transformer_layer_cls_to_wrap else None,
|
69 |
**training_arguments_kwargs,
|
70 |
)
|
71 |
|
|
|
34 |
apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
|
35 |
else:
|
36 |
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
37 |
+
if cfg.fsdp:
|
38 |
+
training_arguments_kwargs["fsdp"] = cfg.fsdp.split(" ")
|
39 |
+
if cfg.fsdp_transformer_layer_cls_to_wrap:
|
40 |
+
training_arguments_kwargs["fsdp_transformer_layer_cls_to_wrap"] = cfg.fsdp_transformer_layer_cls_to_wrap
|
41 |
|
42 |
|
43 |
# deepspeed
|
|
|
68 |
optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
|
69 |
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
|
70 |
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
|
|
|
|
71 |
**training_arguments_kwargs,
|
72 |
)
|
73 |
|