Spaces:
Runtime error
Runtime error
""" | |
The gradio demo server with multiple tabs. | |
It supports chatting with a single model or chatting with two models side-by-side. | |
""" | |
import argparse | |
from typing import List | |
import gradio as gr | |
from serve.gradio_block_arena_anony import ( | |
build_side_by_side_ui_anony, | |
load_demo_side_by_side_anony, | |
set_global_vars_anony, | |
) | |
from serve.gradio_block_arena_named import ( | |
build_side_by_side_ui_named, | |
load_demo_side_by_side_named, | |
set_global_vars_named, | |
) | |
from serve.gradio_block_arena_vision import ( | |
build_single_vision_language_model_ui, | |
) | |
from serve.gradio_block_arena_vision_anony import ( | |
build_side_by_side_vision_ui_anony, | |
load_demo_side_by_side_vision_anony, | |
) | |
from serve.gradio_block_arena_vision_named import ( | |
build_side_by_side_vision_ui_named, | |
load_demo_side_by_side_vision_named, | |
) | |
from serve.gradio_global_state import Context | |
from serve.gradio_web_server import ( | |
set_global_vars, | |
block_css, | |
build_single_model_ui, | |
get_model_list, | |
load_demo_single, | |
get_ip, | |
) | |
from serve.utils import ( | |
build_logger, | |
get_window_url_params_js, | |
get_window_url_params_with_tos_js, | |
alert_js, | |
parse_gradio_auth_creds, | |
) | |
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") | |
def load_demo(context: Context, request: gr.Request): | |
ip = get_ip(request) | |
logger.info(f"load_demo. ip: {ip}. params: {request.query_params}") | |
inner_selected = 0 | |
if "arena" in request.query_params: | |
inner_selected = 0 | |
elif "vision" in request.query_params: | |
inner_selected = 0 | |
elif "compare" in request.query_params: | |
inner_selected = 1 | |
elif "direct" in request.query_params or "model" in request.query_params: | |
inner_selected = 2 | |
elif "leaderboard" in request.query_params: | |
inner_selected = 3 | |
elif "about" in request.query_params: | |
inner_selected = 4 | |
if args.model_list_mode == "reload": | |
context.text_models, context.all_text_models = get_model_list( | |
args.controller_url, | |
args.register_api_endpoint_file, | |
vision_arena=False, | |
) | |
context.vision_models, context.all_vision_models = get_model_list( | |
args.controller_url, | |
args.register_api_endpoint_file, | |
vision_arena=True, | |
) | |
# Text models | |
if args.vision_arena: | |
side_by_side_anony_updates = load_demo_side_by_side_vision_anony() | |
side_by_side_named_updates = load_demo_side_by_side_vision_named( | |
context, | |
) | |
direct_chat_updates = load_demo_single(context, request.query_params) | |
else: | |
direct_chat_updates = load_demo_single(context, request.query_params) | |
side_by_side_anony_updates = load_demo_side_by_side_anony( | |
context.all_text_models, request.query_params | |
) | |
side_by_side_named_updates = load_demo_side_by_side_named( | |
context.text_models, request.query_params | |
) | |
tabs_list = ( | |
[gr.Tabs(selected=inner_selected)] | |
+ side_by_side_anony_updates | |
+ side_by_side_named_updates | |
+ direct_chat_updates | |
) | |
return tabs_list | |
def build_demo( | |
context: Context, elo_results_file: str, leaderboard_table_file, arena_hard_table | |
): | |
if args.show_terms_of_use: | |
load_js = get_window_url_params_with_tos_js | |
else: | |
load_js = get_window_url_params_js | |
head_js = """ | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.min.js"></script> | |
""" | |
if args.ga_id is not None: | |
head_js += f""" | |
<script async src="https://www.googletagmanager.com/gtag/js?id={args.ga_id}"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag(){{dataLayer.push(arguments);}} | |
gtag('js', new Date()); | |
gtag('config', '{args.ga_id}'); | |
window.__gradio_mode__ = "app"; | |
</script> | |
""" | |
# head_js = """""" | |
text_size = gr.themes.sizes.text_lg | |
with gr.Blocks( | |
title="Chatbot Arena 日本語版α", | |
theme=gr.themes.Default(text_size=text_size), | |
css=block_css, | |
head=head_js, | |
) as demo: | |
with gr.Tabs() as inner_tabs: | |
if args.vision_arena: | |
with gr.Tab("⚔️ Arena (battle)", id=0) as arena_tab: | |
arena_tab.select(None, None, None, js=load_js) | |
side_by_side_anony_list = build_side_by_side_vision_ui_anony( | |
context, | |
random_questions=args.random_questions, | |
) | |
with gr.Tab("⚔️ Arena (side-by-side)", id=1) as side_by_side_tab: | |
side_by_side_tab.select(None, None, None, js=alert_js) | |
side_by_side_named_list = build_side_by_side_vision_ui_named( | |
context, random_questions=args.random_questions | |
) | |
with gr.Tab("💬 Direct Chat", id=2) as direct_tab: | |
direct_tab.select(None, None, None, js=alert_js) | |
single_model_list = build_single_vision_language_model_ui( | |
context, | |
add_promotion_links=True, | |
random_questions=args.random_questions, | |
) | |
else: | |
with gr.Tab("⚔️ Arena (battle)", id=0) as arena_tab: | |
arena_tab.select(None, None, None, js=load_js) | |
side_by_side_anony_list = build_side_by_side_ui_anony( | |
context.all_text_models | |
) | |
with gr.Tab("⚔️ Arena (side-by-side)", id=1) as side_by_side_tab: | |
side_by_side_tab.select(None, None, None, js=alert_js) | |
side_by_side_named_list = build_side_by_side_ui_named( | |
context.text_models | |
) | |
with gr.Tab("💬 Direct Chat", id=2) as direct_tab: | |
direct_tab.select(None, None, None, js=alert_js) | |
single_model_list = build_single_model_ui( | |
context.text_models, add_promotion_links=True | |
) | |
demo_tabs = ( | |
[inner_tabs] | |
+ side_by_side_anony_list | |
+ side_by_side_named_list | |
+ single_model_list | |
) | |
# if elo_results_file: | |
# with gr.Tab("🏆 Leaderboard", id=3): | |
# build_leaderboard_tab( | |
# elo_results_file, | |
# leaderboard_table_file, | |
# arena_hard_table, | |
# show_plot=True, | |
# ) | |
# with gr.Tab("ℹ️ About Us", id=4): | |
# about = build_about() | |
context_state = gr.State(context) | |
url_params = gr.JSON(visible=False) | |
if args.model_list_mode not in ["once", "reload"]: | |
raise ValueError( | |
f"Unknown model list mode: {args.model_list_mode}") | |
demo.load( | |
load_demo, | |
[context_state], | |
demo_tabs, | |
js=load_js, | |
) | |
return demo | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int) | |
parser.add_argument( | |
"--share", | |
action="store_true", | |
help="Whether to generate a public, shareable link", | |
) | |
parser.add_argument( | |
"--controller-url", | |
type=str, | |
default="http://localhost:21001", | |
help="The address of the controller", | |
) | |
parser.add_argument( | |
"--concurrency-count", | |
type=int, | |
default=10, | |
help="The concurrency count of the gradio queue", | |
) | |
parser.add_argument( | |
"--model-list-mode", | |
type=str, | |
default="once", | |
choices=["once", "reload"], | |
help="Whether to load the model list once or reload the model list every time.", | |
) | |
parser.add_argument( | |
"--moderate", | |
action="store_true", | |
help="Enable content moderation to block unsafe inputs", | |
) | |
parser.add_argument( | |
"--show-terms-of-use", | |
action="store_true", | |
help="Shows term of use before loading the demo", | |
) | |
parser.add_argument( | |
"--vision-arena", action="store_true", help="Show tabs for vision arena." | |
) | |
parser.add_argument( | |
"--random-questions", type=str, help="Load random questions from a JSON file" | |
) | |
parser.add_argument( | |
"--register-api-endpoint-file", | |
type=str, | |
help="Register API-based model endpoints from a JSON file", | |
default="api_endpoints.json", | |
) | |
parser.add_argument( | |
"--gradio-auth-path", | |
type=str, | |
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', | |
default=None, | |
) | |
parser.add_argument( | |
"--elo-results-file", type=str, help="Load leaderboard results and plots" | |
) | |
parser.add_argument( | |
"--leaderboard-table-file", type=str, help="Load leaderboard results and plots" | |
) | |
parser.add_argument( | |
"--arena-hard-table", type=str, help="Load leaderboard results and plots" | |
) | |
parser.add_argument( | |
"--gradio-root-path", | |
type=str, | |
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", | |
) | |
parser.add_argument( | |
"--ga-id", | |
type=str, | |
help="the Google Analytics ID", | |
default=None, | |
) | |
parser.add_argument( | |
"--use-remote-storage", | |
action="store_true", | |
default=False, | |
help="Uploads image files to google cloud storage if set to true", | |
) | |
parser.add_argument( | |
"--password", | |
type=str, | |
help="Set the password for the gradio web server", | |
) | |
args = parser.parse_args() | |
logger.info(f"args: {args}") | |
# Set global variables | |
set_global_vars(args.controller_url, args.moderate, | |
args.use_remote_storage) | |
set_global_vars_named(args.moderate) | |
set_global_vars_anony(args.moderate) | |
text_models, all_text_models = get_model_list( | |
args.controller_url, | |
args.register_api_endpoint_file, | |
vision_arena=False, | |
) | |
vision_models, all_vision_models = get_model_list( | |
args.controller_url, | |
args.register_api_endpoint_file, | |
vision_arena=True, | |
) | |
models = text_models + [ | |
model for model in vision_models if model not in text_models | |
] | |
all_models = all_text_models + [ | |
model for model in all_vision_models if model not in all_text_models | |
] | |
context = Context( | |
text_models, | |
all_text_models, | |
vision_models, | |
all_vision_models, | |
models, | |
all_models, | |
) | |
# Set authorization credentials | |
auth = None | |
if args.gradio_auth_path is not None: | |
auth = parse_gradio_auth_creds(args.gradio_auth_path) | |
# Launch the demo | |
demo = build_demo( | |
context, | |
args.elo_results_file, | |
args.leaderboard_table_file, | |
args.arena_hard_table, | |
) | |
demo.queue( | |
default_concurrency_limit=args.concurrency_count, | |
status_update_rate=10, | |
api_open=False, | |
).launch( | |
server_name=args.host, | |
server_port=args.port, | |
share=args.share, | |
max_threads=200, | |
auth=auth, | |
root_path=args.gradio_root_path, | |
show_api=False, | |
) | |