arena / app.py
suhyun.kang
update TODO issues
71d0339
raw
history blame
6.83 kB
"""
It provides a platform for comparing the responses of two LLMs.
"""
import enum
import json
from random import sample
from uuid import uuid4
from fastchat.serve import gradio_web_server
from fastchat.serve.gradio_web_server import bot_response
import firebase_admin
from firebase_admin import firestore
import gradio as gr
db_app = firebase_admin.initialize_app()
db = firestore.client()
# TODO(#1): Add more models.
SUPPORTED_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gemini-pro"]
# TODO(#4): Add more languages.
SUPPORTED_TRANSLATION_LANGUAGES = ["Korean", "English"]
class ResponseType(enum.Enum):
SUMMARIZE = "Summarize"
TRANSLATE = "Translate"
class VoteOptions(enum.Enum):
MODEL_A = "Model A is better"
MODEL_B = "Model B is better"
TIE = "Tie"
def vote(state_a, state_b, vote_button, res_type, source_lang, target_lang):
doc_id = uuid4().hex
winner = VoteOptions(vote_button).name.lower()
# The 'messages' field in the state is an array of arrays, which is
# not supported by Firestore. Therefore, we convert it to a JSON string.
model_a_conv = json.dumps(state_a.dict())
model_b_conv = json.dumps(state_b.dict())
if res_type == ResponseType.SUMMARIZE.value:
doc_ref = db.collection("arena-summarizations").document(doc_id)
doc_ref.set({
"id": doc_id,
"model_a": state_a.model_name,
"model_b": state_b.model_name,
"model_a_conv": model_a_conv,
"model_b_conv": model_b_conv,
"winner": winner,
"timestamp": firestore.SERVER_TIMESTAMP
})
return
if res_type == ResponseType.TRANSLATE.value:
doc_ref = db.collection("arena-translations").document(doc_id)
doc_ref.set({
"id": doc_id,
"model_a": state_a.model_name,
"model_b": state_b.model_name,
"model_a_conv": model_a_conv,
"model_b_conv": model_b_conv,
"source_language": source_lang.lower(),
"target_language": target_lang.lower(),
"winner": winner,
"timestamp": firestore.SERVER_TIMESTAMP
})
def user(user_prompt):
model_pair = sample(SUPPORTED_MODELS, 2)
new_state_a = gradio_web_server.State(model_pair[0])
new_state_b = gradio_web_server.State(model_pair[1])
for state in [new_state_a, new_state_b]:
state.conv.append_message(state.conv.roles[0], user_prompt)
state.conv.append_message(state.conv.roles[1], None)
state.skip_next = False
return [
new_state_a, new_state_b, new_state_a.model_name, new_state_b.model_name
]
def bot(state_a, state_b, request: gr.Request):
new_states = [state_a, state_b]
generators = []
for state in new_states:
try:
# TODO(#1): Allow user to set configuration.
# bot_response returns a generator yielding states.
generator = bot_response(state,
temperature=0.9,
top_p=0.9,
max_new_tokens=100,
request=request)
generators.append(generator)
# TODO(#1): Narrow down the exception type.
except Exception as e: # pylint: disable=broad-except
print(f"Error in bot_response: {e}")
raise e
new_responses = [None, None]
# It simulates concurrent response generation from two models.
while True:
stop = True
for i in range(len(generators)):
try:
yielded = next(generators[i])
# The generator yields a tuple, with the new state as the first item.
new_state = yielded[0]
new_states[i] = new_state
# The last item from 'messages' represents the response to the prompt.
bot_message = new_state.conv.messages[-1]
# Each message in conv.messages is structured as [role, message],
# so we extract the last message component.
new_responses[i] = bot_message[-1]
stop = False
except StopIteration:
pass
# TODO(#1): Narrow down the exception type.
except Exception as e: # pylint: disable=broad-except
print(f"Error in generator: {e}")
raise e
yield new_states + new_responses
if stop:
break
with gr.Blocks() as app:
with gr.Row():
response_type_radio = gr.Radio(
[response_type.value for response_type in ResponseType],
label="Response type",
info="Choose the type of response you want from the model.")
source_language = gr.Dropdown(
choices=SUPPORTED_TRANSLATION_LANGUAGES,
label="Source language",
info="Choose the source language for translation.",
interactive=True,
visible=False)
target_language = gr.Dropdown(
choices=SUPPORTED_TRANSLATION_LANGUAGES,
label="Target language",
info="Choose the target language for translation.",
interactive=True,
visible=False)
def update_language_visibility(response_type):
visible = response_type == ResponseType.TRANSLATE.value
return {
source_language: gr.Dropdown(visible=visible),
target_language: gr.Dropdown(visible=visible)
}
response_type_radio.change(update_language_visibility, response_type_radio,
[source_language, target_language])
model_names = [gr.State(None), gr.State(None)]
responses = [gr.State(None), gr.State(None)]
# states stores FastChat-specific conversation states.
states = [gr.State(None), gr.State(None)]
prompt = gr.TextArea(label="Prompt", lines=4)
submit = gr.Button()
with gr.Row():
responses[0] = gr.Textbox(label="Model A", interactive=False)
responses[1] = gr.Textbox(label="Model B", interactive=False)
# TODO(#5): Display it only after the user submits the prompt.
# TODO(#6): Block voting if the response_type is not set.
# TODO(#6): Block voting if the user already voted.
with gr.Row():
option_a = gr.Button(VoteOptions.MODEL_A.value)
option_a.click(
vote, states +
[option_a, response_type_radio, source_language, target_language])
option_b = gr.Button("Model B is better")
option_b.click(
vote, states +
[option_b, response_type_radio, source_language, target_language])
tie = gr.Button("Tie")
tie.click(
vote,
states + [tie, response_type_radio, source_language, target_language])
# TODO(#7): Hide it until the user votes.
with gr.Accordion("Show models", open=False):
with gr.Row():
model_names[0] = gr.Textbox(label="Model A", interactive=False)
model_names[1] = gr.Textbox(label="Model B", interactive=False)
submit.click(user, prompt, states + model_names,
queue=False).then(bot, states, states + responses)
if __name__ == "__main__":
# We need to enable queue to use generators.
app.queue()
app.launch(debug=True)