fixes to make qlora actually work
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -248,7 +248,7 @@ def load_model(
|
|
248 |
|
249 |
if (
|
250 |
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
251 |
-
) and not cfg.load_4bit:
|
252 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
253 |
model = prepare_model_for_int8_training(model)
|
254 |
|
@@ -297,7 +297,7 @@ def load_adapter(model, cfg, adapter):
|
|
297 |
|
298 |
if adapter is None:
|
299 |
return model, None
|
300 |
-
if adapter
|
301 |
return load_lora(model, cfg)
|
302 |
if adapter == "llama-adapter":
|
303 |
return load_llama_adapter(model, cfg)
|
|
|
248 |
|
249 |
if (
|
250 |
(cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora"
|
251 |
+
) and not cfg.load_4bit and (load_in_8bit or cfg.load_in_4bit):
|
252 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
253 |
model = prepare_model_for_int8_training(model)
|
254 |
|
|
|
297 |
|
298 |
if adapter is None:
|
299 |
return model, None
|
300 |
+
if adapter in ["lora" , "qlora"]:
|
301 |
return load_lora(model, cfg)
|
302 |
if adapter == "llama-adapter":
|
303 |
return load_llama_adapter(model, cfg)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -205,7 +205,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
205 |
)
|
206 |
callbacks.append(early_stop_cb)
|
207 |
|
208 |
-
if cfg.local_rank == 0 and cfg.adapter
|
209 |
callbacks.append(SavePeftModelCallback)
|
210 |
|
211 |
data_collator_kwargs = {
|
|
|
205 |
)
|
206 |
callbacks.append(early_stop_cb)
|
207 |
|
208 |
+
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
|
209 |
callbacks.append(SavePeftModelCallback)
|
210 |
|
211 |
data_collator_kwargs = {
|