Tazik Shahjahan winglian commited on
Commit
3678a6c
1 Parent(s): f8ae59b

Fix: bf16 support for inference (#981)

Browse files

* Fix: bf16 torch dtype

* simplify casting to device and dtype

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

Files changed (1) hide show
  1. src/axolotl/cli/__init__.py +2 -2
src/axolotl/cli/__init__.py CHANGED
@@ -103,7 +103,7 @@ def do_inference(
103
  importlib.import_module("axolotl.prompters"), prompter
104
  )
105
 
106
- model = model.to(cfg.device)
107
 
108
  while True:
109
  print("=" * 80)
@@ -168,7 +168,7 @@ def do_inference_gradio(
168
  importlib.import_module("axolotl.prompters"), prompter
169
  )
170
 
171
- model = model.to(cfg.device)
172
 
173
  def generate(instruction):
174
  if not instruction:
 
103
  importlib.import_module("axolotl.prompters"), prompter
104
  )
105
 
106
+ model = model.to(cfg.device, dtype=cfg.torch_dtype)
107
 
108
  while True:
109
  print("=" * 80)
 
168
  importlib.import_module("axolotl.prompters"), prompter
169
  )
170
 
171
+ model = model.to(cfg.device, dtype=cfg.torch_dtype)
172
 
173
  def generate(instruction):
174
  if not instruction: