lorocksUMD commited on
Commit
19ba12a
1 Parent(s): f295437

Update llava/model/builder.py

Browse files
Files changed (1) hide show
  1. llava/model/builder.py +2 -2
llava/model/builder.py CHANGED
@@ -26,8 +26,8 @@ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, D
26
  def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
27
  kwargs = {"device_map": device_map, **kwargs}
28
 
29
- if device != "cuda":
30
- kwargs['device_map'] = {"": device}
31
 
32
  # if load_8bit:
33
  # kwargs['load_in_8bit'] = True
 
26
  def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
27
  kwargs = {"device_map": device_map, **kwargs}
28
 
29
+ # if device != "cuda":
30
+ # kwargs['device_map'] = {"": device}
31
 
32
  # if load_8bit:
33
  # kwargs['load_in_8bit'] = True