Lint models.py
Browse files- src/axolotl/utils/models.py +34 -30
src/axolotl/utils/models.py
CHANGED
@@ -1,13 +1,16 @@
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import math
|
3 |
import os
|
4 |
from pathlib import Path
|
5 |
-
from typing import Optional, Tuple, TYPE_CHECKING
|
6 |
|
7 |
import bitsandbytes as bnb
|
8 |
import torch
|
9 |
import transformers
|
10 |
-
from transformers import (
|
11 |
AutoModelForCausalLM,
|
12 |
AutoTokenizer,
|
13 |
PreTrainedModel,
|
@@ -18,9 +21,8 @@ from transformers import (
|
|
18 |
try:
|
19 |
from transformers import (
|
20 |
LlamaForCausalLM,
|
21 |
-
LlamaTokenizer,
|
22 |
)
|
23 |
-
except:
|
24 |
logging.warning(
|
25 |
"This version of transformers does not support Llama. Consider upgrading."
|
26 |
)
|
@@ -28,9 +30,9 @@ except:
|
|
28 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
29 |
|
30 |
if TYPE_CHECKING:
|
31 |
-
from peft import
|
32 |
-
from axolotl.utils.dict import DictDefault
|
33 |
-
from transformers import PreTrainedTokenizer
|
34 |
|
35 |
|
36 |
def load_tokenizer(
|
@@ -62,8 +64,8 @@ def load_tokenizer(
|
|
62 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
63 |
|
64 |
if cfg.special_tokens:
|
65 |
-
for k,
|
66 |
-
tokenizer.add_special_tokens({k:
|
67 |
if cfg.tokens:
|
68 |
tokenizer.add_tokens(list(cfg.tokens))
|
69 |
|
@@ -80,6 +82,9 @@ def load_model(
|
|
80 |
inference=False,
|
81 |
):
|
82 |
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
|
|
|
|
|
|
83 |
|
84 |
# TODO refactor as a kwarg
|
85 |
load_in_8bit = cfg.load_in_8bit
|
@@ -115,9 +120,9 @@ def load_model(
|
|
115 |
|
116 |
replace_peft_model_with_int4_lora_model()
|
117 |
from peft import prepare_model_for_int8_training
|
118 |
-
except Exception as
|
119 |
-
logging.exception(
|
120 |
-
raise
|
121 |
|
122 |
model_kwargs = {}
|
123 |
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
@@ -155,7 +160,7 @@ def load_model(
|
|
155 |
"unable to find a cached model file, this will likely fail..."
|
156 |
)
|
157 |
model_path = str(cache_model_path)
|
158 |
-
except:
|
159 |
model_path = cfg.base_model
|
160 |
model, _ = load_llama_model_4bit_low_ram(
|
161 |
base_model_config if base_model_config else base_model,
|
@@ -210,13 +215,13 @@ def load_model(
|
|
210 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
211 |
torch_dtype=torch_dtype,
|
212 |
device_map=cfg.device_map,
|
213 |
-
trust_remote_code=
|
214 |
**model_kwargs,
|
215 |
)
|
216 |
else:
|
217 |
config = AutoConfig.from_pretrained(
|
218 |
base_model,
|
219 |
-
trust_remote_code=
|
220 |
)
|
221 |
model = AutoModelForCausalLM.from_pretrained(
|
222 |
base_model,
|
@@ -225,30 +230,29 @@ def load_model(
|
|
225 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
226 |
torch_dtype=torch_dtype,
|
227 |
device_map=cfg.device_map,
|
228 |
-
trust_remote_code=
|
229 |
**model_kwargs,
|
230 |
)
|
231 |
-
except Exception as
|
232 |
logging.error(
|
233 |
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
234 |
)
|
235 |
-
logging.exception(
|
236 |
model = AutoModelForCausalLM.from_pretrained(
|
237 |
base_model,
|
238 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
239 |
torch_dtype=torch_dtype,
|
240 |
device_map=cfg.device_map,
|
241 |
-
trust_remote_code=
|
242 |
**model_kwargs,
|
243 |
)
|
244 |
|
245 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
246 |
model.resize_token_embeddings(embeddings_len)
|
247 |
|
248 |
-
if (
|
249 |
-
(
|
250 |
-
and
|
251 |
-
and (load_in_8bit or cfg.load_in_4bit)
|
252 |
):
|
253 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
254 |
model = prepare_model_for_int8_training(model)
|
@@ -261,14 +265,14 @@ def load_model(
|
|
261 |
if cfg.gptq:
|
262 |
# Scales to half
|
263 |
logging.info("Fitting 4bit scales and zeros to half")
|
264 |
-
for
|
265 |
-
if "Autograd4bitQuantLinear" in str(type(
|
266 |
-
type(
|
267 |
):
|
268 |
-
if hasattr(
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
|
273 |
if (
|
274 |
torch.cuda.device_count() > 1
|
|
|
1 |
+
"""Module for models and model loading"""
|
2 |
+
|
3 |
+
|
4 |
import logging
|
5 |
import math
|
6 |
import os
|
7 |
from pathlib import Path
|
8 |
+
from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
|
9 |
|
10 |
import bitsandbytes as bnb
|
11 |
import torch
|
12 |
import transformers
|
13 |
+
from transformers import ( # noqa: F401
|
14 |
AutoModelForCausalLM,
|
15 |
AutoTokenizer,
|
16 |
PreTrainedModel,
|
|
|
21 |
try:
|
22 |
from transformers import (
|
23 |
LlamaForCausalLM,
|
|
|
24 |
)
|
25 |
+
except ImportError:
|
26 |
logging.warning(
|
27 |
"This version of transformers does not support Llama. Consider upgrading."
|
28 |
)
|
|
|
30 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
31 |
|
32 |
if TYPE_CHECKING:
|
33 |
+
from peft import PeftConfig # noqa: F401
|
34 |
+
from axolotl.utils.dict import DictDefault # noqa: F401
|
35 |
+
from transformers import PreTrainedTokenizer # noqa: F401
|
36 |
|
37 |
|
38 |
def load_tokenizer(
|
|
|
64 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
65 |
|
66 |
if cfg.special_tokens:
|
67 |
+
for k, val in cfg.special_tokens.items():
|
68 |
+
tokenizer.add_special_tokens({k: val})
|
69 |
if cfg.tokens:
|
70 |
tokenizer.add_tokens(list(cfg.tokens))
|
71 |
|
|
|
82 |
inference=False,
|
83 |
):
|
84 |
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
85 |
+
"""
|
86 |
+
Load a model from a base model and a model type.
|
87 |
+
"""
|
88 |
|
89 |
# TODO refactor as a kwarg
|
90 |
load_in_8bit = cfg.load_in_8bit
|
|
|
120 |
|
121 |
replace_peft_model_with_int4_lora_model()
|
122 |
from peft import prepare_model_for_int8_training
|
123 |
+
except Exception as err:
|
124 |
+
logging.exception(err)
|
125 |
+
raise err
|
126 |
|
127 |
model_kwargs = {}
|
128 |
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
|
|
160 |
"unable to find a cached model file, this will likely fail..."
|
161 |
)
|
162 |
model_path = str(cache_model_path)
|
163 |
+
except Exception: # pylint: disable=broad-exception-caught
|
164 |
model_path = cfg.base_model
|
165 |
model, _ = load_llama_model_4bit_low_ram(
|
166 |
base_model_config if base_model_config else base_model,
|
|
|
215 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
216 |
torch_dtype=torch_dtype,
|
217 |
device_map=cfg.device_map,
|
218 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
219 |
**model_kwargs,
|
220 |
)
|
221 |
else:
|
222 |
config = AutoConfig.from_pretrained(
|
223 |
base_model,
|
224 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
225 |
)
|
226 |
model = AutoModelForCausalLM.from_pretrained(
|
227 |
base_model,
|
|
|
230 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
231 |
torch_dtype=torch_dtype,
|
232 |
device_map=cfg.device_map,
|
233 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
234 |
**model_kwargs,
|
235 |
)
|
236 |
+
except Exception as err: # pylint: disable=broad-exception-caught
|
237 |
logging.error(
|
238 |
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
239 |
)
|
240 |
+
logging.exception(err)
|
241 |
model = AutoModelForCausalLM.from_pretrained(
|
242 |
base_model,
|
243 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
244 |
torch_dtype=torch_dtype,
|
245 |
device_map=cfg.device_map,
|
246 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
247 |
**model_kwargs,
|
248 |
)
|
249 |
|
250 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
251 |
model.resize_token_embeddings(embeddings_len)
|
252 |
|
253 |
+
if not cfg.gptq and (
|
254 |
+
(cfg.adapter == "lora" and load_in_8bit)
|
255 |
+
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
|
256 |
):
|
257 |
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
258 |
model = prepare_model_for_int8_training(model)
|
|
|
265 |
if cfg.gptq:
|
266 |
# Scales to half
|
267 |
logging.info("Fitting 4bit scales and zeros to half")
|
268 |
+
for _, module in model.named_modules():
|
269 |
+
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
270 |
+
type(module)
|
271 |
):
|
272 |
+
if hasattr(module, "is_v1_model") and module.is_v1_model:
|
273 |
+
module.zeros = module.zeros.half()
|
274 |
+
module.scales = module.scales.half()
|
275 |
+
module.bias = module.bias.half()
|
276 |
|
277 |
if (
|
278 |
torch.cuda.device_count() > 1
|