winglian commited on
Commit
c7021e1
·
unverified ·
2 Parent(s): 876edd8 e3c494c

Merge pull request #120 from OpenAccess-AI-Collective/model-from-path

Browse files

split up llama model loading so config can be loaded from base config and models can be loaded from a path

Files changed (3) hide show
  1. README.md +3 -0
  2. scripts/finetune.py +3 -2
  3. src/axolotl/utils/models.py +12 -6
README.md CHANGED
@@ -171,6 +171,9 @@ base_model_ignore_patterns:
171
  # if the base_model repo on hf hub doesn't include configuration .json files,
172
  # you can set that here, or leave this empty to default to base_model
173
  base_model_config: ./llama-7b-hf
 
 
 
174
  # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
175
  model_type: AutoModelForCausalLM
176
  # Corresponding tokenizer for the model AutoTokenizer is a good choice
 
171
  # if the base_model repo on hf hub doesn't include configuration .json files,
172
  # you can set that here, or leave this empty to default to base_model
173
  base_model_config: ./llama-7b-hf
174
+ # Optional tokenizer configuration override in case you want to use a different tokenizer
175
+ # than the one defined in the base model
176
+ tokenizer_config:
177
  # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
178
  model_type: AutoModelForCausalLM
179
  # Corresponding tokenizer for the model AutoTokenizer is a good choice
scripts/finetune.py CHANGED
@@ -173,8 +173,9 @@ def train(
173
  cfg.bf16 = False
174
 
175
  # load the tokenizer first
176
- logging.info("loading tokenizer...")
177
- tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
 
178
 
179
  if check_not_in(
180
  ["inference", "shard", "merge_lora"], kwargs
 
173
  cfg.bf16 = False
174
 
175
  # load the tokenizer first
176
+ tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
177
+ logging.info(f"loading tokenizer... {tokenizer_config}")
178
+ tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
179
 
180
  if check_not_in(
181
  ["inference", "shard", "merge_lora"], kwargs
src/axolotl/utils/models.py CHANGED
@@ -10,9 +10,14 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
- from transformers import AutoModelForCausalLM # noqa: F401
14
  from transformers import PreTrainedModel # noqa: F401
15
- from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
 
 
 
 
 
 
16
 
17
  try:
18
  from transformers import LlamaForCausalLM
@@ -25,24 +30,23 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
25
 
26
  if TYPE_CHECKING:
27
  from peft import PeftConfig # noqa: F401
28
- from transformers import PreTrainedTokenizer # noqa: F401
29
 
30
  from axolotl.utils.dict import DictDefault # noqa: F401
31
 
32
 
33
  def load_tokenizer(
34
- base_model_config,
35
  tokenizer_type,
36
  cfg,
37
  ):
38
  if tokenizer_type:
39
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
40
- base_model_config,
41
  trust_remote_code=cfg.trust_remote_code or False,
42
  )
43
  else:
44
  tokenizer = AutoTokenizer.from_pretrained(
45
- base_model_config,
46
  trust_remote_code=cfg.trust_remote_code or False,
47
  )
48
 
@@ -172,8 +176,10 @@ def load_model(
172
  )
173
  load_in_8bit = False
174
  elif is_llama_derived_model and "LlamaForCausalLM" in globals():
 
175
  model = LlamaForCausalLM.from_pretrained(
176
  base_model,
 
177
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
178
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
179
  torch_dtype=torch_dtype,
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
 
13
  from transformers import PreTrainedModel # noqa: F401
14
+ from transformers import ( # noqa: F401
15
+ AutoConfig,
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ BitsAndBytesConfig,
19
+ LlamaConfig,
20
+ )
21
 
22
  try:
23
  from transformers import LlamaForCausalLM
 
30
 
31
  if TYPE_CHECKING:
32
  from peft import PeftConfig # noqa: F401
 
33
 
34
  from axolotl.utils.dict import DictDefault # noqa: F401
35
 
36
 
37
  def load_tokenizer(
38
+ tokenizer_config,
39
  tokenizer_type,
40
  cfg,
41
  ):
42
  if tokenizer_type:
43
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
44
+ tokenizer_config,
45
  trust_remote_code=cfg.trust_remote_code or False,
46
  )
47
  else:
48
  tokenizer = AutoTokenizer.from_pretrained(
49
+ tokenizer_config,
50
  trust_remote_code=cfg.trust_remote_code or False,
51
  )
52
 
 
176
  )
177
  load_in_8bit = False
178
  elif is_llama_derived_model and "LlamaForCausalLM" in globals():
179
+ config = LlamaConfig.from_pretrained(base_model_config)
180
  model = LlamaForCausalLM.from_pretrained(
181
  base_model,
182
+ config=config,
183
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
184
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
185
  torch_dtype=torch_dtype,