fix(model): apply gate fp32 only for mixtral (#1241)
Browse files* fix(model): apply gate fp32 only for mixtral
* Update src/axolotl/utils/models.py
* fix gate layer check
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
src/axolotl/utils/models.py
CHANGED
@@ -676,7 +676,7 @@ def load_model(
|
|
676 |
if not cfg.fsdp:
|
677 |
# FSDP doesn't like mixed Float and BFloat16
|
678 |
for name, module in model.named_modules():
|
679 |
-
if
|
680 |
module.to(torch.float32)
|
681 |
if model_config.model_type == "btlm":
|
682 |
# don't upcast lm_head for btlm
|
|
|
676 |
if not cfg.fsdp:
|
677 |
# FSDP doesn't like mixed Float and BFloat16
|
678 |
for name, module in model.named_modules():
|
679 |
+
if "norm" in name or name.endswith(".gate"):
|
680 |
module.to(torch.float32)
|
681 |
if model_config.model_type == "btlm":
|
682 |
# don't upcast lm_head for btlm
|