Maxime winglian commited on
Commit
1991946
1 Parent(s): f51c9c5

fix: bad dtype for full finetune (#504)

Browse files

* fix: bad dtype for full finetune

* Update src/axolotl/utils/models.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update models.py

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

Files changed (1) hide show
  1. src/axolotl/utils/models.py +1 -1
src/axolotl/utils/models.py CHANGED
@@ -371,7 +371,7 @@ def load_model(
371
 
372
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
373
  # convert them back to fp16/bf16 for flash-attn compatibility.
374
- if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model):
375
  LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
376
  for name, module in model.named_modules():
377
  if "norm" in name:
 
371
 
372
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
373
  # convert them back to fp16/bf16 for flash-attn compatibility.
374
+ if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
375
  LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
376
  for name, module in model.named_modules():
377
  if "norm" in name: