File size: 4,179 Bytes
0dd5c06
 
 
 
cf196b3
 
73e8b86
cf196b3
 
67812d2
73e8b86
a19f11e
3c495cc
 
a19f11e
000d4f2
cf196b3
 
 
8ee349a
 
 
cf196b3
 
 
 
 
 
 
 
000d4f2
3c495cc
47db0c3
cf196b3
 
3c495cc
 
 
 
 
 
 
 
 
 
 
 
 
47db0c3
3c495cc
cf196b3
 
3c495cc
47db0c3
3c495cc
 
 
73e8b86
 
300b938
cf196b3
3c495cc
 
 
 
cf196b3
 
 
 
 
 
 
 
 
 
 
 
 
 
3c495cc
 
cf196b3
871741c
 
cf196b3
 
3c495cc
 
cf196b3
6b89337
000d4f2
6b89337
 
 
 
73e8b86
000d4f2
 
73e8b86
71d0339
3c495cc
71d0339
cf196b3
 
 
 
 
71d0339
73e8b86
 
 
 
 
3c495cc
 
 
 
 
000d4f2
 
3c495cc
 
000d4f2
 
 
 
73e8b86
a19f11e
 
73e8b86
 
 
6b89337
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
It provides a platform for comparing the responses of two LLMs. 
"""

import enum
from uuid import uuid4

import firebase_admin
from firebase_admin import firestore
import gradio as gr

from leaderboard import build_leaderboard
import response
from response import get_responses

# TODO(#21): Fix auto-reload issue related to the initialization of Firebase.
db_app = firebase_admin.initialize_app()
db = firestore.client()

SUPPORTED_TRANSLATION_LANGUAGES = [
    "Korean", "English", "Chinese", "Japanese", "Spanish", "French"
]


class VoteOptions(enum.Enum):
  MODEL_A = "Model A is better"
  MODEL_B = "Model B is better"
  TIE = "Tie"


def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
         user_prompt, instruction, category, source_lang, target_lang):
  doc_id = uuid4().hex
  winner = VoteOptions(vote_button).name.lower()

  doc = {
      "id": doc_id,
      "prompt": user_prompt,
      "instruction": instruction,
      "model_a": model_a_name,
      "model_b": model_b_name,
      "model_a_response": response_a,
      "model_b_response": response_b,
      "winner": winner,
      "timestamp": firestore.SERVER_TIMESTAMP
  }

  if category == response.Category.SUMMARIZE.value:
    doc_ref = db.collection("arena-summarizations").document(doc_id)
    doc_ref.set(doc)
    return

  if category == response.Category.TRANSLATE.value:
    doc_ref = db.collection("arena-translations").document(doc_id)
    doc["source_lang"] = source_lang.lower()
    doc["target_lang"] = target_lang.lower()
    doc_ref.set(doc)


with gr.Blocks(title="Arena") as app:
  with gr.Row():
    category_radio = gr.Radio(
        [category.value for category in response.Category],
        label="Category",
        info="The chosen category determines the instruction sent to the LLMs.")

    source_language = gr.Dropdown(
        choices=SUPPORTED_TRANSLATION_LANGUAGES,
        label="Source language",
        info="Choose the source language for translation.",
        interactive=True,
        visible=False)
    target_language = gr.Dropdown(
        choices=SUPPORTED_TRANSLATION_LANGUAGES,
        label="Target language",
        info="Choose the target language for translation.",
        interactive=True,
        visible=False)

    def update_language_visibility(category):
      visible = category == response.Category.TRANSLATE.value
      return {
          source_language: gr.Dropdown(visible=visible),
          target_language: gr.Dropdown(visible=visible)
      }

    category_radio.change(update_language_visibility, category_radio,
                          [source_language, target_language])

  model_names = [gr.State(None), gr.State(None)]
  response_boxes = [gr.State(None), gr.State(None)]

  prompt = gr.TextArea(label="Prompt", lines=4)
  submit = gr.Button()

  with gr.Row():
    response_boxes[0] = gr.Textbox(label="Model A", interactive=False)
    response_boxes[1] = gr.Textbox(label="Model B", interactive=False)

  # TODO(#5): Display it only after the user submits the prompt.
  # TODO(#6): Block voting if the category is not set.
  # TODO(#6): Block voting if the user already voted.
  with gr.Row():
    option_a = gr.Button(VoteOptions.MODEL_A.value)
    option_b = gr.Button("Model B is better")
    tie = gr.Button("Tie")

  # TODO(#7): Hide it until the user votes.
  with gr.Accordion("Show models", open=False):
    with gr.Row():
      model_names[0] = gr.Textbox(label="Model A", interactive=False)
      model_names[1] = gr.Textbox(label="Model B", interactive=False)

  instruction_state = gr.State("")

  submit.click(get_responses,
               [prompt, category_radio, source_language, target_language],
               response_boxes + model_names + [instruction_state])

  common_inputs = response_boxes + model_names + [
      prompt, instruction_state, category_radio, source_language,
      target_language
  ]
  option_a.click(vote, [option_a] + common_inputs)
  option_b.click(vote, [option_b] + common_inputs)
  tie.click(vote, [tie] + common_inputs)

  build_leaderboard(db)

if __name__ == "__main__":
  # We need to enable queue to use generators.
  app.queue()
  app.launch(debug=True)