remove unnecessary local variable
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -87,7 +87,6 @@ def load_model(
|
|
87 |
base_model = cfg.base_model
|
88 |
base_model_config = cfg.base_model_config
|
89 |
model_type = cfg.model_type
|
90 |
-
adapter = cfg.adapter
|
91 |
|
92 |
# TODO refactor as a kwarg
|
93 |
load_in_8bit = cfg.load_in_8bit
|
@@ -359,7 +358,7 @@ def load_model(
|
|
359 |
if hasattr(module, "weight"):
|
360 |
module.to(torch_dtype)
|
361 |
|
362 |
-
model, lora_config = load_adapter(model, cfg, adapter)
|
363 |
|
364 |
if cfg.ddp and not load_in_8bit:
|
365 |
model.to(f"cuda:{cfg.local_rank}")
|
|
|
87 |
base_model = cfg.base_model
|
88 |
base_model_config = cfg.base_model_config
|
89 |
model_type = cfg.model_type
|
|
|
90 |
|
91 |
# TODO refactor as a kwarg
|
92 |
load_in_8bit = cfg.load_in_8bit
|
|
|
358 |
if hasattr(module, "weight"):
|
359 |
module.to(torch_dtype)
|
360 |
|
361 |
+
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
362 |
|
363 |
if cfg.ddp and not load_in_8bit:
|
364 |
model.to(f"cuda:{cfg.local_rank}")
|