Nanobit winglian commited on
Commit
2d65f47
1 Parent(s): dfd1885

fix(model): apply gate fp32 only for mixtral (#1241)

Browse files

* fix(model): apply gate fp32 only for mixtral

* Update src/axolotl/utils/models.py

* fix gate layer check

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

Files changed (1) hide show
  1. src/axolotl/utils/models.py +1 -1
src/axolotl/utils/models.py CHANGED
@@ -676,7 +676,7 @@ def load_model(
676
  if not cfg.fsdp:
677
  # FSDP doesn't like mixed Float and BFloat16
678
  for name, module in model.named_modules():
679
- if any(m in name for m in ["norm", "gate"]):
680
  module.to(torch.float32)
681
  if model_config.model_type == "btlm":
682
  # don't upcast lm_head for btlm
 
676
  if not cfg.fsdp:
677
  # FSDP doesn't like mixed Float and BFloat16
678
  for name, module in model.named_modules():
679
+ if "norm" in name or name.endswith(".gate"):
680
  module.to(torch.float32)
681
  if model_config.model_type == "btlm":
682
  # don't upcast lm_head for btlm