Update app.py
Browse files
app.py
CHANGED
@@ -2,20 +2,23 @@ import gradio as gr
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
from transformers.generation import GenerationConfig
|
4 |
import re
|
|
|
5 |
from pathlib import Path
|
6 |
import secrets
|
7 |
import torch
|
|
|
8 |
|
9 |
-
# Initialize the model and tokenizer
|
10 |
model_name = "qwen/Qwen-VL-Chat"
|
11 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
12 |
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).eval()
|
13 |
model.generation_config = GenerationConfig.from_pretrained(model_name, trust_remote_code=True)
|
14 |
|
15 |
-
# Set device for model
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
model.to(device)
|
18 |
|
|
|
|
|
|
|
19 |
def save_image(image_file, upload_dir: str) -> str:
|
20 |
Path(upload_dir).mkdir(parents=True, exist_ok=True)
|
21 |
filename = secrets.token_hex(10) + Path(image_file.name).suffix
|
@@ -29,27 +32,51 @@ def clean_response(response: str) -> str:
|
|
29 |
return response
|
30 |
|
31 |
def chat_with_model(image_path=None, text_query=None, history=None):
|
|
|
32 |
query_elements = []
|
33 |
if image_path:
|
34 |
query_elements.append({'image': image_path})
|
35 |
if text_query:
|
36 |
query_elements.append({'text': text_query})
|
37 |
-
|
38 |
query = tokenizer.from_list_format(query_elements)
|
39 |
tokenized_inputs = tokenizer(query, return_tensors='pt').to(device)
|
40 |
output = model.generate(**tokenized_inputs)
|
41 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
42 |
cleaned_response = clean_response(response)
|
43 |
return cleaned_response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
def process_input(text, file):
|
|
|
|
|
46 |
image_path = None
|
47 |
if file is not None:
|
48 |
image_path = save_image(file, "uploaded_images")
|
49 |
-
response = chat_with_model(image_path=image_path, text_query=text)
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
|
|
53 |
with gr.Blocks() as demo:
|
54 |
gr.Markdown("""
|
55 |
# 🙋🏻♂️欢迎来到🌟Tonic 的🦆Qwen-VL-Chat🤩Bot!🚀
|
@@ -69,10 +96,12 @@ Join us: TeamTonic is always making cool demos! Join our active builder's comm
|
|
69 |
file_upload = gr.File(label="Upload Image")
|
70 |
submit_btn = gr.Button("Submit")
|
71 |
|
|
|
|
|
72 |
submit_btn.click(
|
73 |
fn=process_input,
|
74 |
-
inputs=[query, file_upload],
|
75 |
-
outputs=chatbot
|
76 |
)
|
77 |
|
78 |
gr.Markdown("""
|
@@ -81,5 +110,7 @@ Join us: TeamTonic is always making cool demos! Join our active builder's comm
|
|
81 |
Note: This demo is governed by the original license of Qwen-VL. We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content,
|
82 |
including hate speech, violence, pornography, deception, etc. (Note: This demo is subject to the license agreement of Qwen-VL. We strongly advise users not to disseminate or allow others to disseminate the following content, including but not limited to hate speech, violence, pornography, and fraud-related harmful information.)
|
83 |
""")
|
84 |
-
|
85 |
demo.queue().launch()
|
|
|
|
|
|
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
from transformers.generation import GenerationConfig
|
4 |
import re
|
5 |
+
import copy
|
6 |
from pathlib import Path
|
7 |
import secrets
|
8 |
import torch
|
9 |
+
from PIL import Image, ImageDraw
|
10 |
|
|
|
11 |
model_name = "qwen/Qwen-VL-Chat"
|
12 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
13 |
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).eval()
|
14 |
model.generation_config = GenerationConfig.from_pretrained(model_name, trust_remote_code=True)
|
15 |
|
|
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
model.to(device)
|
18 |
|
19 |
+
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
|
20 |
+
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』​``【oaicite:0】``​〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
|
21 |
+
|
22 |
def save_image(image_file, upload_dir: str) -> str:
|
23 |
Path(upload_dir).mkdir(parents=True, exist_ok=True)
|
24 |
filename = secrets.token_hex(10) + Path(image_file.name).suffix
|
|
|
32 |
return response
|
33 |
|
34 |
def chat_with_model(image_path=None, text_query=None, history=None):
|
35 |
+
# Modify this function to use 'history' if your model requires it
|
36 |
query_elements = []
|
37 |
if image_path:
|
38 |
query_elements.append({'image': image_path})
|
39 |
if text_query:
|
40 |
query_elements.append({'text': text_query})
|
41 |
+
# Add history processing here if needed
|
42 |
query = tokenizer.from_list_format(query_elements)
|
43 |
tokenized_inputs = tokenizer(query, return_tensors='pt').to(device)
|
44 |
output = model.generate(**tokenized_inputs)
|
45 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
46 |
cleaned_response = clean_response(response)
|
47 |
return cleaned_response
|
48 |
+
def draw_boxes(image_path, response):
|
49 |
+
image = Image.open(image_path)
|
50 |
+
draw = ImageDraw.Draw(image)
|
51 |
+
boxes = re.findall(r'<box>\((\d+),(\d+)\),\((\d+),(\d+)\)</box>', response)
|
52 |
+
for box in boxes:
|
53 |
+
x1, y1, x2, y2 = map(int, box)
|
54 |
+
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
|
55 |
+
return image
|
56 |
|
57 |
+
def process_input(text=None, file=None, task_history=None):
|
58 |
+
if task_history is None:
|
59 |
+
task_history = []
|
60 |
image_path = None
|
61 |
if file is not None:
|
62 |
image_path = save_image(file, "uploaded_images")
|
63 |
+
response = chat_with_model(image_path=image_path, text_query=text, history=task_history)
|
64 |
+
task_history.append((text, response))
|
65 |
+
|
66 |
+
if "<box>" in response:
|
67 |
+
if image_path:
|
68 |
+
image_with_boxes = draw_boxes(image_path, response)
|
69 |
+
image_with_boxes_path = image_path.replace(".jpg", "_boxed.jpg")
|
70 |
+
image_with_boxes.save(image_with_boxes_path)
|
71 |
+
return [("bot", response), "image", image_with_boxes_path], task_history
|
72 |
+
else:
|
73 |
+
return [("bot", response), "text", None], task_history
|
74 |
+
else:
|
75 |
+
# Clean the response if it contains any box-like annotations
|
76 |
+
clean_response = re.sub(r'<ref>(.*?)</ref>(?:<box>.*?</box>)*(?:<quad>.*?</quad>)*', r'\1', response).strip()
|
77 |
+
return [("bot", clean_response), "text", None], task_history
|
78 |
|
79 |
+
# Define Gradio interface
|
80 |
with gr.Blocks() as demo:
|
81 |
gr.Markdown("""
|
82 |
# 🙋🏻♂️欢迎来到🌟Tonic 的🦆Qwen-VL-Chat🤩Bot!🚀
|
|
|
96 |
file_upload = gr.File(label="Upload Image")
|
97 |
submit_btn = gr.Button("Submit")
|
98 |
|
99 |
+
task_history = []
|
100 |
+
|
101 |
submit_btn.click(
|
102 |
fn=process_input,
|
103 |
+
inputs=[query, file_upload, task_history],
|
104 |
+
outputs=[chatbot, task_history]
|
105 |
)
|
106 |
|
107 |
gr.Markdown("""
|
|
|
110 |
Note: This demo is governed by the original license of Qwen-VL. We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content,
|
111 |
including hate speech, violence, pornography, deception, etc. (Note: This demo is subject to the license agreement of Qwen-VL. We strongly advise users not to disseminate or allow others to disseminate the following content, including but not limited to hate speech, violence, pornography, and fraud-related harmful information.)
|
112 |
""")
|
|
|
113 |
demo.queue().launch()
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
demo.launch()
|