Commit
•
e4d1585
1
Parent(s):
70157cc
Fix DeepSpeed Zero 3 Saving (#709)
Browse files* Update train.py
* add zero3 check
* chore: lint
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
- src/axolotl/train.py +17 -0
src/axolotl/train.py
CHANGED
@@ -12,6 +12,7 @@ import torch
|
|
12 |
import transformers.modelcard
|
13 |
from datasets import Dataset
|
14 |
from optimum.bettertransformer import BetterTransformer
|
|
|
15 |
|
16 |
from axolotl.common.cli import TrainerCliArgs
|
17 |
from axolotl.logging_config import configure_logging
|
@@ -134,6 +135,22 @@ def train(
|
|
134 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
135 |
if cfg.fsdp:
|
136 |
trainer.save_model(cfg.output_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
elif cfg.local_rank == 0:
|
138 |
if cfg.flash_optimum:
|
139 |
model = BetterTransformer.reverse(model)
|
|
|
12 |
import transformers.modelcard
|
13 |
from datasets import Dataset
|
14 |
from optimum.bettertransformer import BetterTransformer
|
15 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
16 |
|
17 |
from axolotl.common.cli import TrainerCliArgs
|
18 |
from axolotl.logging_config import configure_logging
|
|
|
135 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
136 |
if cfg.fsdp:
|
137 |
trainer.save_model(cfg.output_dir)
|
138 |
+
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
139 |
+
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
140 |
+
trainer.accelerator.wait_for_everyone()
|
141 |
+
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
142 |
+
|
143 |
+
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
144 |
+
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
145 |
+
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
146 |
+
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
147 |
+
# The model name saved is `pytorch_model.bin`
|
148 |
+
unwrapped_model.save_pretrained(
|
149 |
+
cfg.output_dir,
|
150 |
+
is_main_process=trainer.accelerator.is_main_process,
|
151 |
+
save_function=trainer.accelerator.save,
|
152 |
+
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
153 |
+
)
|
154 |
elif cfg.local_rank == 0:
|
155 |
if cfg.flash_optimum:
|
156 |
model = BetterTransformer.reverse(model)
|