dev-slx commited on
Commit
d3e3703
·
verified ·
1 Parent(s): ca88460

Update elm/model.py

Browse files
Files changed (1) hide show
  1. elm/model.py +3 -0
elm/model.py CHANGED
@@ -343,6 +343,9 @@ def init_elm_model(model_args=ModelArgs(), device="cuda", model_config_dict=None
343
 
344
  dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
345
 
 
 
 
346
  model = ELM(model_args=model_args).to(dtype=dtype)
347
 
348
  return model
 
343
 
344
  dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
345
 
346
+ if not torch.cuda.is_available():
347
+ dtype = torch.bfloat16
348
+
349
  model = ELM(model_args=model_args).to(dtype=dtype)
350
 
351
  return model