Update elm/model.py
Browse files- elm/model.py +3 -0
elm/model.py
CHANGED
@@ -343,6 +343,9 @@ def init_elm_model(model_args=ModelArgs(), device="cuda", model_config_dict=None
|
|
343 |
|
344 |
dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
345 |
|
|
|
|
|
|
|
346 |
model = ELM(model_args=model_args).to(dtype=dtype)
|
347 |
|
348 |
return model
|
|
|
343 |
|
344 |
dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
345 |
|
346 |
+
if not torch.cuda.is_available():
|
347 |
+
dtype = torch.bfloat16
|
348 |
+
|
349 |
model = ELM(model_args=model_args).to(dtype=dtype)
|
350 |
|
351 |
return model
|