tokestermw winglian commited on
Commit
e4d1585
1 Parent(s): 70157cc

Fix DeepSpeed Zero 3 Saving (#709)

Browse files

* Update train.py

* add zero3 check

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

Files changed (1) hide show
  1. src/axolotl/train.py +17 -0
src/axolotl/train.py CHANGED
@@ -12,6 +12,7 @@ import torch
12
  import transformers.modelcard
13
  from datasets import Dataset
14
  from optimum.bettertransformer import BetterTransformer
 
15
 
16
  from axolotl.common.cli import TrainerCliArgs
17
  from axolotl.logging_config import configure_logging
@@ -134,6 +135,22 @@ def train(
134
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
135
  if cfg.fsdp:
136
  trainer.save_model(cfg.output_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  elif cfg.local_rank == 0:
138
  if cfg.flash_optimum:
139
  model = BetterTransformer.reverse(model)
 
12
  import transformers.modelcard
13
  from datasets import Dataset
14
  from optimum.bettertransformer import BetterTransformer
15
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
16
 
17
  from axolotl.common.cli import TrainerCliArgs
18
  from axolotl.logging_config import configure_logging
 
135
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
136
  if cfg.fsdp:
137
  trainer.save_model(cfg.output_dir)
138
+ elif cfg.deepspeed and is_deepspeed_zero3_enabled():
139
+ # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
140
+ trainer.accelerator.wait_for_everyone()
141
+ unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
142
+
143
+ # Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
144
+ # `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
145
+ # `zero3_save_16bit_model` is True in DeepSpeed Plugin.
146
+ # For Zero Stages 1 and 2, models are saved as usual in the output directory.
147
+ # The model name saved is `pytorch_model.bin`
148
+ unwrapped_model.save_pretrained(
149
+ cfg.output_dir,
150
+ is_main_process=trainer.accelerator.is_main_process,
151
+ save_function=trainer.accelerator.save,
152
+ state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
153
+ )
154
  elif cfg.local_rank == 0:
155
  if cfg.flash_optimum:
156
  model = BetterTransformer.reverse(model)