tmm1 commited on
Commit
f319b0b
1 Parent(s): 7fd662d

rename var and reformat

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +3 -5
src/axolotl/utils/models.py CHANGED
@@ -355,7 +355,7 @@ def load_model(
355
  if hasattr(module, "weight"):
356
  module.to(torch.float32)
357
 
358
- fix_dtype = not cfg.adapter
359
  if not cfg.gptq and (
360
  (cfg.adapter == "lora" and load_in_8bit)
361
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -364,13 +364,11 @@ def load_model(
364
  model = prepare_model_for_kbit_training(
365
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
366
  )
367
- fix_dtype = True
368
 
369
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
370
  # convert them back to fp16/bf16 for flash-attn compatibility.
371
- if fix_dtype and (
372
- cfg.flash_attention and cfg.is_llama_derived_model
373
- ):
374
  for name, module in model.named_modules():
375
  if "norm" in name:
376
  module.to(cfg.torch_dtype)
 
355
  if hasattr(module, "weight"):
356
  module.to(torch.float32)
357
 
358
+ needs_fa2_dtype = not cfg.adapter
359
  if not cfg.gptq and (
360
  (cfg.adapter == "lora" and load_in_8bit)
361
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
 
364
  model = prepare_model_for_kbit_training(
365
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
366
  )
367
+ needs_fa2_dtype = True
368
 
369
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
370
  # convert them back to fp16/bf16 for flash-attn compatibility.
371
+ if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model):
 
 
372
  for name, module in model.named_modules():
373
  if "norm" in name:
374
  module.to(cfg.torch_dtype)