split up llama model loading so config can be loaded from base config and models can be loaded from a path
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
|
10 |
import bitsandbytes as bnb
|
11 |
import torch
|
12 |
import transformers
|
13 |
-
from transformers import AutoModelForCausalLM # noqa: F401
|
14 |
from transformers import PreTrainedModel # noqa: F401
|
15 |
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
16 |
|
@@ -172,8 +172,10 @@ def load_model(
|
|
172 |
)
|
173 |
load_in_8bit = False
|
174 |
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
|
|
175 |
model = LlamaForCausalLM.from_pretrained(
|
176 |
base_model,
|
|
|
177 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
178 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
179 |
torch_dtype=torch_dtype,
|
|
|
10 |
import bitsandbytes as bnb
|
11 |
import torch
|
12 |
import transformers
|
13 |
+
from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401
|
14 |
from transformers import PreTrainedModel # noqa: F401
|
15 |
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
16 |
|
|
|
172 |
)
|
173 |
load_in_8bit = False
|
174 |
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
175 |
+
config = LlamaConfig.from_pretrained(base_model_config)
|
176 |
model = LlamaForCausalLM.from_pretrained(
|
177 |
base_model,
|
178 |
+
config=config,
|
179 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
180 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
181 |
torch_dtype=torch_dtype,
|