Senqiao commited on
Commit
b0b7b8a
·
1 Parent(s): bce7a77

update the code to official test

Browse files
Files changed (2) hide show
  1. app.py +822 -3
  2. app_dev_debug.py +1 -1
app.py CHANGED
@@ -1,4 +1,823 @@
 
 
 
 
1
  import time
2
- print('hello world')# Make the app stop and avoid the port conflict
3
- while True:
4
- time.sleep(10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
  import time
6
+
7
+ import gradio as gr
8
+ import requests
9
+
10
+ from llava.conversation import (default_conversation, conv_templates,
11
+ SeparatorStyle)
12
+ from llava.constants import LOGDIR
13
+ from llava.utils import (build_logger, server_error_msg,
14
+ violates_moderation, moderation_msg)
15
+ import hashlib
16
+ import subprocess
17
+ import sys
18
+ import time
19
+
20
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
21
+
22
+ headers = {"User-Agent": "LLaVA Client"}
23
+
24
+ no_change_btn = gr.Button()
25
+ enable_btn = gr.Button(interactive=True)
26
+ disable_btn = gr.Button(interactive=False)
27
+
28
+ priority = {
29
+ "vicuna-13b": "aaaaaaa",
30
+ "koala-13b": "aaaaaab",
31
+ }
32
+
33
+
34
+ def get_conv_log_filename():
35
+ t = datetime.datetime.now()
36
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
37
+ return name
38
+
39
+
40
+ def get_model_list():
41
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
42
+ assert ret.status_code == 200
43
+ ret = requests.post(args.controller_url + "/list_models")
44
+ models = ret.json()["models"]
45
+ models.sort(key=lambda x: priority.get(x, x))
46
+ logger.info(f"Models: {models}")
47
+ return models
48
+
49
+
50
+ get_window_url_params = """
51
+ function() {
52
+ const params = new URLSearchParams(window.location.search);
53
+ url_params = Object.fromEntries(params);
54
+ console.log(url_params);
55
+ return url_params;
56
+ }
57
+ """
58
+
59
+
60
+ def load_demo(url_params, request: gr.Request):
61
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
62
+
63
+ dropdown_update = gr.Dropdown(visible=True)
64
+ if "model" in url_params:
65
+ model = url_params["model"]
66
+ if model in models:
67
+ dropdown_update = gr.Dropdown(value=model, visible=True)
68
+
69
+ state = default_conversation.copy()
70
+ return state, dropdown_update
71
+
72
+
73
+ def load_demo_refresh_model_list(request: gr.Request):
74
+ logger.info(f"load_demo. ip: {request.client.host}")
75
+ models = get_model_list()
76
+ state = default_conversation.copy()
77
+ dropdown_update = gr.Dropdown(
78
+ choices=models,
79
+ value=models[0] if len(models) > 0 else ""
80
+ )
81
+ return state, dropdown_update
82
+
83
+
84
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
85
+ with open(get_conv_log_filename(), "a") as fout:
86
+ data = {
87
+ "tstamp": round(time.time(), 4),
88
+ "type": vote_type,
89
+ "model": model_selector,
90
+ "state": state.dict(),
91
+ "ip": request.client.host,
92
+ }
93
+ fout.write(json.dumps(data) + "\n")
94
+
95
+
96
+ def upvote_last_response(state, model_selector, request: gr.Request):
97
+ logger.info(f"upvote. ip: {request.client.host}")
98
+ vote_last_response(state, "upvote", model_selector, request)
99
+ return ("",) + (disable_btn,) * 3
100
+
101
+
102
+ def downvote_last_response(state, model_selector, request: gr.Request):
103
+ logger.info(f"downvote. ip: {request.client.host}")
104
+ vote_last_response(state, "downvote", model_selector, request)
105
+ return ("",) + (disable_btn,) * 3
106
+
107
+
108
+ def flag_last_response(state, model_selector, request: gr.Request):
109
+ logger.info(f"flag. ip: {request.client.host}")
110
+ vote_last_response(state, "flag", model_selector, request)
111
+ return ("",) + (disable_btn,) * 3
112
+
113
+
114
+ def regenerate(state, masked_image, image_process_mode, request: gr.Request):
115
+ logger.info(f"regenerate. ip: {request.client.host}")
116
+ state.messages[-1][-1] = None
117
+ prev_human_msg = state.messages[-2]
118
+ if type(prev_human_msg[1]) in (tuple, list):
119
+ prev_human_msg[1] = (*prev_human_msg[1][:3], image_process_mode)
120
+ state.skip_next = False
121
+
122
+ state.messages[-2] = [
123
+ state.messages[-2][0],
124
+ (state.messages[-2][1][0],masked_image, state.messages[-2][1][2], state.messages[-2][1][3]) # Create a new tuple with the updated image
125
+ ]
126
+
127
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
128
+
129
+
130
+ def clear_history(request: gr.Request):
131
+ logger.info(f"clear_history. ip: {request.client.host}")
132
+ state = default_conversation.copy()
133
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
134
+
135
+
136
+
137
+ def add_text_wCLS(state, text, masked_image, image_process_mode, imagebox, request: gr.Request):
138
+ logger.info(f"add_text_withcls. ip: {request.client.host}. len: {len(text)}")
139
+
140
+ if len(text) <= 0 and masked_image is None and imagebox is None:
141
+ state.skip_next = True
142
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
143
+ if args.moderate:
144
+ flagged = violates_moderation(text)
145
+ if flagged:
146
+ state.skip_next = True
147
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
148
+ no_change_btn,) * 5
149
+
150
+ text = text[:1536]
151
+ if imagebox is not None:
152
+ text = text[:1200]
153
+ if '<image>' not in text:
154
+ text = text + '\n<image>'
155
+ text = (text, masked_image, imagebox, image_process_mode)
156
+ state = default_conversation.copy()
157
+ state.append_message(state.roles[0], text)
158
+ state.append_message(state.roles[1], None)
159
+ state.skip_next = False
160
+ state.cls=True
161
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
162
+
163
+
164
+ def add_text(state, text, masked_image, image_process_mode, imagebox, request: gr.Request):
165
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
166
+
167
+ if len(text) <= 0 and masked_image is None and imagebox is None:
168
+ state.skip_next = True
169
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
170
+ if args.moderate:
171
+ flagged = violates_moderation(text)
172
+ if flagged:
173
+ state.skip_next = True
174
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
175
+ no_change_btn,) * 5
176
+
177
+ text = text[:1536]
178
+ if imagebox is not None:
179
+ text = text[:1200]
180
+ if '<image>' not in text:
181
+ text = text + '\n<image>'
182
+ text = (text, masked_image, imagebox, image_process_mode)
183
+ state = default_conversation.copy()
184
+ state.append_message(state.roles[0], text)
185
+ state.append_message(state.roles[1], None)
186
+ state.skip_next = False
187
+ state.cls=False
188
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
189
+
190
+
191
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, raw_tokens, request: gr.Request):
192
+ cls_flag = state.cls
193
+ print(f">>>>>>>>CLS_FLAG_{cls_flag}")
194
+ select_tokens = raw_tokens.strip('[]')
195
+ select_tokens = list(map(int, select_tokens.split()))
196
+ logger.info(f"http_bot. ip: {request.client.host}")
197
+ start_tstamp = time.time()
198
+ model_name = model_selector
199
+
200
+ if state.skip_next:
201
+ # This generate call is skipped due to invalid inputs
202
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
203
+ return
204
+
205
+ if len(state.messages) == state.offset + 2:
206
+ # First round of conversation
207
+ if "llava" in model_name.lower():
208
+ if 'llama-2' in model_name.lower():
209
+ template_name = "llava_llama_2"
210
+ elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
211
+ if 'orca' in model_name.lower():
212
+ template_name = "mistral_orca"
213
+ elif 'hermes' in model_name.lower():
214
+ template_name = "chatml_direct"
215
+ else:
216
+ template_name = "mistral_instruct"
217
+ elif 'llava-v1.6-34b' in model_name.lower():
218
+ template_name = "chatml_direct"
219
+ elif "v1" in model_name.lower():
220
+ if 'mmtag' in model_name.lower():
221
+ template_name = "v1_mmtag"
222
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
223
+ template_name = "v1_mmtag"
224
+ else:
225
+ template_name = "llava_v1"
226
+ elif "mpt" in model_name.lower():
227
+ template_name = "mpt"
228
+ else:
229
+ if 'mmtag' in model_name.lower():
230
+ template_name = "v0_mmtag"
231
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
232
+ template_name = "v0_mmtag"
233
+ else:
234
+ template_name = "llava_v0"
235
+ elif "mpt" in model_name:
236
+ template_name = "mpt_text"
237
+ elif "llama-2" in model_name:
238
+ template_name = "llama_2"
239
+ else:
240
+ template_name = "vicuna_v1"
241
+ new_state = conv_templates[template_name].copy()
242
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
243
+ new_state.append_message(new_state.roles[1], None)
244
+ state = new_state
245
+
246
+ # Query worker address
247
+ controller_url = args.controller_url
248
+ ret = requests.post(controller_url + "/get_worker_address",
249
+ json={"model": model_name})
250
+ worker_addr = ret.json()["address"]
251
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
252
+
253
+ # No available worker
254
+ if worker_addr == "":
255
+ state.messages[-1][-1] = server_error_msg
256
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
257
+ return
258
+
259
+ # Construct prompt
260
+ prompt = state.get_prompt()
261
+
262
+ all_images = state.get_images(return_pil=True)
263
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
264
+ for image, hash in zip(all_images, all_image_hash):
265
+ t = datetime.datetime.now()
266
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
267
+ if not os.path.isfile(filename):
268
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
269
+ image.save(filename)
270
+
271
+ # Make requests
272
+ pload = {
273
+ "model": model_name,
274
+ "prompt": prompt,
275
+ "temperature": float(temperature),
276
+ "top_p": float(top_p),
277
+ "max_new_tokens": min(int(max_new_tokens), 1536),
278
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
279
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
280
+ "select_tokens":select_tokens,
281
+ "cls_flag":cls_flag,
282
+ }
283
+ logger.info(f"==== request ====\n{pload}")
284
+ state.cls=cls_flag
285
+ pload['images'] = state.get_images()
286
+
287
+ state.messages[-1][-1] = "▌"
288
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
289
+
290
+ try:
291
+ # Stream output
292
+ response = requests.post(worker_addr + "/worker_generate_stream",
293
+ headers=headers, json=pload, stream=True, timeout=20)
294
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
295
+ if chunk:
296
+ data = json.loads(chunk.decode())
297
+ if data["error_code"] == 0:
298
+ output = data["text"][len(prompt):].strip()
299
+ state.messages[-1][-1] = output + "▌"
300
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
301
+ else:
302
+ output = data["text"] + f" (error_code: {data['error_code']})"
303
+ state.messages[-1][-1] = output
304
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
305
+ return
306
+ time.sleep(0.03)
307
+ except requests.exceptions.RequestException as e:
308
+ state.messages[-1][-1] = server_error_msg
309
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
310
+ return
311
+
312
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
313
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
314
+
315
+ finish_tstamp = time.time()
316
+ logger.info(f"{output}")
317
+
318
+ with open(get_conv_log_filename(), "a") as fout:
319
+ data = {
320
+ "tstamp": round(finish_tstamp, 4),
321
+ "type": "chat",
322
+ "model": model_name,
323
+ "start": round(start_tstamp, 4),
324
+ "finish": round(finish_tstamp, 4),
325
+ "state": state.dict(),
326
+ "images": all_image_hash,
327
+ "ip": request.client.host,
328
+ }
329
+ fout.write(json.dumps(data) + "\n")
330
+
331
+ title_markdown = ("""
332
+ # VisionZip: Longer is Better but Not Necessary in Vision Language Models
333
+ [[Code](https://github.com/dvlab-research/VisionZip)] [[Demo-Visualizer](http://202.104.135.156:11030)] [[Usage-Video](https://youtu.be/9GNIJy4U6-k?si=jcWIJ2O0IjB4aamm)] [[Intro-Video](https://youtu.be/sytaAzmxxpo?si=IieArmQ7YNf2dVyM)]
334
+
335
+ This demo allows users to manually select which visual tokens to send to the LLM to observe how different visual tokens impact the final response.
336
+
337
+ ### Instructions:
338
+ 1. Upload an image.
339
+ 2. Select the visual tokens.
340
+ 3. Generate the answer.
341
+
342
+ For a step-by-step guide, refer to the [Usage Video](https://youtu.be/9GNIJy4U6-k?si=jcWIJ2O0IjB4aamm).
343
+ """)
344
+
345
+ tos_markdown = ("""
346
+ ### Terms of use
347
+ By using this service, users are required to agree to the following terms:
348
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
349
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
350
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
351
+ """)
352
+
353
+
354
+ learn_more_markdown = ("""
355
+ ### License
356
+ The service is a research preview intended for non-commercial use only, subject to the [License](https://github.com/dvlab-research/VisionZip/blob/main/LICENSE) of VisionZip, model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
357
+ """)
358
+
359
+ block_css = """
360
+
361
+ #buttons button {
362
+ min-width: min(120px,100%);
363
+ }
364
+
365
+ """
366
+ import gradio as gr
367
+ import numpy as np
368
+ # Function to capture coordinates of the drawing on the image
369
+ import numpy as np
370
+ from PIL import Image, ImageDraw
371
+
372
+
373
+ def create_mask(image, grid_vet):
374
+ if image is None:
375
+ return None
376
+ # Resize the image to 336x336
377
+ image = image.resize((336, 336))
378
+
379
+ # Create a transparent overlay
380
+ overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
381
+ draw = ImageDraw.Draw(overlay)
382
+
383
+ grid_size = 14
384
+ grid_count = 24
385
+
386
+ for i in range(grid_count):
387
+ for j in range(grid_count):
388
+ # Calculate the bounding box of each grid cell
389
+ left = j * grid_size
390
+ top = i * grid_size
391
+ right = left + grid_size
392
+ bottom = top + grid_size
393
+
394
+ # If the value in grid_vet is 0, draw a white mask with 70% transparency
395
+ if grid_vet[i][j] == 0:
396
+ draw.rectangle([left, top, right, bottom], fill=(255, 255, 255, 178)) # 70% transparency
397
+
398
+ # Composite the image with the overlay
399
+ final_image = Image.alpha_composite(image.convert('RGBA'), overlay)
400
+
401
+ # Convert back to RGB if needed (remove alpha channel)
402
+ return final_image.convert('RGB')
403
+
404
+ def capture_coordinates(image, drawing):
405
+ outputs = drawing['layers'][0][:, :, -1] # Alpha channel (transparency)
406
+
407
+ non_zero_pixels = np.argwhere(outputs > 0) # Non-transparent pixels
408
+
409
+ grid_size = 14
410
+ grid_count = 24
411
+
412
+ grid_vector = np.zeros((grid_count, grid_count), dtype=int)
413
+
414
+ for y, x in non_zero_pixels:
415
+ grid_x = x // grid_size
416
+ grid_y = y // grid_size
417
+ grid_vector[grid_y, grid_x] = 1
418
+
419
+ grid_vector_flat = grid_vector.flatten()
420
+ index = np.where(grid_vector_flat==1)[0]
421
+ final_image = create_mask(image,grid_vector)
422
+
423
+
424
+ return str(index),final_image
425
+
426
+ def calculate_dominant_tokens_192(image, model_selector,state):
427
+ token_num=192
428
+ model_name = model_selector
429
+
430
+ controller_url = args.controller_url
431
+
432
+ ret = requests.post(controller_url + "/get_worker_address",
433
+ json={"model": model_name})
434
+ worker_addr = ret.json()["address"]
435
+
436
+ pload = {
437
+ "images": [state.process_image(image, "Default")],
438
+ "token_num":token_num,
439
+ }
440
+
441
+ response = requests.post(worker_addr + "/worker_get_visonzip",json=pload, timeout=20)
442
+
443
+ select_idx = response.json()['token_idx'][0]
444
+ grid_count=24
445
+ grid_vector = np.zeros((grid_count, grid_count), dtype=int)
446
+ for idx in select_idx:
447
+ row = idx // grid_count
448
+ col = idx % grid_count
449
+ grid_vector[row, col] = 1
450
+
451
+ final_image = create_mask(image,grid_vector)
452
+ select_idx = np.array(select_idx)
453
+
454
+ return str(select_idx), final_image
455
+
456
+ def calculate_dominant_tokens_128(image, model_selector,state):
457
+ ## Call the Model to get the visionzip
458
+ ## use the index to get the grid vector
459
+ token_num=128
460
+ model_name = model_selector
461
+
462
+ controller_url = args.controller_url
463
+
464
+ ret = requests.post(controller_url + "/get_worker_address",
465
+ json={"model": model_name})
466
+ worker_addr = ret.json()["address"]
467
+
468
+ pload = {
469
+ "images": [state.process_image(image, "Default")],
470
+ "token_num":token_num,
471
+ }
472
+
473
+ response = requests.post(worker_addr + "/worker_get_visonzip",json=pload, timeout=20)
474
+
475
+ select_idx = response.json()['token_idx'][0]
476
+ grid_count=24
477
+ grid_vector = np.zeros((grid_count, grid_count), dtype=int)
478
+ for idx in select_idx:
479
+ row = idx // grid_count
480
+ col = idx % grid_count
481
+ grid_vector[row, col] = 1
482
+
483
+ final_image = create_mask(image,grid_vector)
484
+ select_idx = np.array(select_idx)
485
+
486
+ return str(select_idx), final_image
487
+
488
+ def calculate_dominant_tokens_64(image, model_selector,state):
489
+ ## Call the Model to get the visionzip
490
+ ## use the index to get the grid vector
491
+ token_num=64
492
+ model_name = model_selector
493
+
494
+ controller_url = args.controller_url
495
+
496
+ ret = requests.post(controller_url + "/get_worker_address",
497
+ json={"model": model_name})
498
+ worker_addr = ret.json()["address"]
499
+
500
+ pload = {
501
+ "images": [state.process_image(image, "Default")],
502
+ "token_num":token_num,
503
+ }
504
+
505
+ response = requests.post(worker_addr + "/worker_get_visonzip",json=pload, timeout=20)
506
+
507
+ select_idx = response.json()['token_idx'][0]
508
+ grid_count=24
509
+ grid_vector = np.zeros((grid_count, grid_count), dtype=int)
510
+ for idx in select_idx:
511
+ row = idx // grid_count
512
+ col = idx % grid_count
513
+ grid_vector[row, col] = 1
514
+
515
+ final_image = create_mask(image,grid_vector)
516
+ select_idx = np.array(select_idx)
517
+
518
+ return str(select_idx), final_image
519
+
520
+ from PIL import Image
521
+
522
+ # Function to resize the image to 336x336 and return it
523
+ def resize_image(image):
524
+ if image is None:
525
+ return None
526
+ return image.resize((336, 336))
527
+
528
+ def default_img(image):
529
+ grid_count = 24
530
+ grid_vector = np.zeros((grid_count, grid_count), dtype=int)
531
+ default_image = create_mask(image,grid_vector)
532
+ return default_image
533
+
534
+ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
535
+ models = get_model_list()
536
+
537
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER (No CLS)", container=False)
538
+
539
+ with gr.Blocks(title="VisionZip", theme=gr.themes.Default(), css=block_css) as demo:
540
+ state = gr.State()
541
+
542
+ if not embed_mode:
543
+ gr.Markdown(title_markdown)
544
+
545
+ with gr.Row():
546
+ with gr.Column(scale=3):
547
+ with gr.Row(elem_id="model_selector_row"):
548
+ model_selector = gr.Dropdown(
549
+ choices=models,
550
+ value=models[0] if len(models) > 0 else "",
551
+ interactive=True,
552
+ show_label=False,
553
+ container=False)
554
+
555
+ imagebox = gr.Image(type="pil", label="Upload Image", interactive=True)
556
+ image_process_mode = gr.Radio(
557
+ ["Crop", "Resize", "Pad", "Default"],
558
+ value="Default",
559
+ label="Preprocess for non-square image", visible=False)
560
+
561
+
562
+ sketchbox = gr.Sketchpad(
563
+ label="Select on the Image",
564
+ height=250,
565
+ brush=gr.Brush(
566
+ colors=["#FF0000", "#0000FF", "#00FF00", "#FFFF00"], # Red, Blue, Green, Yellow, Black
567
+ default_color="#FF0000",
568
+ color_mode="defaults" # Fixed color mode (can also be "dynamic" for multiple colors)
569
+ )
570
+ )
571
+
572
+ get_coordinates_btn = gr.Button(value="Get the Selected Tokens")
573
+ with gr.Row(): # Add this new row to hold both buttons side by side
574
+ get_dominant64_btn = gr.Button(value="Get 64 Dominant Tokens")
575
+ get_dominant128_btn = gr.Button(value="Get 128 Dominant Tokens")
576
+ get_dominant192_btn = gr.Button(value="Get 192 Dominant Tokens")
577
+
578
+ coordinates_output = gr.Textbox(label="Select Tokens Index", interactive=False)
579
+
580
+ # Add the new image output area
581
+ masked_image_output = gr.Image(type="pil", label="Selected Visual Tokens", interactive=False)
582
+
583
+ get_coordinates_btn.click(
584
+ capture_coordinates,
585
+ [imagebox, sketchbox],
586
+ [coordinates_output,masked_image_output]
587
+ )
588
+ get_dominant64_btn.click(
589
+ calculate_dominant_tokens_64,
590
+ [imagebox,model_selector,state],
591
+ [coordinates_output,masked_image_output]
592
+
593
+ )
594
+ get_dominant128_btn.click(
595
+ calculate_dominant_tokens_128,
596
+ [imagebox,model_selector,state],
597
+ [coordinates_output,masked_image_output]
598
+
599
+ )
600
+ get_dominant192_btn.click(
601
+ calculate_dominant_tokens_192,
602
+ [imagebox,model_selector,state],
603
+ [coordinates_output,masked_image_output]
604
+
605
+ )
606
+ # Link the uploaded image to the sketchbox with resizing
607
+ imagebox.change(fn=lambda img: resize_image(img), inputs=imagebox, outputs=sketchbox)
608
+ # imagebox.change(fn=lambda img: default_img(img), inputs=imagebox, outputs=masked_image_output)
609
+
610
+ imagebox.change(
611
+ fn=lambda img: [default_img(img), ""] , # Reset coordinates_output to empty string
612
+ inputs=imagebox,
613
+ outputs=[masked_image_output, coordinates_output] # Include coordinates_output in outputs
614
+ )
615
+
616
+ # Example input examples
617
+ if cur_dir is None:
618
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
619
+ gr.Examples(examples=[
620
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
621
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
622
+ ], inputs=[imagebox, textbox])
623
+
624
+ with gr.Accordion("Parameters", open=False) as parameter_row:
625
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature")
626
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P")
627
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens")
628
+
629
+ with gr.Column(scale=8):
630
+ chatbot = gr.Chatbot(
631
+ elem_id="chatbot",
632
+ label="LLaVA Chatbot",
633
+ height=650,
634
+ layout="panel",
635
+ )
636
+ with gr.Row():
637
+ with gr.Column(scale=7):
638
+ textbox.render()
639
+ with gr.Column(scale=1, min_width=50):
640
+ CLS_btn = gr.Button(value="Add CLS", variant="primary")
641
+ with gr.Column(scale=1, min_width=50):
642
+ submit_btn = gr.Button(value="No CLS", variant="primary")
643
+ with gr.Row(elem_id="buttons") as button_row:
644
+ upvote_btn = gr.Button(value="��� Upvote", interactive=False)
645
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
646
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
647
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
648
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
649
+
650
+ # Register listeners
651
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
652
+ upvote_btn.click(
653
+ upvote_last_response,
654
+ [state, model_selector],
655
+ [textbox, upvote_btn, downvote_btn, flag_btn]
656
+ )
657
+ downvote_btn.click(
658
+ downvote_last_response,
659
+ [state, model_selector],
660
+ [textbox, upvote_btn, downvote_btn, flag_btn]
661
+ )
662
+ flag_btn.click(
663
+ flag_last_response,
664
+ [state, model_selector],
665
+ [textbox, upvote_btn, downvote_btn, flag_btn]
666
+ )
667
+
668
+ regenerate_btn.click(
669
+ regenerate,
670
+ [state, masked_image_output, image_process_mode], # No need for imagebox here, you already have masked_image_output
671
+ [state, chatbot, textbox] + btn_list # Use masked_image_output in the outputs
672
+ ).then(
673
+ http_bot,
674
+ [state, model_selector, temperature, top_p, max_output_tokens, coordinates_output],
675
+ [state, chatbot] + btn_list,
676
+ concurrency_limit=concurrency_count
677
+ )
678
+
679
+ clear_btn.click(
680
+ clear_history,
681
+ None,
682
+ [state, chatbot, textbox, imagebox] + btn_list,
683
+ queue=False
684
+ )
685
+
686
+ textbox.submit(
687
+ add_text,
688
+ [state, textbox, masked_image_output, image_process_mode, imagebox],
689
+ [state, chatbot, textbox] + btn_list,
690
+ queue=False
691
+ ).then(
692
+ http_bot,
693
+ [state, model_selector, temperature, top_p, max_output_tokens, coordinates_output],
694
+ [state, chatbot] + btn_list,
695
+ concurrency_limit=concurrency_count
696
+ )
697
+
698
+ submit_btn.click(
699
+ add_text,
700
+ [state, textbox, masked_image_output, image_process_mode, imagebox],
701
+ [state, chatbot, textbox] + btn_list
702
+ ).then(
703
+ http_bot,
704
+ [state, model_selector, temperature, top_p, max_output_tokens, coordinates_output],
705
+ [state, chatbot] + btn_list,
706
+ concurrency_limit=concurrency_count
707
+ )
708
+ CLS_btn.click(
709
+ add_text_wCLS,
710
+ [state, textbox, masked_image_output, image_process_mode, imagebox],
711
+ [state, chatbot, textbox] + btn_list
712
+ ).then(
713
+ http_bot,
714
+ [state, model_selector, temperature, top_p, max_output_tokens, coordinates_output],
715
+ [state, chatbot] + btn_list,
716
+ concurrency_limit=concurrency_count
717
+ )
718
+
719
+ if args.model_list_mode == "once":
720
+ demo.load(
721
+ load_demo,
722
+ [url_params],
723
+ [state, model_selector],
724
+ js=get_window_url_params
725
+ )
726
+ elif args.model_list_mode == "reload":
727
+ demo.load(
728
+ load_demo_refresh_model_list,
729
+ None,
730
+ [state, model_selector],
731
+ queue=False
732
+ )
733
+ else:
734
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
735
+
736
+ return demo
737
+
738
+ def start_demo(args):
739
+ demo = build_demo(args.embed)
740
+ demo.queue(
741
+ status_update_rate=10, api_open=False
742
+ ).launch(server_name=args.host, server_port=args.port, share=args.share)
743
+
744
+ def start_controller():
745
+ logger.info("Starting the controller")
746
+ controller_command = [
747
+ "python",
748
+ "-m",
749
+ "llava.serve.controller",
750
+ "--host",
751
+ "0.0.0.0",
752
+ "--port",
753
+ "10000",
754
+ ]
755
+ return subprocess.Popen(controller_command)
756
+
757
+ def start_worker():
758
+ return subprocess.Popen(['python', '-m', 'llava.serve.model_worker', '--host', '0.0.0.0', '--controller', 'http://localhost:10000', '--model-path', 'liuhaotian/llava-v1.5-7b'])
759
+ def download_llava():
760
+ command = ['huggingface-cli', 'download', '--resume-download', 'liuhaotian/llava-v1.5-7b']
761
+
762
+ # Capture the output and errors
763
+ result = subprocess.run(command, capture_output=True, text=True)
764
+
765
+ # Print output and error (if any)
766
+ print("STDOUT:", result.stdout)
767
+ print("STDERR:", result.stderr)
768
+
769
+ # Check if the command was successful (exit code 0 means success)
770
+ if result.returncode == 0:
771
+ print("Download completed successfully.")
772
+ else:
773
+ print("Download failed.")
774
+
775
+
776
+ def download_clip():
777
+ command = ['huggingface-cli', 'download', '--resume-download', 'openai/clip-vit-large-patch14-336']
778
+
779
+ # Capture the output and errors
780
+ result = subprocess.run(command, capture_output=True, text=True)
781
+
782
+ # Print output and error (if any)
783
+ print("STDOUT:", result.stdout)
784
+ print("STDERR:", result.stderr)
785
+
786
+ # Check if the command was successful (exit code 0 means success)
787
+ if result.returncode == 0:
788
+ print("Download completed successfully.")
789
+ else:
790
+ print("Download failed.")
791
+
792
+ if __name__ == "__main__":
793
+ parser = argparse.ArgumentParser()
794
+ parser.add_argument("--host", type=str, default="0.0.0.0")
795
+ parser.add_argument("--port", type=int)
796
+ parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
797
+ parser.add_argument("--concurrency-count", type=int, default=8)
798
+ parser.add_argument("--model-list-mode", type=str, default="reload",
799
+ choices=["once", "reload"])
800
+ parser.add_argument("--share", action="store_true")
801
+ parser.add_argument("--moderate", action="store_true")
802
+ parser.add_argument("--embed", action="store_true")
803
+ args = parser.parse_args()
804
+
805
+ logger.info(f"args: {args}")
806
+
807
+ download_clip()
808
+ download_llava()
809
+ controller_proc = start_controller()
810
+
811
+ worker_proc = start_worker()
812
+
813
+ time.sleep(100)
814
+ try:
815
+ start_demo(args)
816
+ except Exception as e:
817
+ print(e)
818
+ exit_status = 1
819
+ finally:
820
+ worker_proc.kill()
821
+ controller_proc.kill()
822
+
823
+ sys.exit(exit_status)
app_dev_debug.py CHANGED
@@ -738,7 +738,7 @@ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
738
  def start_demo(args):
739
  demo = build_demo(args.embed)
740
  demo.queue(
741
- concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
742
  ).launch(server_name=args.host, server_port=args.port, share=args.share)
743
 
744
  def start_controller():
 
738
  def start_demo(args):
739
  demo = build_demo(args.embed)
740
  demo.queue(
741
+ status_update_rate=10, api_open=False
742
  ).launch(server_name=args.host, server_port=args.port, share=args.share)
743
 
744
  def start_controller():