ensure enable_input_require_grads is called on model before getting the peft model (#345)
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -391,6 +391,8 @@ def load_adapter(model, cfg, adapter):
|
|
391 |
|
392 |
if adapter is None:
|
393 |
return model, None
|
|
|
|
|
394 |
if adapter in ["lora", "qlora"]:
|
395 |
return load_lora(model, cfg)
|
396 |
if adapter == "llama-adapter":
|
|
|
391 |
|
392 |
if adapter is None:
|
393 |
return model, None
|
394 |
+
if hasattr(model, "enable_input_require_grads"):
|
395 |
+
model.enable_input_require_grads()
|
396 |
if adapter in ["lora", "qlora"]:
|
397 |
return load_lora(model, cfg)
|
398 |
if adapter == "llama-adapter":
|