lorocksUMD commited on
Commit
2533bec
1 Parent(s): e3e9e4b

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +12 -12
script.py CHANGED
@@ -25,6 +25,16 @@ from io import BytesIO
25
  import re
26
 
27
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  """
30
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
@@ -53,17 +63,8 @@ def load_images(image_files):
53
  out.append(image)
54
  return out
55
 
56
- model_path = "liuhaotian/llava-v1.6-mistral-7b"
57
- model_name = get_model_name_from_path(model_path)
58
- # tokenizer = AutoTokenizer.from_pretrained(model_path)
59
- # model = LlavaMistralForCausalLM.from_pretrained(
60
- # model_path,
61
- # low_cpu_mem_usage=True,
62
- # # offload_folder="/content/sample_data"
63
- # )
64
 
65
- prompt = "What are the things I should be cautious about when I visit here?"
66
- image_file = "Great-Room-4.jpg"
67
 
68
  args = type('Args', (), {
69
  "model_path": model_path,
@@ -79,9 +80,8 @@ args = type('Args', (), {
79
  "max_new_tokens": 512
80
  })()
81
 
82
-
83
  tokenizer, model, image_processor, context_len = load_pretrained_model(
84
- model_path, None, model_name, device_map="cpu"
85
  )
86
 
87
  qs = args.query
 
25
  import re
26
 
27
 
28
+ # Line 138 uncomment the cuda() to use GPUs
29
+
30
+ device = "cpu"
31
+ # device = "auto"
32
+
33
+ prompt = "What are the things I should be cautious about when I visit here?"
34
+ image_file = "Great-Room-4.jpg"
35
+
36
+ model_path = "liuhaotian/llava-v1.6-mistral-7b"
37
+
38
 
39
  """
40
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
63
  out.append(image)
64
  return out
65
 
 
 
 
 
 
 
 
 
66
 
67
+ model_name = get_model_name_from_path(model_path)
 
68
 
69
  args = type('Args', (), {
70
  "model_path": model_path,
 
80
  "max_new_tokens": 512
81
  })()
82
 
 
83
  tokenizer, model, image_processor, context_len = load_pretrained_model(
84
+ model_path, None, model_name, device_map=device
85
  )
86
 
87
  qs = args.query