Nanobit commited on
Commit
c3e8165
1 Parent(s): 7f38175

fix: torch_dtype mistral default to fp32 (#1050)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +4 -1
src/axolotl/utils/models.py CHANGED
@@ -599,7 +599,10 @@ def load_model(
599
 
600
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
601
  # convert them back to fp16/bf16 for flash-attn compatibility.
602
- if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
 
 
 
603
  LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
604
  for name, module in model.named_modules():
605
  if "norm" in name:
 
599
 
600
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
601
  # convert them back to fp16/bf16 for flash-attn compatibility.
602
+ if needs_fa2_dtype or (
603
+ cfg.flash_attention
604
+ and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model)
605
+ ):
606
  LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
607
  for name, module in model.named_modules():
608
  if "norm" in name: