Spaces:
Runtime error
Runtime error
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import copy
|
6 |
+
import gradio as gr
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import secrets
|
10 |
+
import tempfile
|
11 |
+
|
12 |
+
from PIL import Image
|
13 |
+
from monkey_model.modeling_monkey import MonkeyLMHeadModel
|
14 |
+
from monkey_model.tokenization_qwen import QWenTokenizer
|
15 |
+
from monkey_model.configuration_monkey import MonkeyConfig
|
16 |
+
|
17 |
+
import shutil
|
18 |
+
from pathlib import Path
|
19 |
+
import json
|
20 |
+
DEFAULT_CKPT_PATH = 'echo840/Monkey' # '/home/zhangli/demo/'
|
21 |
+
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
|
22 |
+
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
|
23 |
+
title_markdown = ("""
|
24 |
+
# Welcome to Monkey
|
25 |
+
|
26 |
+
Hello! I'm Monkey, a Large Language and Vision Assistant. Before talking to me, please read the **Operation Guide** and **Terms of Use**.
|
27 |
+
|
28 |
+
> Note: This demo represents a more advanced iteration of the chat system, building upon the previous version to deliver an enhanced interactive experience. As a result, we cannot guarantee that the question-answering scenarios presented in the paper can be replicated accurately using this updated version.
|
29 |
+
|
30 |
+
## Operation Guide
|
31 |
+
|
32 |
+
Click the **Upload** button to upload an image. Then, you can get Monkey's answer in two ways:
|
33 |
+
- Click the **Generate** and Monkey will generate a description of the image.
|
34 |
+
- Enter the question in the dialog box, click the **Submit**, and Monkey will answer the question based on the image.
|
35 |
+
- Click **Clear History** to clear the current image and Q&A content.
|
36 |
+
|
37 |
+
""")
|
38 |
+
|
39 |
+
policy_markdown = ("""
|
40 |
+
## Terms of Use
|
41 |
+
|
42 |
+
By using this service, users are required to agree to the following terms:
|
43 |
+
|
44 |
+
- Monkey is for research use only and unauthorized commercial use is prohibited. For any query, please contact the author.
|
45 |
+
- Monkey's generation capabilities are limited, so we recommend that users do not rely entirely on its answers.
|
46 |
+
- Monkey's security measures are limited, so we cannot guarantee that the output is completely appropriate. We strongly recommend that users do not intentionally guide Monkey to generate harmful content, including hate speech, discrimination, violence, pornography, deception, etc.
|
47 |
+
|
48 |
+
""")
|
49 |
+
|
50 |
+
# ## Some Prompt Examples
|
51 |
+
|
52 |
+
# In order to generate more detailed captions, we provide some input examples so that you can conduct more interesting explorations.
|
53 |
+
|
54 |
+
# - Generate the detailed caption in English.
|
55 |
+
# - Explain the visual content of the image in great detail.
|
56 |
+
# - Analyze the image in a comprehensive and detailed manner.
|
57 |
+
# - Describe the image in as much detail as possible in English without duplicating it.
|
58 |
+
# - Describe the image in as much detail as possible in English, including as many elements from the image as possible, but without repetition.
|
59 |
+
|
60 |
+
|
61 |
+
def _get_args():
|
62 |
+
parser = ArgumentParser()
|
63 |
+
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
|
64 |
+
help="Checkpoint name or path, default to %(default)r")
|
65 |
+
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
|
66 |
+
|
67 |
+
parser.add_argument("--share", action="store_true", default=True,
|
68 |
+
help="Create a publicly shareable link for the interface.")
|
69 |
+
parser.add_argument("--inbrowser", action="store_true", default=False,
|
70 |
+
help="Automatically launch the interface in a new tab on the default browser.")
|
71 |
+
parser.add_argument("--server-port", type=int, default=8000,
|
72 |
+
help="Demo server port.")
|
73 |
+
parser.add_argument("--server-name", type=str, default="127.0.0.1",
|
74 |
+
help="Demo server name.")
|
75 |
+
|
76 |
+
args = parser.parse_args()
|
77 |
+
return args
|
78 |
+
|
79 |
+
|
80 |
+
def _load_model_tokenizer(args):
|
81 |
+
tokenizer = QWenTokenizer.from_pretrained(
|
82 |
+
args.checkpoint_path, trust_remote_code=True)
|
83 |
+
|
84 |
+
if args.cpu_only:
|
85 |
+
device_map = "cpu"
|
86 |
+
else:
|
87 |
+
device_map = "cuda"
|
88 |
+
|
89 |
+
model = MonkeyLMHeadModel.from_pretrained(
|
90 |
+
args.checkpoint_path,
|
91 |
+
device_map=device_map,
|
92 |
+
trust_remote_code=True,
|
93 |
+
).eval()
|
94 |
+
# model.generation_config = GenerationConfig.from_pretrained(
|
95 |
+
# args.checkpoint_path, trust_remote_code=True, resume_download=True,
|
96 |
+
# )
|
97 |
+
tokenizer.padding_side = 'left'
|
98 |
+
tokenizer.pad_token_id = tokenizer.eod_id
|
99 |
+
return model, tokenizer
|
100 |
+
|
101 |
+
|
102 |
+
def _parse_text(text):
|
103 |
+
lines = text.split("\n")
|
104 |
+
lines = [line for line in lines if line != ""]
|
105 |
+
count = 0
|
106 |
+
for i, line in enumerate(lines):
|
107 |
+
if "```" in line:
|
108 |
+
count += 1
|
109 |
+
items = line.split("`")
|
110 |
+
if count % 2 == 1:
|
111 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
112 |
+
else:
|
113 |
+
lines[i] = f"<br></code></pre>"
|
114 |
+
else:
|
115 |
+
if i > 0:
|
116 |
+
if count % 2 == 1:
|
117 |
+
line = line.replace("`", r"\`")
|
118 |
+
line = line.replace("<", "<")
|
119 |
+
line = line.replace(">", ">")
|
120 |
+
line = line.replace(" ", " ")
|
121 |
+
line = line.replace("*", "*")
|
122 |
+
line = line.replace("_", "_")
|
123 |
+
line = line.replace("-", "-")
|
124 |
+
line = line.replace(".", ".")
|
125 |
+
line = line.replace("!", "!")
|
126 |
+
line = line.replace("(", "(")
|
127 |
+
line = line.replace(")", ")")
|
128 |
+
line = line.replace("$", "$")
|
129 |
+
lines[i] = "<br>" + line
|
130 |
+
text = "".join(lines)
|
131 |
+
return text
|
132 |
+
|
133 |
+
|
134 |
+
def _launch_demo(args, model, tokenizer):
|
135 |
+
def predict(_chatbot, task_history):
|
136 |
+
chat_query = _chatbot[-1][0]
|
137 |
+
query = task_history[-1][0]
|
138 |
+
question = _parse_text(query)
|
139 |
+
# print("User: " + _parse_text(query))
|
140 |
+
full_response = ""
|
141 |
+
|
142 |
+
|
143 |
+
img_path = _chatbot[0][0][0]
|
144 |
+
try:
|
145 |
+
Image.open(img_path)
|
146 |
+
except:
|
147 |
+
response = "Please upload a picture."
|
148 |
+
_chatbot[-1] = (_parse_text(chat_query), response)
|
149 |
+
full_response = _parse_text(response)
|
150 |
+
|
151 |
+
task_history[-1] = (query, full_response)
|
152 |
+
# print("Monkey: " + _parse_text(full_response))
|
153 |
+
return _chatbot
|
154 |
+
|
155 |
+
query = f'<img>{img_path}</img> {question} Answer: '
|
156 |
+
print(query)
|
157 |
+
|
158 |
+
all_history = query
|
159 |
+
all_history_0 = ''
|
160 |
+
if len(_chatbot) > 2:
|
161 |
+
all_history = ''
|
162 |
+
for conv in _chatbot[1:-1]:
|
163 |
+
q = conv[0]
|
164 |
+
a = conv[1]
|
165 |
+
all_history_0 = all_history + f'{q} Answer: {a} '
|
166 |
+
all_history = all_history_0 + f'<img>{img_path}</img> ' # 1288 tokens
|
167 |
+
all_history = all_history + f'{question} Answer: '
|
168 |
+
print(all_history)
|
169 |
+
tokens = all_history.split()
|
170 |
+
last_2048_tokens = tokens[-760:]
|
171 |
+
all_history = " ".join(last_2048_tokens)
|
172 |
+
print(all_history)
|
173 |
+
|
174 |
+
# input_ids = tokenizer(query, return_tensors='pt', padding='longest')
|
175 |
+
input_ids = tokenizer(all_history, return_tensors='pt', padding='longest')
|
176 |
+
|
177 |
+
attention_mask = input_ids.attention_mask
|
178 |
+
input_ids = input_ids.input_ids
|
179 |
+
|
180 |
+
pred = model.generate(
|
181 |
+
input_ids=input_ids.cuda(),
|
182 |
+
attention_mask=attention_mask.cuda(),
|
183 |
+
do_sample=False,
|
184 |
+
num_beams=1,
|
185 |
+
max_new_tokens=512,
|
186 |
+
min_new_tokens=1,
|
187 |
+
length_penalty=3,
|
188 |
+
num_return_sequences=1,
|
189 |
+
output_hidden_states=True,
|
190 |
+
use_cache=True,
|
191 |
+
pad_token_id=tokenizer.eod_id,
|
192 |
+
eos_token_id=tokenizer.eod_id,
|
193 |
+
)
|
194 |
+
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
|
195 |
+
|
196 |
+
_chatbot[-1] = (_parse_text(chat_query), response)
|
197 |
+
full_response = _parse_text(response)
|
198 |
+
|
199 |
+
# with open('./history/question_answer.jsonl', 'a',encoding="utf-8") as file: # 使用 'a' 模式打开文件,表示以追加模式写入
|
200 |
+
# data = {query:response}
|
201 |
+
# json_line = json.dumps(data)
|
202 |
+
# file.write(json_line + '\n')
|
203 |
+
# with open('./history/all_history_together.jsonl', 'a',encoding="utf-8") as file: # 使用 'a' 模式打开文件,表示以追加模式写入
|
204 |
+
# data = f'<img>{img_path}</img> ' + all_history_0 + f'{question} Answer: {full_response}'
|
205 |
+
# json_line = json.dumps(data)
|
206 |
+
# file.write(json_line + '\n')
|
207 |
+
|
208 |
+
|
209 |
+
task_history[-1] = (query, full_response)
|
210 |
+
print("Monkey: " + _parse_text(full_response))
|
211 |
+
return _chatbot
|
212 |
+
|
213 |
+
def caption(_chatbot, task_history):
|
214 |
+
|
215 |
+
query = "Describe the image in as much detail as possible in English, including as many elements from the image as possible, but without repetition. Answer: "
|
216 |
+
chat_query = "Describe the image in as much detail as possible in English, including as many elements from the image as possible, but without repetition. Answer: "
|
217 |
+
|
218 |
+
question = _parse_text(query)
|
219 |
+
print("User: " + _parse_text(query))
|
220 |
+
|
221 |
+
full_response = ""
|
222 |
+
|
223 |
+
try:
|
224 |
+
img_path = _chatbot[0][0][0]
|
225 |
+
Image.open(img_path)
|
226 |
+
except:
|
227 |
+
response = "Please upload a picture."
|
228 |
+
|
229 |
+
_chatbot.append((None, response))
|
230 |
+
full_response = _parse_text(response)
|
231 |
+
|
232 |
+
task_history.append((None, full_response))
|
233 |
+
print("Monkey: " + _parse_text(full_response))
|
234 |
+
return _chatbot
|
235 |
+
img_path = _chatbot[0][0][0]
|
236 |
+
query = f'<img>{img_path}</img> {chat_query} '
|
237 |
+
print(query)
|
238 |
+
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
|
239 |
+
attention_mask = input_ids.attention_mask
|
240 |
+
input_ids = input_ids.input_ids
|
241 |
+
|
242 |
+
pred = model.generate(
|
243 |
+
input_ids=input_ids.cuda(),
|
244 |
+
attention_mask=attention_mask.cuda(),
|
245 |
+
do_sample=True,
|
246 |
+
temperature=0.7,
|
247 |
+
max_new_tokens=250,
|
248 |
+
min_new_tokens=1,
|
249 |
+
length_penalty=3,
|
250 |
+
num_return_sequences=1,
|
251 |
+
output_hidden_states=True,
|
252 |
+
use_cache=True,
|
253 |
+
pad_token_id=tokenizer.eod_id,
|
254 |
+
eos_token_id=tokenizer.eod_id,
|
255 |
+
)
|
256 |
+
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
|
257 |
+
|
258 |
+
_chatbot.append((None, response))
|
259 |
+
full_response = _parse_text(response)
|
260 |
+
|
261 |
+
task_history.append((None, full_response))
|
262 |
+
print("Monkey: " + _parse_text(full_response))
|
263 |
+
return _chatbot
|
264 |
+
|
265 |
+
def add_text(history, task_history, text):
|
266 |
+
task_text = text
|
267 |
+
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
|
268 |
+
task_text = text[:-1]
|
269 |
+
history = history + [(_parse_text(text), None)]
|
270 |
+
task_history = task_history + [(task_text, None)]
|
271 |
+
# print(history, task_history, text)
|
272 |
+
return history, task_history, ""
|
273 |
+
|
274 |
+
def add_file(history, task_history, file):
|
275 |
+
save_path = os.path.join("./history/test_image",file.name.split("/")[-2])
|
276 |
+
Path(save_path).mkdir(exist_ok=True,parents=True)
|
277 |
+
shutil.copy(file.name,save_path)
|
278 |
+
history = [((file.name,), None)]
|
279 |
+
task_history = [((file.name,), None)]
|
280 |
+
# print(history, task_history, file)
|
281 |
+
return history, task_history
|
282 |
+
|
283 |
+
def reset_user_input():
|
284 |
+
return gr.update(value="")
|
285 |
+
|
286 |
+
def reset_state(task_history):
|
287 |
+
# with open('./history/all_history_separate.jsonl', 'a',encoding="utf-8") as file: # 使用 'a' 模式打开文件,表示以追加模式写入
|
288 |
+
# data = task_history
|
289 |
+
# json_line = json.dumps(data)
|
290 |
+
# file.write(json_line + '\n')
|
291 |
+
task_history.clear()
|
292 |
+
return []
|
293 |
+
|
294 |
+
|
295 |
+
with gr.Blocks() as demo:
|
296 |
+
gr.Markdown(title_markdown)
|
297 |
+
|
298 |
+
chatbot = gr.Chatbot(label='Monkey', elem_classes="control-height", height=600,avatar_images=("./images/logo_user.png","./images/logo_monkey.png"),layout="bubble",bubble_full_width=False,show_copy_button=True)
|
299 |
+
query = gr.Textbox(lines=1, label='Input')
|
300 |
+
task_history = gr.State([])
|
301 |
+
|
302 |
+
with gr.Row():
|
303 |
+
empty_bin = gr.Button("Clear History")
|
304 |
+
submit_btn = gr.Button("Submit")
|
305 |
+
|
306 |
+
generate_btn_en = gr.Button("Generate")
|
307 |
+
addfile_btn = gr.UploadButton("Upload", file_types=["image"])
|
308 |
+
|
309 |
+
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
|
310 |
+
predict, [chatbot, task_history], [chatbot], show_progress=True
|
311 |
+
)
|
312 |
+
generate_btn_en.click(caption, [chatbot, task_history], [chatbot], show_progress=True)
|
313 |
+
|
314 |
+
submit_btn.click(reset_user_input, [], [query])
|
315 |
+
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
|
316 |
+
|
317 |
+
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True,scroll_to_output=True)
|
318 |
+
|
319 |
+
with gr.Row(variant="compact"):
|
320 |
+
with gr.Column(scale=2):
|
321 |
+
with gr.Row():
|
322 |
+
a = gr.Image(Image.open("./images/logo_monkey.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
|
323 |
+
b = gr.Image(Image.open("./images/logo_hust.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
|
324 |
+
with gr.Column(scale=4):
|
325 |
+
with gr.Row():
|
326 |
+
a = gr.Image(Image.open("./images/logo_monkey.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
|
327 |
+
c = gr.Image(Image.open("./images/logo_vlr.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
|
328 |
+
b = gr.Image(Image.open("./images/logo_hust.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
|
329 |
+
b = gr.Image(Image.open("./images/logo_king.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False)
|
330 |
+
with gr.Column(scale=2):
|
331 |
+
with gr.Row():
|
332 |
+
a = gr.Image(Image.open("./images/logo_monkey.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
|
333 |
+
b = gr.Image(Image.open("./images/logo_hust.png"),height=100,width=100,show_download_button=False,label="Generated images", show_label=False,render=False)
|
334 |
+
|
335 |
+
gr.Markdown(policy_markdown)
|
336 |
+
|
337 |
+
demo.queue().launch(
|
338 |
+
server_name="0.0.0.0",
|
339 |
+
server_port=7682,
|
340 |
+
share=True
|
341 |
+
)
|
342 |
+
|
343 |
+
|
344 |
+
def main():
|
345 |
+
args = _get_args()
|
346 |
+
|
347 |
+
model, tokenizer = _load_model_tokenizer(args)
|
348 |
+
_launch_demo(args, model, tokenizer)
|
349 |
+
|
350 |
+
|
351 |
+
if __name__ == '__main__':
|
352 |
+
main()
|