winglian commited on
Commit
6dfdd2d
1 Parent(s): 29936bb

don't load models in 8bit unless they are using an adapter, also fix tokenizer load in exceptional case

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +5 -5
src/axolotl/utils/models.py CHANGED
@@ -109,7 +109,7 @@ def load_model(
109
  else:
110
  model = LlamaForCausalLM.from_pretrained(
111
  base_model,
112
- load_in_8bit=cfg.load_in_8bit,
113
  torch_dtype=torch_dtype,
114
  device_map=cfg.device_map,
115
  )
@@ -117,14 +117,14 @@ def load_model(
117
  elif model_type:
118
  model = getattr(transformers, model_type).from_pretrained(
119
  base_model,
120
- load_in_8bit=cfg.load_in_8bit,
121
  torch_dtype=torch_dtype,
122
  device_map=cfg.device_map,
123
  )
124
  else:
125
  model = AutoModelForCausalLM.from_pretrained(
126
  base_model,
127
- load_in_8bit=cfg.load_in_8bit,
128
  torch_dtype=torch_dtype,
129
  device_map=cfg.device_map,
130
  )
@@ -135,7 +135,7 @@ def load_model(
135
  logging.exception(e)
136
  model = AutoModelForCausalLM.from_pretrained(
137
  base_model,
138
- load_in_8bit=cfg.load_in_8bit,
139
  torch_dtype=torch_dtype,
140
  device_map=cfg.device_map,
141
  )
@@ -147,7 +147,7 @@ def load_model(
147
  else:
148
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
149
  except:
150
- tokenizer = AutoTokenizer.from_pretrained(base_model)
151
 
152
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
153
  logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
 
109
  else:
110
  model = LlamaForCausalLM.from_pretrained(
111
  base_model,
112
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
113
  torch_dtype=torch_dtype,
114
  device_map=cfg.device_map,
115
  )
 
117
  elif model_type:
118
  model = getattr(transformers, model_type).from_pretrained(
119
  base_model,
120
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
121
  torch_dtype=torch_dtype,
122
  device_map=cfg.device_map,
123
  )
124
  else:
125
  model = AutoModelForCausalLM.from_pretrained(
126
  base_model,
127
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
128
  torch_dtype=torch_dtype,
129
  device_map=cfg.device_map,
130
  )
 
135
  logging.exception(e)
136
  model = AutoModelForCausalLM.from_pretrained(
137
  base_model,
138
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
139
  torch_dtype=torch_dtype,
140
  device_map=cfg.device_map,
141
  )
 
147
  else:
148
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
149
  except:
150
+ tokenizer = AutoTokenizer.from_pretrained(base_model_config)
151
 
152
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
153
  logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")