Jan Philipp Harries
Jan Philipp Harries
commited on
Commit
•
be75668
1
Parent(s):
aeec7c4
set fsdp state dict (#584)
Browse filesCo-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
- src/axolotl/train.py +4 -0
src/axolotl/train.py
CHANGED
@@ -117,6 +117,10 @@ def train(
|
|
117 |
|
118 |
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
119 |
|
|
|
|
|
|
|
|
|
120 |
if cfg.relora_steps:
|
121 |
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
122 |
model = model.merge_and_unload()
|
|
|
117 |
|
118 |
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
119 |
|
120 |
+
if trainer.is_fsdp_enabled:
|
121 |
+
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
122 |
+
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
123 |
+
|
124 |
if cfg.relora_steps:
|
125 |
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
126 |
model = model.merge_and_unload()
|