File size: 5,139 Bytes
0dd5c06
 
 
cf196b3
 
73e8b86
cf196b3
67812d2
73e8b86
a19f11e
076f69b
3c495cc
 
a19f11e
8ee349a
 
 
cf196b3
 
 
 
 
 
 
 
000d4f2
3c495cc
47db0c3
cf196b3
 
a089fa0
f00b7ff
a089fa0
3c495cc
 
 
 
 
 
 
 
 
 
 
 
 
47db0c3
3c495cc
a089fa0
f00b7ff
cf196b3
3c495cc
a089fa0
 
 
47db0c3
a089fa0
 
3c495cc
73e8b86
f00b7ff
a089fa0
 
 
73e8b86
1505f99
 
 
 
 
 
 
 
 
300b938
cf196b3
3c495cc
 
 
 
cf196b3
 
 
 
 
 
 
 
 
 
 
 
 
 
3c495cc
 
cf196b3
871741c
 
cf196b3
 
3c495cc
 
cf196b3
6b89337
000d4f2
6b89337
 
 
 
f00b7ff
 
1505f99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f00b7ff
 
 
 
73e8b86
4b31650
cf196b3
a089fa0
 
cf196b3
a089fa0
3c495cc
 
a089fa0
 
f00b7ff
4b31650
000d4f2
 
3c495cc
 
000d4f2
f00b7ff
 
 
 
73e8b86
076f69b
a19f11e
73e8b86
 
 
6b89337
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
It provides a platform for comparing the responses of two LLMs. 
"""
import enum
from uuid import uuid4

from firebase_admin import firestore
import gradio as gr

from leaderboard import build_leaderboard
from leaderboard import db
import response
from response import get_responses

SUPPORTED_TRANSLATION_LANGUAGES = [
    "Korean", "English", "Chinese", "Japanese", "Spanish", "French"
]


class VoteOptions(enum.Enum):
  MODEL_A = "Model A is better"
  MODEL_B = "Model B is better"
  TIE = "Tie"


def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
         user_prompt, instruction, category, source_lang, target_lang):
  doc_id = uuid4().hex
  winner = VoteOptions(vote_button).name.lower()

  deactivated_buttons = [gr.Button(interactive=False) for _ in range(3)]
  outputs = deactivated_buttons + [gr.Row(visible=True)]

  doc = {
      "id": doc_id,
      "prompt": user_prompt,
      "instruction": instruction,
      "model_a": model_a_name,
      "model_b": model_b_name,
      "model_a_response": response_a,
      "model_b_response": response_b,
      "winner": winner,
      "timestamp": firestore.SERVER_TIMESTAMP
  }

  if category == response.Category.SUMMARIZE.value:
    doc_ref = db.collection("arena-summarizations").document(doc_id)
    doc_ref.set(doc)

    return outputs

  if category == response.Category.TRANSLATE.value:
    if not source_lang or not target_lang:
      raise gr.Error("Please select source and target languages.")

    doc_ref = db.collection("arena-translations").document(doc_id)
    doc["source_language"] = source_lang.lower()
    doc["target_language"] = target_lang.lower()
    doc_ref.set(doc)

    return outputs

  raise gr.Error("Please select a response type.")


def scroll_to_bottom_js(elem_id):
  return f"""
  () => {{
    const element = document.querySelector("#{elem_id} textarea");
    element.scrollTop = element.scrollHeight;
  }}
  """


with gr.Blocks(title="Arena") as app:
  with gr.Row():
    category_radio = gr.Radio(
        [category.value for category in response.Category],
        label="Category",
        info="The chosen category determines the instruction sent to the LLMs.")

    source_language = gr.Dropdown(
        choices=SUPPORTED_TRANSLATION_LANGUAGES,
        label="Source language",
        info="Choose the source language for translation.",
        interactive=True,
        visible=False)
    target_language = gr.Dropdown(
        choices=SUPPORTED_TRANSLATION_LANGUAGES,
        label="Target language",
        info="Choose the target language for translation.",
        interactive=True,
        visible=False)

    def update_language_visibility(category):
      visible = category == response.Category.TRANSLATE.value
      return {
          source_language: gr.Dropdown(visible=visible),
          target_language: gr.Dropdown(visible=visible)
      }

    category_radio.change(update_language_visibility, category_radio,
                          [source_language, target_language])

  model_names = [gr.State(None), gr.State(None)]
  response_boxes = [gr.State(None), gr.State(None)]

  prompt = gr.TextArea(label="Prompt", lines=4)
  submit = gr.Button()

  with gr.Group():
    with gr.Row():
      response_a_elem_id = "responseA"
      response_a_textbox = gr.Textbox(label="Model A",
                                      interactive=False,
                                      elem_id=response_a_elem_id)
      response_a_textbox.change(fn=None,
                                js=scroll_to_bottom_js(response_a_elem_id))
      response_boxes[0] = response_a_textbox

      response_b_elem_id = "responseB"
      response_b_textbox = gr.Textbox(label="Model B",
                                      interactive=False,
                                      elem_id=response_b_elem_id)
      response_b_textbox.change(fn=None,
                                js=scroll_to_bottom_js(response_b_elem_id))
      response_boxes[1] = response_b_textbox

    with gr.Row(visible=False) as model_name_row:
      model_names[0] = gr.Textbox(show_label=False)
      model_names[1] = gr.Textbox(show_label=False)

  with gr.Row(visible=False) as vote_row:
    option_a = gr.Button(VoteOptions.MODEL_A.value)
    option_b = gr.Button(VoteOptions.MODEL_B.value)
    tie = gr.Button(VoteOptions.TIE.value)

  vote_buttons = [option_a, option_b, tie]
  instruction_state = gr.State("")

  submit.click(
      get_responses, [prompt, category_radio, source_language, target_language],
      response_boxes + model_names + vote_buttons +
      [instruction_state, model_name_row, vote_row])

  common_inputs = response_boxes + model_names + [
      prompt, instruction_state, category_radio, source_language,
      target_language
  ]
  common_outputs = vote_buttons + [model_name_row]
  option_a.click(vote, [option_a] + common_inputs, common_outputs)
  option_b.click(vote, [option_b] + common_inputs, common_outputs)
  tie.click(vote, [tie] + common_inputs, common_outputs)

  build_leaderboard()

if __name__ == "__main__":
  # We need to enable queue to use generators.
  app.queue()
  app.launch(debug=True)