fix FSDP save of final model (#329)
Browse files- scripts/finetune.py +3 -1
scripts/finetune.py
CHANGED
@@ -344,7 +344,9 @@ def train(
|
|
344 |
|
345 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
346 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
347 |
-
if cfg.
|
|
|
|
|
348 |
if cfg.flash_optimum:
|
349 |
model = BetterTransformer.reverse(model)
|
350 |
model.save_pretrained(cfg.output_dir)
|
|
|
344 |
|
345 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
346 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
347 |
+
if cfg.fsdp:
|
348 |
+
model.save_pretrained(cfg.output_dir)
|
349 |
+
elif cfg.local_rank == 0:
|
350 |
if cfg.flash_optimum:
|
351 |
model = BetterTransformer.reverse(model)
|
352 |
model.save_pretrained(cfg.output_dir)
|