ensure merged model matches the training dtype (#902)
Browse files* ensure merged model matches the training dtype
* Update src/axolotl/cli/__init__.py
* Update src/axolotl/cli/__init__.py
src/axolotl/cli/__init__.py
CHANGED
@@ -72,7 +72,7 @@ def do_merge_lora(
|
|
72 |
|
73 |
LOG.info("running merge of LoRA with base model")
|
74 |
model = model.merge_and_unload()
|
75 |
-
model.to(dtype=
|
76 |
|
77 |
if cfg.local_rank == 0:
|
78 |
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
|
|
72 |
|
73 |
LOG.info("running merge of LoRA with base model")
|
74 |
model = model.merge_and_unload()
|
75 |
+
model.to(dtype=cfg.torch_dtype)
|
76 |
|
77 |
if cfg.local_rank == 0:
|
78 |
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|