winglian commited on
Commit
b9d07aa
1 Parent(s): 3b4d055

prepare does all this already for qlora?

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +12 -12
src/axolotl/utils/models.py CHANGED
@@ -204,17 +204,17 @@ def load_model(
204
  """### Post-processing on the model
205
  Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.
206
  """
207
- if cfg.adapter == "qlora":
208
- for param in model.parameters():
209
- param.requires_grad = False # freeze the model - train adapters later
210
- if param.ndim == 1:
211
- # cast the small parameters (e.g. layernorm) to fp32 for stability
212
- param.data = param.data.to(torch.float32)
213
- class CastOutputToFloat(nn.Sequential):
214
- def forward(self, x):
215
- return super().forward(x).to(torch.float32)
216
-
217
- model.lm_head = CastOutputToFloat(model.lm_head)
218
 
219
  if not tokenizer:
220
  try:
@@ -255,7 +255,7 @@ def load_model(
255
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
256
  model.resize_token_embeddings(embeddings_len)
257
 
258
- if cfg.adapter and load_in_8bit and not cfg.load_4bit:
259
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
260
  model = prepare_model_for_int8_training(model)
261
 
 
204
  """### Post-processing on the model
205
  Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.
206
  """
207
+ # if cfg.adapter == "qlora":
208
+ # for param in model.parameters():
209
+ # param.requires_grad = False # freeze the model - train adapters later
210
+ # if param.ndim == 1:
211
+ # # cast the small parameters (e.g. layernorm) to fp32 for stability
212
+ # param.data = param.data.to(torch.float32)
213
+ # class CastOutputToFloat(nn.Linear):
214
+ # def forward(self, x):
215
+ # return super().forward(x).to(torch.float32)
216
+ #
217
+ # model.lm_head = CastOutputToFloat(model.lm_head.in_features, model.lm_head.out_features, model.lm_head.bias)
218
 
219
  if not tokenizer:
220
  try:
 
255
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
256
  model.resize_token_embeddings(embeddings_len)
257
 
258
+ if ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") and not cfg.load_4bit:
259
  logging.info("converting PEFT model w/ prepare_model_for_int8_training")
260
  model = prepare_model_for_int8_training(model)
261