Spaces:
Running
Running
File size: 4,179 Bytes
0dd5c06 cf196b3 73e8b86 cf196b3 67812d2 73e8b86 a19f11e 3c495cc a19f11e 000d4f2 cf196b3 8ee349a cf196b3 000d4f2 3c495cc 47db0c3 cf196b3 3c495cc 47db0c3 3c495cc cf196b3 3c495cc 47db0c3 3c495cc 73e8b86 300b938 cf196b3 3c495cc cf196b3 3c495cc cf196b3 871741c cf196b3 3c495cc cf196b3 6b89337 000d4f2 6b89337 73e8b86 000d4f2 73e8b86 71d0339 3c495cc 71d0339 cf196b3 71d0339 73e8b86 3c495cc 000d4f2 3c495cc 000d4f2 73e8b86 a19f11e 73e8b86 6b89337 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""
It provides a platform for comparing the responses of two LLMs.
"""
import enum
from uuid import uuid4
import firebase_admin
from firebase_admin import firestore
import gradio as gr
from leaderboard import build_leaderboard
import response
from response import get_responses
# TODO(#21): Fix auto-reload issue related to the initialization of Firebase.
db_app = firebase_admin.initialize_app()
db = firestore.client()
SUPPORTED_TRANSLATION_LANGUAGES = [
"Korean", "English", "Chinese", "Japanese", "Spanish", "French"
]
class VoteOptions(enum.Enum):
MODEL_A = "Model A is better"
MODEL_B = "Model B is better"
TIE = "Tie"
def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
user_prompt, instruction, category, source_lang, target_lang):
doc_id = uuid4().hex
winner = VoteOptions(vote_button).name.lower()
doc = {
"id": doc_id,
"prompt": user_prompt,
"instruction": instruction,
"model_a": model_a_name,
"model_b": model_b_name,
"model_a_response": response_a,
"model_b_response": response_b,
"winner": winner,
"timestamp": firestore.SERVER_TIMESTAMP
}
if category == response.Category.SUMMARIZE.value:
doc_ref = db.collection("arena-summarizations").document(doc_id)
doc_ref.set(doc)
return
if category == response.Category.TRANSLATE.value:
doc_ref = db.collection("arena-translations").document(doc_id)
doc["source_lang"] = source_lang.lower()
doc["target_lang"] = target_lang.lower()
doc_ref.set(doc)
with gr.Blocks(title="Arena") as app:
with gr.Row():
category_radio = gr.Radio(
[category.value for category in response.Category],
label="Category",
info="The chosen category determines the instruction sent to the LLMs.")
source_language = gr.Dropdown(
choices=SUPPORTED_TRANSLATION_LANGUAGES,
label="Source language",
info="Choose the source language for translation.",
interactive=True,
visible=False)
target_language = gr.Dropdown(
choices=SUPPORTED_TRANSLATION_LANGUAGES,
label="Target language",
info="Choose the target language for translation.",
interactive=True,
visible=False)
def update_language_visibility(category):
visible = category == response.Category.TRANSLATE.value
return {
source_language: gr.Dropdown(visible=visible),
target_language: gr.Dropdown(visible=visible)
}
category_radio.change(update_language_visibility, category_radio,
[source_language, target_language])
model_names = [gr.State(None), gr.State(None)]
response_boxes = [gr.State(None), gr.State(None)]
prompt = gr.TextArea(label="Prompt", lines=4)
submit = gr.Button()
with gr.Row():
response_boxes[0] = gr.Textbox(label="Model A", interactive=False)
response_boxes[1] = gr.Textbox(label="Model B", interactive=False)
# TODO(#5): Display it only after the user submits the prompt.
# TODO(#6): Block voting if the category is not set.
# TODO(#6): Block voting if the user already voted.
with gr.Row():
option_a = gr.Button(VoteOptions.MODEL_A.value)
option_b = gr.Button("Model B is better")
tie = gr.Button("Tie")
# TODO(#7): Hide it until the user votes.
with gr.Accordion("Show models", open=False):
with gr.Row():
model_names[0] = gr.Textbox(label="Model A", interactive=False)
model_names[1] = gr.Textbox(label="Model B", interactive=False)
instruction_state = gr.State("")
submit.click(get_responses,
[prompt, category_radio, source_language, target_language],
response_boxes + model_names + [instruction_state])
common_inputs = response_boxes + model_names + [
prompt, instruction_state, category_radio, source_language,
target_language
]
option_a.click(vote, [option_a] + common_inputs)
option_b.click(vote, [option_b] + common_inputs)
tie.click(vote, [tie] + common_inputs)
build_leaderboard(db)
if __name__ == "__main__":
# We need to enable queue to use generators.
app.queue()
app.launch(debug=True)
|