lorocksUMD commited on
Commit
7124b11
1 Parent(s): 599f85f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -1
app.py CHANGED
@@ -3,13 +3,50 @@ from huggingface_hub import InferenceClient
3
  from transformers import AutoTokenizer
4
  from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM
5
  from llava.model.builder import load_pretrained_model
6
- from llava.mm_utils import get_model_name_from_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  """
9
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
  """
11
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  model_path = "liuhaotian/llava-v1.6-mistral-7b"
14
  model_name = get_model_name_from_path(model_path)
15
  # tokenizer = AutoTokenizer.from_pretrained(model_path)
@@ -19,10 +56,101 @@ model_name = get_model_name_from_path(model_path)
19
  # # offload_folder="/content/sample_data"
20
  # )
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  tokenizer, model, image_processor, context_len = load_pretrained_model(
23
  model_path, None, model_name
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def respond(
28
  message,
 
3
  from transformers import AutoTokenizer
4
  from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM
5
  from llava.model.builder import load_pretrained_model
6
+ from llava.mm_utils import (
7
+ process_images,
8
+ tokenizer_image_token,
9
+ get_model_name_from_path,
10
+ )
11
+ from llava.constants import (
12
+ IMAGE_TOKEN_INDEX,
13
+ DEFAULT_IMAGE_TOKEN,
14
+ DEFAULT_IM_START_TOKEN,
15
+ DEFAULT_IM_END_TOKEN,
16
+ IMAGE_PLACEHOLDER,
17
+ )
18
+ from llava.conversation import conv_templates, SeparatorStyle
19
+
20
+
21
+
22
 
23
  """
24
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
25
  """
26
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
27
 
28
+ # Functions for inference
29
+ def image_parser(args):
30
+ out = args.image_file.split(args.sep)
31
+ return out
32
+
33
+
34
+ def load_image(image_file):
35
+ if image_file.startswith("http") or image_file.startswith("https"):
36
+ response = requests.get(image_file)
37
+ image = Image.open(BytesIO(response.content)).convert("RGB")
38
+ else:
39
+ image = Image.open(image_file).convert("RGB")
40
+ return image
41
+
42
+
43
+ def load_images(image_files):
44
+ out = []
45
+ for image_file in image_files:
46
+ image = load_image(image_file)
47
+ out.append(image)
48
+ return out
49
+
50
  model_path = "liuhaotian/llava-v1.6-mistral-7b"
51
  model_name = get_model_name_from_path(model_path)
52
  # tokenizer = AutoTokenizer.from_pretrained(model_path)
 
56
  # # offload_folder="/content/sample_data"
57
  # )
58
 
59
+ prompt = "What are the things I should be cautious about when I visit here?"
60
+ image_file = "https://llava-vl.github.io/static/images/view.jpg"
61
+
62
+ args = type('Args', (), {
63
+ "model_path": model_path,
64
+ "model_base": None,
65
+ "model_name": get_model_name_from_path(model_path),
66
+ "query": prompt,
67
+ "conv_mode": None,
68
+ "image_file": image_file,
69
+ "sep": ",",
70
+ "temperature": 0,
71
+ "top_p": None,
72
+ "num_beams": 1,
73
+ "max_new_tokens": 512
74
+ })()
75
+
76
+
77
  tokenizer, model, image_processor, context_len = load_pretrained_model(
78
  model_path, None, model_name
79
  )
80
 
81
+ qs = args.query
82
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
83
+ if IMAGE_PLACEHOLDER in qs:
84
+ if model.config.mm_use_im_start_end:
85
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
86
+ else:
87
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
88
+ else:
89
+ if model.config.mm_use_im_start_end:
90
+ qs = image_token_se + "\n" + qs
91
+ else:
92
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
93
+
94
+ if "llama-2" in model_name.lower():
95
+ conv_mode = "llava_llama_2"
96
+ elif "mistral" in model_name.lower():
97
+ conv_mode = "mistral_instruct"
98
+ elif "v1.6-34b" in model_name.lower():
99
+ conv_mode = "chatml_direct"
100
+ elif "v1" in model_name.lower():
101
+ conv_mode = "llava_v1"
102
+ elif "mpt" in model_name.lower():
103
+ conv_mode = "mpt"
104
+ else:
105
+ conv_mode = "llava_v0"
106
+
107
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
108
+ print(
109
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
110
+ conv_mode, args.conv_mode, args.conv_mode
111
+ )
112
+ )
113
+ else:
114
+ args.conv_mode = conv_mode
115
+
116
+ conv = conv_templates[args.conv_mode].copy()
117
+ conv.append_message(conv.roles[0], qs)
118
+ conv.append_message(conv.roles[1], None)
119
+ prompt = conv.get_prompt()
120
+
121
+ image_files = image_parser(args)
122
+ images = load_images(image_files)
123
+ image_sizes = [x.size for x in images]
124
+ images_tensor = process_images(
125
+ images,
126
+ image_processor,
127
+ model.config
128
+ ).to(model.device, dtype=torch.float16)
129
+
130
+ input_ids = (
131
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
132
+ .unsqueeze(0)
133
+ .cuda()
134
+ )
135
+
136
+ with torch.inference_mode():
137
+ output_ids = model.generate(
138
+ input_ids,
139
+ images=images_tensor,
140
+ image_sizes=image_sizes,
141
+ do_sample=True if args.temperature > 0 else False,
142
+ temperature=args.temperature,
143
+ top_p=args.top_p,
144
+ num_beams=args.num_beams,
145
+ max_new_tokens=args.max_new_tokens,
146
+ use_cache=True,
147
+ )
148
+
149
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
150
+ print(outputs)
151
+ # End Llava inference
152
+
153
+
154
 
155
  def respond(
156
  message,