Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ from threading import Thread
|
|
6 |
# import time
|
7 |
import cv2
|
8 |
|
9 |
-
|
10 |
# import copy
|
11 |
import torch
|
12 |
|
@@ -34,8 +34,6 @@ from llava.mm_utils import (
|
|
34 |
|
35 |
from serve_constants import html_header
|
36 |
|
37 |
-
from PIL import Image
|
38 |
-
|
39 |
import requests
|
40 |
from PIL import Image
|
41 |
from io import BytesIO
|
@@ -46,6 +44,9 @@ import gradio_client
|
|
46 |
import subprocess
|
47 |
import sys
|
48 |
|
|
|
|
|
|
|
49 |
def install_gradio_4_35_0():
|
50 |
current_version = gr.__version__
|
51 |
if current_version != "4.35.0":
|
@@ -64,6 +65,11 @@ import gradio_client
|
|
64 |
print(f"Gradio version: {gr.__version__}")
|
65 |
print(f"Gradio-client version: {gradio_client.__version__}")
|
66 |
|
|
|
|
|
|
|
|
|
|
|
67 |
class InferenceDemo(object):
|
68 |
def __init__(
|
69 |
self, args, model_path, tokenizer, model, image_processor, context_len
|
@@ -113,6 +119,16 @@ def is_valid_video_filename(name):
|
|
113 |
else:
|
114 |
return False
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
def sample_frames(video_file, num_frames):
|
118 |
video = cv2.VideoCapture(video_file)
|
@@ -193,9 +209,14 @@ def bot(history):
|
|
193 |
if type(message[0]) is tuple:
|
194 |
images_this_term.append(message[0][0])
|
195 |
if is_valid_video_filename(message[0][0]):
|
|
|
|
|
196 |
num_new_images += our_chatbot.num_frames
|
197 |
-
|
|
|
198 |
num_new_images += 1
|
|
|
|
|
199 |
else:
|
200 |
num_new_images = 0
|
201 |
|
@@ -209,8 +230,11 @@ def bot(history):
|
|
209 |
for f in images_this_term:
|
210 |
if is_valid_video_filename(f):
|
211 |
image_list += sample_frames(f, our_chatbot.num_frames)
|
212 |
-
|
213 |
image_list.append(load_image(f))
|
|
|
|
|
|
|
214 |
image_tensor = [
|
215 |
our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
|
216 |
0
|
@@ -219,6 +243,24 @@ def bot(history):
|
|
219 |
.to(our_chatbot.model.device)
|
220 |
for f in image_list
|
221 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
image_tensor = torch.stack(image_tensor)
|
224 |
image_token = DEFAULT_IMAGE_TOKEN * num_new_images
|
@@ -280,7 +322,19 @@ def bot(history):
|
|
280 |
our_chatbot.conversation.messages[-1][-1] = outputs
|
281 |
|
282 |
history[-1] = [text, outputs]
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
return history
|
285 |
# generate_kwargs = dict(
|
286 |
# inputs=input_ids,
|
@@ -345,7 +399,7 @@ with gr.Blocks(
|
|
345 |
|
346 |
with gr.Column():
|
347 |
with gr.Row():
|
348 |
-
chatbot = gr.Chatbot([], elem_id="
|
349 |
|
350 |
with gr.Row():
|
351 |
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
|
@@ -560,8 +614,8 @@ if __name__ == "__main__":
|
|
560 |
argparser.add_argument("--model-base", type=str, default=None)
|
561 |
argparser.add_argument("--num-gpus", type=int, default=1)
|
562 |
argparser.add_argument("--conv-mode", type=str, default=None)
|
563 |
-
argparser.add_argument("--temperature", type=float, default=0.
|
564 |
-
argparser.add_argument("--max-new-tokens", type=int, default=
|
565 |
argparser.add_argument("--num_frames", type=int, default=16)
|
566 |
argparser.add_argument("--load-8bit", action="store_true")
|
567 |
argparser.add_argument("--load-4bit", action="store_true")
|
|
|
6 |
# import time
|
7 |
import cv2
|
8 |
|
9 |
+
import datetime
|
10 |
# import copy
|
11 |
import torch
|
12 |
|
|
|
34 |
|
35 |
from serve_constants import html_header
|
36 |
|
|
|
|
|
37 |
import requests
|
38 |
from PIL import Image
|
39 |
from io import BytesIO
|
|
|
44 |
import subprocess
|
45 |
import sys
|
46 |
|
47 |
+
external_log_dir = "./logs"
|
48 |
+
LOGDIR = external_log_dir
|
49 |
+
|
50 |
def install_gradio_4_35_0():
|
51 |
current_version = gr.__version__
|
52 |
if current_version != "4.35.0":
|
|
|
65 |
print(f"Gradio version: {gr.__version__}")
|
66 |
print(f"Gradio-client version: {gradio_client.__version__}")
|
67 |
|
68 |
+
def get_conv_log_filename():
|
69 |
+
t = datetime.datetime.now()
|
70 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
|
71 |
+
return name
|
72 |
+
|
73 |
class InferenceDemo(object):
|
74 |
def __init__(
|
75 |
self, args, model_path, tokenizer, model, image_processor, context_len
|
|
|
119 |
else:
|
120 |
return False
|
121 |
|
122 |
+
def is_valid_image_filename(name):
|
123 |
+
image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
|
124 |
+
|
125 |
+
ext = name.split(".")[-1].lower()
|
126 |
+
|
127 |
+
if ext in image_extensions:
|
128 |
+
return True
|
129 |
+
else:
|
130 |
+
return False
|
131 |
+
|
132 |
|
133 |
def sample_frames(video_file, num_frames):
|
134 |
video = cv2.VideoCapture(video_file)
|
|
|
209 |
if type(message[0]) is tuple:
|
210 |
images_this_term.append(message[0][0])
|
211 |
if is_valid_video_filename(message[0][0]):
|
212 |
+
# 不接受视频
|
213 |
+
raise ValueError("Video is not supported")
|
214 |
num_new_images += our_chatbot.num_frames
|
215 |
+
elif is_valid_image_filename(message[0][0]):
|
216 |
+
print("#### Load image from local file",message[0][0])
|
217 |
num_new_images += 1
|
218 |
+
else:
|
219 |
+
raise ValueError("Invalid image file")
|
220 |
else:
|
221 |
num_new_images = 0
|
222 |
|
|
|
230 |
for f in images_this_term:
|
231 |
if is_valid_video_filename(f):
|
232 |
image_list += sample_frames(f, our_chatbot.num_frames)
|
233 |
+
elif is_valid_image_filename(f):
|
234 |
image_list.append(load_image(f))
|
235 |
+
else:
|
236 |
+
raise ValueError("Invalid image file")
|
237 |
+
|
238 |
image_tensor = [
|
239 |
our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
|
240 |
0
|
|
|
243 |
.to(our_chatbot.model.device)
|
244 |
for f in image_list
|
245 |
]
|
246 |
+
all_image_hash = []
|
247 |
+
for image_path in image_list:
|
248 |
+
with open(image_path, "rb") as image_file:
|
249 |
+
image_data = image_file.read()
|
250 |
+
image_hash = hashlib.md5(image_data).hexdigest()
|
251 |
+
all_image_hash.append(image_hash)
|
252 |
+
image = PIL.Image.open(image_path).convert("RGB")
|
253 |
+
all_images.append(image)
|
254 |
+
t = datetime.datetime.now()
|
255 |
+
filename = os.path.join(
|
256 |
+
LOGDIR,
|
257 |
+
"serve_images",
|
258 |
+
f"{t.year}-{t.month:02d}-{t.day:02d}",
|
259 |
+
f"{image_hash}.jpg",
|
260 |
+
)
|
261 |
+
if not os.path.isfile(filename):
|
262 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
263 |
+
image.save(filename)
|
264 |
|
265 |
image_tensor = torch.stack(image_tensor)
|
266 |
image_token = DEFAULT_IMAGE_TOKEN * num_new_images
|
|
|
322 |
our_chatbot.conversation.messages[-1][-1] = outputs
|
323 |
|
324 |
history[-1] = [text, outputs]
|
325 |
+
print("#### history",history)
|
326 |
+
|
327 |
+
with open(get_conv_log_filename(), "a") as fout:
|
328 |
+
data = {
|
329 |
+
"tstamp": round(finish_tstamp, 4),
|
330 |
+
"type": "chat",
|
331 |
+
"model": "Pangea-7b",
|
332 |
+
"start": round(start_tstamp, 4),
|
333 |
+
"finish": round(start_tstamp, 4),
|
334 |
+
"state": history,
|
335 |
+
"images": all_image_hash,
|
336 |
+
}
|
337 |
+
fout.write(json.dumps(data) + "\n")
|
338 |
return history
|
339 |
# generate_kwargs = dict(
|
340 |
# inputs=input_ids,
|
|
|
399 |
|
400 |
with gr.Column():
|
401 |
with gr.Row():
|
402 |
+
chatbot = gr.Chatbot([], elem_id="Pangea", bubble_full_width=False, height=750)
|
403 |
|
404 |
with gr.Row():
|
405 |
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
|
|
|
614 |
argparser.add_argument("--model-base", type=str, default=None)
|
615 |
argparser.add_argument("--num-gpus", type=int, default=1)
|
616 |
argparser.add_argument("--conv-mode", type=str, default=None)
|
617 |
+
argparser.add_argument("--temperature", type=float, default=0.7)
|
618 |
+
argparser.add_argument("--max-new-tokens", type=int, default=4096)
|
619 |
argparser.add_argument("--num_frames", type=int, default=16)
|
620 |
argparser.add_argument("--load-8bit", action="store_true")
|
621 |
argparser.add_argument("--load-4bit", action="store_true")
|