simplify load_model signature
Browse files- scripts/finetune.py +1 -8
- src/axolotl/utils/models.py +7 -4
scripts/finetune.py
CHANGED
@@ -252,14 +252,7 @@ def train(
|
|
252 |
|
253 |
# Load the model and tokenizer
|
254 |
LOG.info("loading model and peft_config...")
|
255 |
-
model, peft_config = load_model(
|
256 |
-
cfg.base_model,
|
257 |
-
cfg.base_model_config,
|
258 |
-
cfg.model_type,
|
259 |
-
tokenizer,
|
260 |
-
cfg,
|
261 |
-
adapter=cfg.adapter,
|
262 |
-
)
|
263 |
|
264 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
265 |
LOG.info("running merge of LoRA with base model")
|
|
|
252 |
|
253 |
# Load the model and tokenizer
|
254 |
LOG.info("loading model and peft_config...")
|
255 |
+
model, peft_config = load_model(cfg, tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
258 |
LOG.info("running merge of LoRA with base model")
|
src/axolotl/utils/models.py
CHANGED
@@ -77,12 +77,15 @@ def load_tokenizer(
|
|
77 |
|
78 |
|
79 |
def load_model(
|
80 |
-
|
81 |
-
):
|
82 |
-
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
83 |
"""
|
84 |
-
Load a model
|
85 |
"""
|
|
|
|
|
|
|
|
|
86 |
|
87 |
# TODO refactor as a kwarg
|
88 |
load_in_8bit = cfg.load_in_8bit
|
|
|
77 |
|
78 |
|
79 |
def load_model(
|
80 |
+
cfg, tokenizer
|
81 |
+
): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
|
|
82 |
"""
|
83 |
+
Load a model for a given configuration and tokenizer.
|
84 |
"""
|
85 |
+
base_model = cfg.base_model
|
86 |
+
base_model_config = cfg.base_model_config
|
87 |
+
model_type = cfg.model_type
|
88 |
+
adapter = cfg.adapter
|
89 |
|
90 |
# TODO refactor as a kwarg
|
91 |
load_in_8bit = cfg.load_in_8bit
|