Spaces:
Running
Running
import spaces | |
import argparse | |
from ast import parse | |
import datetime | |
import json | |
import os | |
import time | |
import hashlib | |
import re | |
import gradio as gr | |
import requests | |
import random | |
from filelock import FileLock | |
from io import BytesIO | |
from PIL import Image, ImageDraw, ImageFont | |
from constants import LOGDIR | |
from utils import ( | |
build_logger, | |
server_error_msg, | |
violates_moderation, | |
moderation_msg, | |
load_image_from_base64, | |
get_log_filename, | |
) | |
from conversation import Conversation | |
logger = build_logger("gradio_web_server", "gradio_web_server.log") | |
headers = {"User-Agent": "InternVL-Chat Client"} | |
no_change_btn = gr.Button() | |
enable_btn = gr.Button(interactive=True) | |
disable_btn = gr.Button(interactive=False) | |
def make_zerogpu_happy(): | |
pass | |
def write2file(path, content): | |
lock = FileLock(f"{path}.lock") | |
with lock: | |
with open(path, "a") as fout: | |
fout.write(content) | |
get_window_url_params = """ | |
function() { | |
const params = new URLSearchParams(window.location.search); | |
url_params = Object.fromEntries(params); | |
console.log(url_params); | |
return url_params; | |
} | |
""" | |
def init_state(state=None): | |
if state is not None: | |
del state | |
return Conversation() | |
def find_bounding_boxes(state, response): | |
pattern = re.compile(r"<ref>\s*(.*?)\s*</ref>\s*<box>\s*(\[\[.*?\]\])\s*</box>") | |
matches = pattern.findall(response) | |
results = [] | |
for match in matches: | |
results.append((match[0], eval(match[1]))) | |
returned_image = None | |
latest_image = state.get_images(source=state.USER)[-1] | |
returned_image = latest_image.copy() | |
width, height = returned_image.size | |
draw = ImageDraw.Draw(returned_image) | |
for result in results: | |
line_width = max(1, int(min(width, height) / 200)) | |
random_color = ( | |
random.randint(0, 128), | |
random.randint(0, 128), | |
random.randint(0, 128), | |
) | |
category_name, coordinates = result | |
coordinates = [ | |
( | |
float(x[0]) / 1000, | |
float(x[1]) / 1000, | |
float(x[2]) / 1000, | |
float(x[3]) / 1000, | |
) | |
for x in coordinates | |
] | |
coordinates = [ | |
( | |
int(x[0] * width), | |
int(x[1] * height), | |
int(x[2] * width), | |
int(x[3] * height), | |
) | |
for x in coordinates | |
] | |
for box in coordinates: | |
draw.rectangle(box, outline=random_color, width=line_width) | |
font = ImageFont.truetype("assets/SimHei.ttf", int(20 * line_width / 2)) | |
text_size = font.getbbox(category_name) | |
text_width, text_height = ( | |
text_size[2] - text_size[0], | |
text_size[3] - text_size[1], | |
) | |
text_position = (box[0], max(0, box[1] - text_height)) | |
draw.rectangle( | |
[ | |
text_position, | |
(text_position[0] + text_width, text_position[1] + text_height), | |
], | |
fill=random_color, | |
) | |
draw.text(text_position, category_name, fill="white", font=font) | |
return returned_image if len(matches) > 0 else None | |
def vote_last_response(state, liked, request: gr.Request): | |
conv_data = { | |
"tstamp": round(time.time(), 4), | |
"like": liked, | |
"model": 'InternVL2.5-78B', | |
"state": state.dict(), | |
"ip": request.client.host, | |
} | |
write2file(get_log_filename(), json.dumps(conv_data) + "\n") | |
def upvote_last_response(state, request: gr.Request): | |
logger.info(f"upvote. ip: {request.client.host}") | |
vote_last_response(state, True, request) | |
textbox = gr.MultimodalTextbox(value=None, interactive=True) | |
return (textbox,) + (disable_btn,) * 3 | |
def downvote_last_response(state, request: gr.Request): | |
logger.info(f"downvote. ip: {request.client.host}") | |
vote_last_response(state, False, request) | |
textbox = gr.MultimodalTextbox(value=None, interactive=True) | |
return (textbox,) + (disable_btn,) * 3 | |
def vote_selected_response( | |
state, request: gr.Request, data: gr.LikeData | |
): | |
logger.info( | |
f"Vote: {data.liked}, index: {data.index}, value: {data.value} , ip: {request.client.host}" | |
) | |
conv_data = { | |
"tstamp": round(time.time(), 4), | |
"like": data.liked, | |
"index": data.index, | |
"model": 'InternVL2.5-78B', | |
"state": state.dict(), | |
"ip": request.client.host, | |
} | |
write2file(get_log_filename(), json.dumps(conv_data) + "\n") | |
return | |
def flag_last_response(state, request: gr.Request): | |
logger.info(f"flag. ip: {request.client.host}") | |
vote_last_response(state, "flag", request) | |
textbox = gr.MultimodalTextbox(value=None, interactive=True) | |
return (textbox,) + (disable_btn,) * 3 | |
def regenerate(state, image_process_mode, request: gr.Request): | |
logger.info(f"regenerate. ip: {request.client.host}") | |
# state.messages[-1][-1] = None | |
state.update_message(Conversation.ASSISTANT, content='', image=None, idx=-1) | |
prev_human_msg = state.messages[-2] | |
if type(prev_human_msg[1]) in (tuple, list): | |
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) | |
state.skip_next = False | |
textbox = gr.MultimodalTextbox(value=None, interactive=True) | |
return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5 | |
def clear_history(request: gr.Request): | |
logger.info(f"clear_history. ip: {request.client.host}") | |
state = init_state() | |
textbox = gr.MultimodalTextbox(value=None, interactive=True) | |
return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5 | |
def add_text(state, message, system_prompt, request: gr.Request): | |
print(f"state: {state}") | |
if not state: | |
state = init_state() | |
images = message.get("files", []) | |
text = message.get("text", "").strip() | |
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") | |
# import pdb; pdb.set_trace() | |
textbox = gr.MultimodalTextbox(value=None, interactive=False) | |
if len(text) <= 0 and len(images) == 0: | |
state.skip_next = True | |
return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5 | |
if args.moderate: | |
flagged = violates_moderation(text) | |
if flagged: | |
state.skip_next = True | |
textbox = gr.MultimodalTextbox( | |
value={"text": moderation_msg}, interactive=True | |
) | |
return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5 | |
images = [Image.open(path).convert("RGB") for path in images] | |
if len(images) > 0 and len(state.get_images(source=state.USER)) > 0: | |
state = init_state(state) | |
state.set_system_message(system_prompt) | |
state.append_message(Conversation.USER, text, images) | |
state.skip_next = False | |
return (state, state.to_gradio_chatbot(), textbox) + ( | |
disable_btn, | |
) * 5 | |
def http_bot( | |
state, | |
temperature, | |
top_p, | |
repetition_penalty, | |
max_new_tokens, | |
max_input_tiles, | |
request: gr.Request, | |
): | |
model_name = 'InternVL2.5-78B' | |
logger.info(f"http_bot. ip: {request.client.host}") | |
start_tstamp = time.time() | |
if hasattr(state, "skip_next") and state.skip_next: | |
# This generate call is skipped due to invalid inputs | |
yield ( | |
state, | |
state.to_gradio_chatbot(), | |
gr.MultimodalTextbox(interactive=False), | |
) + (no_change_btn,) * 5 | |
return | |
worker_addr = os.environ.get("WORKER_ADDR", "") | |
api_token = os.environ.get("API_TOKEN", "") | |
headers = {"Authorization": f"{api_token}", "Content-Type": "application/json"} | |
# No available worker | |
if worker_addr == "": | |
# state.messages[-1][-1] = server_error_msg | |
state.update_message(Conversation.ASSISTANT, server_error_msg) | |
yield ( | |
state, | |
state.to_gradio_chatbot(), | |
gr.MultimodalTextbox(interactive=False), | |
disable_btn, | |
disable_btn, | |
disable_btn, | |
enable_btn, | |
enable_btn, | |
) | |
return | |
all_images = state.get_images(source=state.USER) | |
all_image_paths = [state.save_image(image) for image in all_images] | |
# Make requests | |
pload = { | |
"model": model_name, | |
"messages": state.get_prompt_v2(inlude_image=True, max_dynamic_patch=max_input_tiles), | |
"temperature": float(temperature), | |
"top_p": float(top_p), | |
"max_tokens": max_new_tokens, | |
"repetition_penalty": repetition_penalty, | |
"stream": True | |
} | |
logger.info(f"==== request ====\n{pload}") | |
state.append_message(Conversation.ASSISTANT, state.streaming_placeholder) | |
yield ( | |
state, | |
state.to_gradio_chatbot(), | |
gr.MultimodalTextbox(interactive=False), | |
) + (disable_btn,) * 5 | |
try: | |
# Stream output | |
response = requests.post(worker_addr, json=pload, headers=headers, stream=True, timeout=40) | |
finnal_output = '' | |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\n"): | |
if chunk: | |
chunk = chunk.decode() | |
if chunk == 'data: [DONE]': | |
break | |
if chunk.startswith("data:"): | |
chunk = chunk[5:] | |
chunk = json.loads(chunk) | |
output = chunk['choices'][0]['delta']['content'] | |
finnal_output += output | |
state.update_message(Conversation.ASSISTANT, finnal_output + state.streaming_placeholder, None) | |
yield ( | |
state, | |
state.to_gradio_chatbot(), | |
gr.MultimodalTextbox(interactive=False), | |
) + (disable_btn,) * 5 | |
except requests.exceptions.RequestException as e: | |
state.update_message(Conversation.ASSISTANT, server_error_msg, None) | |
yield ( | |
state, | |
state.to_gradio_chatbot(), | |
gr.MultimodalTextbox(interactive=True), | |
) + ( | |
disable_btn, | |
disable_btn, | |
disable_btn, | |
enable_btn, | |
enable_btn, | |
) | |
return | |
ai_response = state.return_last_message() | |
if "<ref>" in ai_response: | |
returned_image = find_bounding_boxes(state, ai_response) | |
returned_image = [returned_image] if returned_image else [] | |
state.update_message(Conversation.ASSISTANT, ai_response, returned_image) | |
state.end_of_current_turn() | |
yield ( | |
state, | |
state.to_gradio_chatbot(), | |
gr.MultimodalTextbox(interactive=True), | |
) + (enable_btn,) * 5 | |
finish_tstamp = time.time() | |
logger.info(f"{finnal_output}") | |
data = { | |
"tstamp": round(finish_tstamp, 4), | |
"like": None, | |
"model": model_name, | |
"start": round(start_tstamp, 4), | |
"finish": round(start_tstamp, 4), | |
"state": state.dict(), | |
"images": all_image_paths, | |
"ip": request.client.host, | |
} | |
write2file(get_log_filename(), json.dumps(data) + "\n") | |
title_html = """ | |
<img src="https://internvl.opengvlab.com/assets/logo-47b364d3.jpg" style="width: 280px; height: 70px;"> | |
<a href="https://internvl.github.io/blog/2024-07-02-InternVL-2.0/">[📜 InternVL2 Blog]</a> | |
<a href="https://internvl.opengvlab.com/">[🌟 Official Demo]</a> | |
<a href="https://github.com/OpenGVLab/InternVL?tab=readme-ov-file#quick-start-with-huggingface">[🚀 Quick Start]</a> | |
<a href="https://github.com/OpenGVLab/InternVL/blob/main/document/How_to_use_InternVL_API.md">[🌐 API]</a> | |
""" | |
# .gradio-container {margin: 5px 10px 0 10px !important}; | |
block_css = """ | |
.gradio-container {margin: 0.1% 1% 0 1% !important; max-width: 98% !important;}; | |
#buttons button { | |
min-width: min(120px,100%); | |
} | |
.gradient-text { | |
font-size: 28px; | |
width: auto; | |
font-weight: bold; | |
background: linear-gradient(45deg, red, orange, yellow, green, blue, indigo, violet); | |
background-clip: text; | |
-webkit-background-clip: text; | |
color: transparent; | |
} | |
.plain-text { | |
font-size: 22px; | |
width: auto; | |
font-weight: bold; | |
} | |
""" | |
js = """ | |
function createWaveAnimation() { | |
const text = document.getElementById('text'); | |
var i = 0; | |
setInterval(function() { | |
const colors = [ | |
'red, orange, yellow, green, blue, indigo, violet, purple', | |
'orange, yellow, green, blue, indigo, violet, purple, red', | |
'yellow, green, blue, indigo, violet, purple, red, orange', | |
'green, blue, indigo, violet, purple, red, orange, yellow', | |
'blue, indigo, violet, purple, red, orange, yellow, green', | |
'indigo, violet, purple, red, orange, yellow, green, blue', | |
'violet, purple, red, orange, yellow, green, blue, indigo', | |
'purple, red, orange, yellow, green, blue, indigo, violet', | |
]; | |
const angle = 45; | |
const colorIndex = i % colors.length; | |
text.style.background = `linear-gradient(${angle}deg, ${colors[colorIndex]})`; | |
text.style.webkitBackgroundClip = 'text'; | |
text.style.backgroundClip = 'text'; | |
text.style.color = 'transparent'; | |
text.style.fontSize = '28px'; | |
text.style.width = 'auto'; | |
text.textContent = 'InternVL2'; | |
text.style.fontWeight = 'bold'; | |
i += 1; | |
}, 200); | |
const params = new URLSearchParams(window.location.search); | |
url_params = Object.fromEntries(params); | |
// console.log(url_params); | |
// console.log('hello world...'); | |
// console.log(window.location.search); | |
// console.log('hello world...'); | |
// alert(window.location.search) | |
// alert(url_params); | |
return url_params; | |
} | |
""" | |
def build_demo(): | |
textbox = gr.MultimodalTextbox( | |
interactive=True, | |
file_types=["image", "video"], | |
placeholder="Enter message or upload file...", | |
show_label=False, | |
) | |
with gr.Blocks( | |
title="InternVL-Chat", | |
theme=gr.themes.Default(), | |
css=block_css, | |
) as demo: | |
state = gr.State() | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# gr.Image('./gallery/logo-47b364d3.jpg') | |
gr.HTML(title_html) | |
with gr.Accordion("Settings", open=False) as setting_row: | |
system_prompt = gr.Textbox( | |
value="请尽可能详细地回答用户的问题。", | |
label="System Prompt", | |
interactive=True, | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.2, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
) | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
interactive=True, | |
label="Top P", | |
) | |
repetition_penalty = gr.Slider( | |
minimum=1.0, | |
maximum=1.5, | |
value=1.1, | |
step=0.02, | |
interactive=True, | |
label="Repetition penalty", | |
) | |
max_output_tokens = gr.Slider( | |
minimum=0, | |
maximum=4096, | |
value=1024, | |
step=64, | |
interactive=True, | |
label="Max output tokens", | |
) | |
max_input_tiles = gr.Slider( | |
minimum=1, | |
maximum=32, | |
value=12, | |
step=1, | |
interactive=True, | |
label="Max input tiles (control the image size)", | |
) | |
examples = gr.Examples( | |
examples=[ | |
[ | |
{ | |
"files": [ | |
"gallery/14.jfif", | |
], | |
"text": "Please help me analyze this picture.", | |
} | |
], | |
[ | |
{ | |
"files": [ | |
"gallery/1-2.PNG", | |
], | |
"text": "Implement this flow chart using python", | |
} | |
], | |
[ | |
{ | |
"files": [ | |
"gallery/15.PNG", | |
], | |
"text": "Please help me analyze this picture.", | |
} | |
], | |
], | |
inputs=[textbox], | |
) | |
with gr.Column(scale=8): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="InternVL2", | |
height=580, | |
show_copy_button=True, | |
show_share_button=True, | |
avatar_images=[ | |
"assets/human.png", | |
"assets/assistant.png", | |
], | |
bubble_full_width=False, | |
) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
textbox.render() | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button(value="Send", variant="primary") | |
with gr.Row(elem_id="buttons") as button_row: | |
upvote_btn = gr.Button(value="👍 Upvote", interactive=False) | |
downvote_btn = gr.Button(value="👎 Downvote", interactive=False) | |
flag_btn = gr.Button(value="⚠️ Flag", interactive=False) | |
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) | |
regenerate_btn = gr.Button( | |
value="🔄 Regenerate", interactive=False | |
) | |
clear_btn = gr.Button(value="🗑️ Clear", interactive=False) | |
url_params = gr.JSON(visible=False) | |
# Register listeners | |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] | |
upvote_btn.click( | |
upvote_last_response, | |
[state], | |
[textbox, upvote_btn, downvote_btn, flag_btn], | |
) | |
downvote_btn.click( | |
downvote_last_response, | |
[state], | |
[textbox, upvote_btn, downvote_btn, flag_btn], | |
) | |
chatbot.like( | |
vote_selected_response, | |
[state], | |
[], | |
) | |
flag_btn.click( | |
flag_last_response, | |
[state], | |
[textbox, upvote_btn, downvote_btn, flag_btn], | |
) | |
regenerate_btn.click( | |
regenerate, | |
[state, system_prompt], | |
[state, chatbot, textbox] + btn_list, | |
).then( | |
http_bot, | |
[ | |
state, | |
temperature, | |
top_p, | |
repetition_penalty, | |
max_output_tokens, | |
max_input_tiles, | |
], | |
[state, chatbot, textbox] + btn_list, | |
) | |
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) | |
textbox.submit( | |
add_text, | |
[state, textbox, system_prompt], | |
[state, chatbot, textbox] + btn_list, | |
).then( | |
http_bot, | |
[ | |
state, | |
temperature, | |
top_p, | |
repetition_penalty, | |
max_output_tokens, | |
max_input_tiles, | |
], | |
[state, chatbot, textbox] + btn_list, | |
) | |
submit_btn.click( | |
add_text, | |
[state, textbox, system_prompt], | |
[state, chatbot, textbox] + btn_list, | |
).then( | |
http_bot, | |
[ | |
state, | |
temperature, | |
top_p, | |
repetition_penalty, | |
max_output_tokens, | |
max_input_tiles, | |
], | |
[state, chatbot, textbox] + btn_list, | |
) | |
return demo | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--concurrency-count", type=int, default=10) | |
parser.add_argument("--share", action="store_true") | |
parser.add_argument("--moderate", action="store_true") | |
args = parser.parse_args() | |
logger.info(f"args: {args}") | |
logger.info(args) | |
demo = build_demo() | |
demo.queue(api_open=False).launch( | |
server_name=args.host, | |
server_port=args.port, | |
share=args.share, | |
max_threads=args.concurrency_count, | |
) | |