tmm1 commited on
Commit
3a011ea
1 Parent(s): 1f613e5

fix condition and add logging

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +2 -1
src/axolotl/utils/models.py CHANGED
@@ -355,7 +355,7 @@ def load_model(
355
  if hasattr(module, "weight"):
356
  module.to(torch.float32)
357
 
358
- needs_fa2_dtype = not cfg.adapter
359
  if not cfg.gptq and (
360
  (cfg.adapter == "lora" and load_in_8bit)
361
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -369,6 +369,7 @@ def load_model(
369
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
370
  # convert them back to fp16/bf16 for flash-attn compatibility.
371
  if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model):
 
372
  for name, module in model.named_modules():
373
  if "norm" in name:
374
  module.to(cfg.torch_dtype)
 
355
  if hasattr(module, "weight"):
356
  module.to(torch.float32)
357
 
358
+ needs_fa2_dtype = cfg.adapter is not None
359
  if not cfg.gptq and (
360
  (cfg.adapter == "lora" and load_in_8bit)
361
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
 
369
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
370
  # convert them back to fp16/bf16 for flash-attn compatibility.
371
  if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model):
372
+ LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
373
  for name, module in model.named_modules():
374
  if "norm" in name:
375
  module.to(cfg.torch_dtype)