winglian commited on
Commit
894cba0
·
unverified ·
1 Parent(s): 41a4d15

fix FSDP save of final model (#329)

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +3 -1
scripts/finetune.py CHANGED
@@ -344,7 +344,9 @@ def train(
344
 
345
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
346
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
347
- if cfg.local_rank == 0:
 
 
348
  if cfg.flash_optimum:
349
  model = BetterTransformer.reverse(model)
350
  model.save_pretrained(cfg.output_dir)
 
344
 
345
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
346
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
347
+ if cfg.fsdp:
348
+ model.save_pretrained(cfg.output_dir)
349
+ elif cfg.local_rank == 0:
350
  if cfg.flash_optimum:
351
  model = BetterTransformer.reverse(model)
352
  model.save_pretrained(cfg.output_dir)