lorocksUMD commited on
Commit
d5d1509
1 Parent(s): 3b00bc2

Update llava/model/builder.py

Browse files
Files changed (1) hide show
  1. llava/model/builder.py +4 -3
llava/model/builder.py CHANGED
@@ -31,9 +31,10 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
31
 
32
  load_8bit = True
33
  if load_8bit:
34
- kwargs['load_in_8bit'] = True
 
35
  elif load_4bit:
36
- kwargs['load_in_4bit'] = True
37
  kwargs['quantization_config'] = BitsAndBytesConfig(
38
  load_in_4bit=True,
39
  bnb_4bit_compute_dtype=torch.float16,
@@ -41,7 +42,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
41
  bnb_4bit_quant_type='nf4'
42
  )
43
  else:
44
- kwargs['torch_dtype'] = torch.float16
45
 
46
  if use_flash_attn:
47
  kwargs['attn_implementation'] = 'flash_attention_2'
 
31
 
32
  load_8bit = True
33
  if load_8bit:
34
+ # kwargs['load_in_8bit'] = True
35
+ kwargs['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
36
  elif load_4bit:
37
+ # kwargs['load_in_4bit'] = True
38
  kwargs['quantization_config'] = BitsAndBytesConfig(
39
  load_in_4bit=True,
40
  bnb_4bit_compute_dtype=torch.float16,
 
42
  bnb_4bit_quant_type='nf4'
43
  )
44
  else:
45
+ kwargs['torch_dtype'] = torch.float32
46
 
47
  if use_flash_attn:
48
  kwargs['attn_implementation'] = 'flash_attention_2'