fix torch_dtype for model load
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -62,9 +62,12 @@ def load_model(
|
|
62 |
logging.info("patching with xformers attention")
|
63 |
hijack_llama_attention()
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
68 |
try:
|
69 |
if cfg.load_4bit:
|
70 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
|
|
62 |
logging.info("patching with xformers attention")
|
63 |
hijack_llama_attention()
|
64 |
|
65 |
+
if cfg.bf16:
|
66 |
+
torch_dtype = torch.bfloat16
|
67 |
+
elif cfg.load_in_8bit or cfg.fp16:
|
68 |
+
torch_dtype = torch.float16
|
69 |
+
else:
|
70 |
+
torch_dtype = torch.float32
|
71 |
try:
|
72 |
if cfg.load_4bit:
|
73 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|