winglian commited on
Commit
2520ecd
·
1 Parent(s): c5b0af1

split up llama model loading so config can be loaded from base config and models can be loaded from a path

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +3 -1
src/axolotl/utils/models.py CHANGED
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
- from transformers import AutoModelForCausalLM # noqa: F401
14
  from transformers import PreTrainedModel # noqa: F401
15
  from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
16
 
@@ -172,8 +172,10 @@ def load_model(
172
  )
173
  load_in_8bit = False
174
  elif is_llama_derived_model and "LlamaForCausalLM" in globals():
 
175
  model = LlamaForCausalLM.from_pretrained(
176
  base_model,
 
177
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
178
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
179
  torch_dtype=torch_dtype,
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
+ from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401
14
  from transformers import PreTrainedModel # noqa: F401
15
  from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
16
 
 
172
  )
173
  load_in_8bit = False
174
  elif is_llama_derived_model and "LlamaForCausalLM" in globals():
175
+ config = LlamaConfig.from_pretrained(base_model_config)
176
  model = LlamaForCausalLM.from_pretrained(
177
  base_model,
178
+ config=config,
179
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
180
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
181
  torch_dtype=torch_dtype,