Fix: bf16 support for inference (#981)
Browse files* Fix: bf16 torch dtype
* simplify casting to device and dtype
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
src/axolotl/cli/__init__.py
CHANGED
@@ -103,7 +103,7 @@ def do_inference(
|
|
103 |
importlib.import_module("axolotl.prompters"), prompter
|
104 |
)
|
105 |
|
106 |
-
model = model.to(cfg.device)
|
107 |
|
108 |
while True:
|
109 |
print("=" * 80)
|
@@ -168,7 +168,7 @@ def do_inference_gradio(
|
|
168 |
importlib.import_module("axolotl.prompters"), prompter
|
169 |
)
|
170 |
|
171 |
-
model = model.to(cfg.device)
|
172 |
|
173 |
def generate(instruction):
|
174 |
if not instruction:
|
|
|
103 |
importlib.import_module("axolotl.prompters"), prompter
|
104 |
)
|
105 |
|
106 |
+
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
107 |
|
108 |
while True:
|
109 |
print("=" * 80)
|
|
|
168 |
importlib.import_module("axolotl.prompters"), prompter
|
169 |
)
|
170 |
|
171 |
+
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
172 |
|
173 |
def generate(instruction):
|
174 |
if not instruction:
|