Fix(save): Save as safetensors (#363)
Browse files- scripts/finetune.py +10 -5
scripts/finetune.py
CHANGED
|
@@ -257,6 +257,8 @@ def train(
|
|
| 257 |
LOG.info("loading model and peft_config...")
|
| 258 |
model, peft_config = load_model(cfg, tokenizer)
|
| 259 |
|
|
|
|
|
|
|
| 260 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
| 261 |
LOG.info("running merge of LoRA with base model")
|
| 262 |
model = model.merge_and_unload()
|
|
@@ -264,7 +266,10 @@ def train(
|
|
| 264 |
|
| 265 |
if cfg.local_rank == 0:
|
| 266 |
LOG.info("saving merged model")
|
| 267 |
-
model.save_pretrained(
|
|
|
|
|
|
|
|
|
|
| 268 |
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
| 269 |
return
|
| 270 |
|
|
@@ -280,7 +285,7 @@ def train(
|
|
| 280 |
return
|
| 281 |
|
| 282 |
if "shard" in kwargs:
|
| 283 |
-
model.save_pretrained(cfg.output_dir)
|
| 284 |
return
|
| 285 |
|
| 286 |
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
|
@@ -302,7 +307,7 @@ def train(
|
|
| 302 |
def terminate_handler(_, __, model):
|
| 303 |
if cfg.flash_optimum:
|
| 304 |
model = BetterTransformer.reverse(model)
|
| 305 |
-
model.save_pretrained(cfg.output_dir)
|
| 306 |
sys.exit(0)
|
| 307 |
|
| 308 |
signal.signal(
|
|
@@ -342,11 +347,11 @@ def train(
|
|
| 342 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 343 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 344 |
if cfg.fsdp:
|
| 345 |
-
model.save_pretrained(cfg.output_dir)
|
| 346 |
elif cfg.local_rank == 0:
|
| 347 |
if cfg.flash_optimum:
|
| 348 |
model = BetterTransformer.reverse(model)
|
| 349 |
-
model.save_pretrained(cfg.output_dir)
|
| 350 |
|
| 351 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
| 352 |
|
|
|
|
| 257 |
LOG.info("loading model and peft_config...")
|
| 258 |
model, peft_config = load_model(cfg, tokenizer)
|
| 259 |
|
| 260 |
+
safe_serialization = cfg.save_safetensors is True
|
| 261 |
+
|
| 262 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
| 263 |
LOG.info("running merge of LoRA with base model")
|
| 264 |
model = model.merge_and_unload()
|
|
|
|
| 266 |
|
| 267 |
if cfg.local_rank == 0:
|
| 268 |
LOG.info("saving merged model")
|
| 269 |
+
model.save_pretrained(
|
| 270 |
+
str(Path(cfg.output_dir) / "merged"),
|
| 271 |
+
safe_serialization=safe_serialization,
|
| 272 |
+
)
|
| 273 |
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
| 274 |
return
|
| 275 |
|
|
|
|
| 285 |
return
|
| 286 |
|
| 287 |
if "shard" in kwargs:
|
| 288 |
+
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
| 289 |
return
|
| 290 |
|
| 291 |
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
|
|
|
| 307 |
def terminate_handler(_, __, model):
|
| 308 |
if cfg.flash_optimum:
|
| 309 |
model = BetterTransformer.reverse(model)
|
| 310 |
+
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
| 311 |
sys.exit(0)
|
| 312 |
|
| 313 |
signal.signal(
|
|
|
|
| 347 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
| 348 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
| 349 |
if cfg.fsdp:
|
| 350 |
+
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
| 351 |
elif cfg.local_rank == 0:
|
| 352 |
if cfg.flash_optimum:
|
| 353 |
model = BetterTransformer.reverse(model)
|
| 354 |
+
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
| 355 |
|
| 356 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
| 357 |
|