File size: 5,958 Bytes
0dd5c06
 
 
 
cf196b3
73e8b86
cf196b3
73e8b86
cf196b3
 
67812d2
000d4f2
73e8b86
a19f11e
 
000d4f2
cf196b3
 
 
73e8b86
000d4f2
 
 
73e8b86
71d0339
cf196b3
 
 
 
 
 
 
 
 
 
 
 
 
 
000d4f2
 
47db0c3
cf196b3
 
 
47db0c3
cf196b3
47db0c3
000d4f2
 
 
 
 
cf196b3
 
 
 
 
 
47db0c3
cf196b3
47db0c3
000d4f2
 
 
 
 
cf196b3
 
 
 
 
 
73e8b86
000d4f2
 
 
 
 
73e8b86
000d4f2
 
 
73e8b86
 
000d4f2
 
6b89337
73e8b86
000d4f2
73e8b86
 
000d4f2
 
 
 
 
 
 
73e8b86
 
 
 
 
 
000d4f2
6b89337
 
73e8b86
 
 
44ad98f
73e8b86
c73f9e9
 
000d4f2
 
c73f9e9
000d4f2
73e8b86
c73f9e9
000d4f2
 
73e8b86
 
 
2e9ad55
 
 
 
 
73e8b86
 
 
 
 
cf196b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871741c
cf196b3
871741c
 
cf196b3
 
 
 
 
6b89337
000d4f2
6b89337
 
 
 
73e8b86
000d4f2
 
73e8b86
71d0339
 
 
cf196b3
 
 
 
 
71d0339
73e8b86
 
 
 
 
000d4f2
 
 
 
 
 
 
 
73e8b86
a19f11e
 
73e8b86
 
 
6b89337
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
It provides a platform for comparing the responses of two LLMs. 
"""

import enum
from random import sample
from uuid import uuid4

import firebase_admin
from firebase_admin import firestore
import gradio as gr
from litellm import completion

from leaderboard import build_leaderboard

# TODO(#21): Fix auto-reload issue related to the initialization of Firebase.
db_app = firebase_admin.initialize_app()
db = firestore.client()

# TODO(#1): Add more models.
SUPPORTED_MODELS = [
    "gpt-4", "gpt-4-0125-preview", "gpt-3.5-turbo", "gemini-pro"
]

# TODO(#4): Add more languages.
SUPPORTED_TRANSLATION_LANGUAGES = ["Korean", "English"]


class ResponseType(enum.Enum):
  SUMMARIZE = "Summarize"
  TRANSLATE = "Translate"


class VoteOptions(enum.Enum):
  MODEL_A = "Model A is better"
  MODEL_B = "Model B is better"
  TIE = "Tie"


def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
         user_prompt, res_type, source_lang, target_lang):
  doc_id = uuid4().hex
  winner = VoteOptions(vote_button).name.lower()

  if res_type == ResponseType.SUMMARIZE.value:
    doc_ref = db.collection("arena-summarizations").document(doc_id)
    doc_ref.set({
        "id": doc_id,
        "prompt": user_prompt,
        "model_a": model_a_name,
        "model_b": model_b_name,
        "model_a_response": response_a,
        "model_b_response": response_b,
        "winner": winner,
        "timestamp": firestore.SERVER_TIMESTAMP
    })
    return

  if res_type == ResponseType.TRANSLATE.value:
    doc_ref = db.collection("arena-translations").document(doc_id)
    doc_ref.set({
        "id": doc_id,
        "prompt": user_prompt,
        "model_a": model_a_name,
        "model_b": model_b_name,
        "model_a_response": response_a,
        "model_b_response": response_b,
        "source_language": source_lang.lower(),
        "target_language": target_lang.lower(),
        "winner": winner,
        "timestamp": firestore.SERVER_TIMESTAMP
    })


def response_generator(response: str):
  for part in response:
    content = part.choices[0].delta.content
    if content is None:
      continue

    # To simulate a stream, we yield each character of the response.
    for character in content:
      yield character


def get_responses(user_prompt):
  models = sample(SUPPORTED_MODELS, 2)

  generators = []
  for model in models:
    try:
      # TODO(#1): Allow user to set configuration.
      response = completion(model=model,
                            messages=[{
                                "content": user_prompt,
                                "role": "user"
                            }],
                            stream=True)
      generators.append(response_generator(response))

    # TODO(#1): Narrow down the exception type.
    except Exception as e:  # pylint: disable=broad-except
      print(f"Error in bot_response: {e}")
      raise e

  responses = ["", ""]

  # It simulates concurrent response generation from two models.
  while True:
    stop = True

    for i in range(len(generators)):
      try:
        yielded = next(generators[i])

        if yielded is None:
          continue

        responses[i] += yielded
        stop = False

        yield responses + models

      except StopIteration:
        pass

      # TODO(#1): Narrow down the exception type.
      except Exception as e:  # pylint: disable=broad-except
        print(f"Error in generator: {e}")
        raise e

    if stop:
      break


with gr.Blocks() as app:
  with gr.Row():
    response_type_radio = gr.Radio(
        [response_type.value for response_type in ResponseType],
        label="Response type",
        info="Choose the type of response you want from the model.")

    source_language = gr.Dropdown(
        choices=SUPPORTED_TRANSLATION_LANGUAGES,
        label="Source language",
        info="Choose the source language for translation.",
        interactive=True,
        visible=False)
    target_language = gr.Dropdown(
        choices=SUPPORTED_TRANSLATION_LANGUAGES,
        label="Target language",
        info="Choose the target language for translation.",
        interactive=True,
        visible=False)

    def update_language_visibility(response_type):
      visible = response_type == ResponseType.TRANSLATE.value
      return {
          source_language: gr.Dropdown(visible=visible),
          target_language: gr.Dropdown(visible=visible)
      }

    response_type_radio.change(update_language_visibility, response_type_radio,
                               [source_language, target_language])

  model_names = [gr.State(None), gr.State(None)]
  response_boxes = [gr.State(None), gr.State(None)]

  prompt = gr.TextArea(label="Prompt", lines=4)
  submit = gr.Button()

  with gr.Row():
    response_boxes[0] = gr.Textbox(label="Model A", interactive=False)
    response_boxes[1] = gr.Textbox(label="Model B", interactive=False)

  # TODO(#5): Display it only after the user submits the prompt.
  # TODO(#6): Block voting if the response_type is not set.
  # TODO(#6): Block voting if the user already voted.
  with gr.Row():
    option_a = gr.Button(VoteOptions.MODEL_A.value)
    option_b = gr.Button("Model B is better")
    tie = gr.Button("Tie")

  # TODO(#7): Hide it until the user votes.
  with gr.Accordion("Show models", open=False):
    with gr.Row():
      model_names[0] = gr.Textbox(label="Model A", interactive=False)
      model_names[1] = gr.Textbox(label="Model B", interactive=False)

  submit.click(get_responses, prompt, response_boxes + model_names)

  common_inputs = response_boxes + model_names + [
      prompt, response_type_radio, source_language, target_language
  ]
  option_a.click(vote, [option_a] + common_inputs)
  option_b.click(vote, [option_b] + common_inputs)
  tie.click(vote, [tie] + common_inputs)

  build_leaderboard(db)

if __name__ == "__main__":
  # We need to enable queue to use generators.
  app.queue()
  app.launch(debug=True)