suhyun.kang commited on
Commit
cc5a628
2 Parent(s): a19f11e 486e533

Merge branch 'main' of https://github.com/Y-IAB/arena into 1-elo

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +4 -2
  3. app.py +52 -72
  4. requirments.txt +22 -41
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  venv
2
  *.log
 
 
1
  venv
2
  *.log
3
+ __pycache__
README.md CHANGED
@@ -19,7 +19,9 @@
19
  Set your OpenAI API key as an environment variable and start the application:
20
 
21
  ```shell
22
- GCP_PROJECT_ID=<your project id> OPENAI_API_KEY=<your key> python3 app.py
23
  ```
24
 
25
- Replace <your project id> and <your key> with your GCP project ID and OpenAI API key respectively.
 
 
 
19
  Set your OpenAI API key as an environment variable and start the application:
20
 
21
  ```shell
22
+ OPENAI_API_KEY=<your key> python3 app.py
23
  ```
24
 
25
+ Replace `<your key>` with your GCP project ID.
26
+
27
+ > To run the app with [auto-reloading](https://www.gradio.app/guides/developing-faster-with-reload-mode), use `gradio app.py --demo-name app` instead of `python3 app.py`.
app.py CHANGED
@@ -3,23 +3,24 @@ It provides a platform for comparing the responses of two LLMs.
3
  """
4
 
5
  import enum
6
- import json
7
  from random import sample
8
  from uuid import uuid4
9
 
10
- from fastchat.serve import gradio_web_server
11
- from fastchat.serve.gradio_web_server import bot_response
12
  import firebase_admin
13
  from firebase_admin import firestore
14
  import gradio as gr
 
15
 
16
  from leaderboard import build_leaderboard
17
 
 
18
  db_app = firebase_admin.initialize_app()
19
  db = firestore.client()
20
 
21
  # TODO(#1): Add more models.
22
- SUPPORTED_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gemini-pro"]
 
 
23
 
24
  # TODO(#4): Add more languages.
25
  SUPPORTED_TRANSLATION_LANGUAGES = ["Korean", "English"]
@@ -36,23 +37,20 @@ class VoteOptions(enum.Enum):
36
  TIE = "Tie"
37
 
38
 
39
- def vote(state_a, state_b, vote_button, res_type, source_lang, target_lang):
 
40
  doc_id = uuid4().hex
41
  winner = VoteOptions(vote_button).name.lower()
42
 
43
- # The 'messages' field in the state is an array of arrays, which is
44
- # not supported by Firestore. Therefore, we convert it to a JSON string.
45
- model_a_conv = json.dumps(state_a.dict())
46
- model_b_conv = json.dumps(state_b.dict())
47
-
48
  if res_type == ResponseType.SUMMARIZE.value:
49
  doc_ref = db.collection("arena-summarizations").document(doc_id)
50
  doc_ref.set({
51
  "id": doc_id,
52
- "model_a": state_a.model_name,
53
- "model_b": state_b.model_name,
54
- "model_a_conv": model_a_conv,
55
- "model_b_conv": model_b_conv,
 
56
  "winner": winner,
57
  "timestamp": firestore.SERVER_TIMESTAMP
58
  })
@@ -62,10 +60,11 @@ def vote(state_a, state_b, vote_button, res_type, source_lang, target_lang):
62
  doc_ref = db.collection("arena-translations").document(doc_id)
63
  doc_ref.set({
64
  "id": doc_id,
65
- "model_a": state_a.model_name,
66
- "model_b": state_b.model_name,
67
- "model_a_conv": model_a_conv,
68
- "model_b_conv": model_b_conv,
 
69
  "source_language": source_lang.lower(),
70
  "target_language": target_lang.lower(),
71
  "winner": winner,
@@ -73,42 +72,38 @@ def vote(state_a, state_b, vote_button, res_type, source_lang, target_lang):
73
  })
74
 
75
 
76
- def user(user_prompt):
77
- model_pair = sample(SUPPORTED_MODELS, 2)
78
- new_state_a = gradio_web_server.State(model_pair[0])
79
- new_state_b = gradio_web_server.State(model_pair[1])
80
-
81
- for state in [new_state_a, new_state_b]:
82
- state.conv.append_message(state.conv.roles[0], user_prompt)
83
- state.conv.append_message(state.conv.roles[1], None)
84
- state.skip_next = False
85
 
86
- return [
87
- new_state_a, new_state_b, new_state_a.model_name, new_state_b.model_name
88
- ]
89
 
90
 
91
- def bot(state_a, state_b, request: gr.Request):
92
- new_states = [state_a, state_b]
93
 
94
  generators = []
95
- for state in new_states:
96
  try:
97
  # TODO(#1): Allow user to set configuration.
98
- # bot_response returns a generator yielding states.
99
- generator = bot_response(state,
100
- temperature=0.9,
101
- top_p=0.9,
102
- max_new_tokens=100,
103
- request=request)
104
- generators.append(generator)
105
 
106
  # TODO(#1): Narrow down the exception type.
107
  except Exception as e: # pylint: disable=broad-except
108
  print(f"Error in bot_response: {e}")
109
  raise e
110
 
111
- new_responses = [None, None]
112
 
113
  # It simulates concurrent response generation from two models.
114
  while True:
@@ -118,19 +113,14 @@ def bot(state_a, state_b, request: gr.Request):
118
  try:
119
  yielded = next(generators[i])
120
 
121
- # The generator yields a tuple, with the new state as the first item.
122
- new_state = yielded[0]
123
- new_states[i] = new_state
124
-
125
- # The last item from 'messages' represents the response to the prompt.
126
- bot_message = new_state.conv.messages[-1]
127
-
128
- # Each message in conv.messages is structured as [role, message],
129
- # so we extract the last message component.
130
- new_responses[i] = bot_message[-1]
131
 
 
132
  stop = False
133
 
 
 
134
  except StopIteration:
135
  pass
136
 
@@ -139,8 +129,6 @@ def bot(state_a, state_b, request: gr.Request):
139
  print(f"Error in generator: {e}")
140
  raise e
141
 
142
- yield new_states + new_responses
143
-
144
  if stop:
145
  break
146
 
@@ -176,36 +164,22 @@ with gr.Blocks() as app:
176
  [source_language, target_language])
177
 
178
  model_names = [gr.State(None), gr.State(None)]
179
- responses = [gr.State(None), gr.State(None)]
180
-
181
- # states stores FastChat-specific conversation states.
182
- states = [gr.State(None), gr.State(None)]
183
 
184
  prompt = gr.TextArea(label="Prompt", lines=4)
185
  submit = gr.Button()
186
 
187
  with gr.Row():
188
- responses[0] = gr.Textbox(label="Model A", interactive=False)
189
- responses[1] = gr.Textbox(label="Model B", interactive=False)
190
 
191
  # TODO(#5): Display it only after the user submits the prompt.
192
  # TODO(#6): Block voting if the response_type is not set.
193
  # TODO(#6): Block voting if the user already voted.
194
  with gr.Row():
195
  option_a = gr.Button(VoteOptions.MODEL_A.value)
196
- option_a.click(
197
- vote, states +
198
- [option_a, response_type_radio, source_language, target_language])
199
-
200
  option_b = gr.Button("Model B is better")
201
- option_b.click(
202
- vote, states +
203
- [option_b, response_type_radio, source_language, target_language])
204
-
205
  tie = gr.Button("Tie")
206
- tie.click(
207
- vote,
208
- states + [tie, response_type_radio, source_language, target_language])
209
 
210
  # TODO(#7): Hide it until the user votes.
211
  with gr.Accordion("Show models", open=False):
@@ -213,8 +187,14 @@ with gr.Blocks() as app:
213
  model_names[0] = gr.Textbox(label="Model A", interactive=False)
214
  model_names[1] = gr.Textbox(label="Model B", interactive=False)
215
 
216
- submit.click(user, prompt, states + model_names,
217
- queue=False).then(bot, states, states + responses)
 
 
 
 
 
 
218
 
219
  build_leaderboard(db)
220
 
 
3
  """
4
 
5
  import enum
 
6
  from random import sample
7
  from uuid import uuid4
8
 
 
 
9
  import firebase_admin
10
  from firebase_admin import firestore
11
  import gradio as gr
12
+ from litellm import completion
13
 
14
  from leaderboard import build_leaderboard
15
 
16
+ # TODO(#21): Fix auto-reload issue related to the initialization of Firebase.
17
  db_app = firebase_admin.initialize_app()
18
  db = firestore.client()
19
 
20
  # TODO(#1): Add more models.
21
+ SUPPORTED_MODELS = [
22
+ "gpt-4", "gpt-4-0125-preview", "gpt-3.5-turbo", "gemini-pro"
23
+ ]
24
 
25
  # TODO(#4): Add more languages.
26
  SUPPORTED_TRANSLATION_LANGUAGES = ["Korean", "English"]
 
37
  TIE = "Tie"
38
 
39
 
40
+ def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
41
+ user_prompt, res_type, source_lang, target_lang):
42
  doc_id = uuid4().hex
43
  winner = VoteOptions(vote_button).name.lower()
44
 
 
 
 
 
 
45
  if res_type == ResponseType.SUMMARIZE.value:
46
  doc_ref = db.collection("arena-summarizations").document(doc_id)
47
  doc_ref.set({
48
  "id": doc_id,
49
+ "prompt": user_prompt,
50
+ "model_a": model_a_name,
51
+ "model_b": model_b_name,
52
+ "model_a_response": response_a,
53
+ "model_b_response": response_b,
54
  "winner": winner,
55
  "timestamp": firestore.SERVER_TIMESTAMP
56
  })
 
60
  doc_ref = db.collection("arena-translations").document(doc_id)
61
  doc_ref.set({
62
  "id": doc_id,
63
+ "prompt": user_prompt,
64
+ "model_a": model_a_name,
65
+ "model_b": model_b_name,
66
+ "model_a_response": response_a,
67
+ "model_b_response": response_b,
68
  "source_language": source_lang.lower(),
69
  "target_language": target_lang.lower(),
70
  "winner": winner,
 
72
  })
73
 
74
 
75
+ def response_generator(response: str):
76
+ for part in response:
77
+ content = part.choices[0].delta.content
78
+ if content is None:
79
+ continue
 
 
 
 
80
 
81
+ # To simulate a stream, we yield each character of the response.
82
+ for character in content:
83
+ yield character
84
 
85
 
86
+ def get_responses(user_prompt):
87
+ models = sample(SUPPORTED_MODELS, 2)
88
 
89
  generators = []
90
+ for model in models:
91
  try:
92
  # TODO(#1): Allow user to set configuration.
93
+ response = completion(model=model,
94
+ messages=[{
95
+ "content": user_prompt,
96
+ "role": "user"
97
+ }],
98
+ stream=True)
99
+ generators.append(response_generator(response))
100
 
101
  # TODO(#1): Narrow down the exception type.
102
  except Exception as e: # pylint: disable=broad-except
103
  print(f"Error in bot_response: {e}")
104
  raise e
105
 
106
+ responses = ["", ""]
107
 
108
  # It simulates concurrent response generation from two models.
109
  while True:
 
113
  try:
114
  yielded = next(generators[i])
115
 
116
+ if yielded is None:
117
+ continue
 
 
 
 
 
 
 
 
118
 
119
+ responses[i] += yielded
120
  stop = False
121
 
122
+ yield responses + models
123
+
124
  except StopIteration:
125
  pass
126
 
 
129
  print(f"Error in generator: {e}")
130
  raise e
131
 
 
 
132
  if stop:
133
  break
134
 
 
164
  [source_language, target_language])
165
 
166
  model_names = [gr.State(None), gr.State(None)]
167
+ response_boxes = [gr.State(None), gr.State(None)]
 
 
 
168
 
169
  prompt = gr.TextArea(label="Prompt", lines=4)
170
  submit = gr.Button()
171
 
172
  with gr.Row():
173
+ response_boxes[0] = gr.Textbox(label="Model A", interactive=False)
174
+ response_boxes[1] = gr.Textbox(label="Model B", interactive=False)
175
 
176
  # TODO(#5): Display it only after the user submits the prompt.
177
  # TODO(#6): Block voting if the response_type is not set.
178
  # TODO(#6): Block voting if the user already voted.
179
  with gr.Row():
180
  option_a = gr.Button(VoteOptions.MODEL_A.value)
 
 
 
 
181
  option_b = gr.Button("Model B is better")
 
 
 
 
182
  tie = gr.Button("Tie")
 
 
 
183
 
184
  # TODO(#7): Hide it until the user votes.
185
  with gr.Accordion("Show models", open=False):
 
187
  model_names[0] = gr.Textbox(label="Model A", interactive=False)
188
  model_names[1] = gr.Textbox(label="Model B", interactive=False)
189
 
190
+ submit.click(get_responses, prompt, response_boxes + model_names)
191
+
192
+ common_inputs = response_boxes + model_names + [
193
+ prompt, response_type_radio, source_language, target_language
194
+ ]
195
+ option_a.click(vote, [option_a] + common_inputs)
196
+ option_b.click(vote, [option_b] + common_inputs)
197
+ tie.click(vote, [tie] + common_inputs)
198
 
199
  build_leaderboard(db)
200
 
requirments.txt CHANGED
@@ -1,4 +1,3 @@
1
- accelerate==0.26.1
2
  aiofiles==23.2.1
3
  aiohttp==3.9.3
4
  aiosignal==1.3.1
@@ -6,9 +5,9 @@ altair==5.2.0
6
  annotated-types==0.6.0
7
  anyio==4.2.0
8
  attrs==23.2.0
9
- CacheControl==0.13.1
10
  cachetools==5.3.2
11
- certifi==2023.11.17
12
  cffi==1.16.0
13
  charset-normalizer==3.3.2
14
  click==8.1.7
@@ -17,76 +16,67 @@ contourpy==1.2.0
17
  cryptography==42.0.2
18
  cycler==0.12.1
19
  distro==1.9.0
20
- fastapi==0.109.0
21
  ffmpy==0.3.1
22
  filelock==3.13.1
23
  firebase-admin==6.4.0
24
  fonttools==4.47.2
25
  frozenlist==1.4.1
26
- fschat==0.2.35
27
- fsspec==2023.12.2
28
- google-api-core==2.16.1
29
  google-api-python-client==2.116.0
30
  google-auth==2.27.0
31
  google-auth-httplib2==0.2.0
32
- google-cloud-aiplatform==1.40.0
33
- google-cloud-bigquery==3.17.1
34
  google-cloud-core==2.4.1
35
  google-cloud-firestore==2.14.0
36
- google-cloud-resource-manager==1.11.0
37
  google-cloud-storage==2.14.0
38
  google-crc32c==1.5.0
39
  google-resumable-media==2.7.0
40
  googleapis-common-protos==1.62.0
41
- gradio==3.50.2
42
- gradio_client==0.6.1
43
- grpc-google-iam-v1==0.13.0
44
- grpcio==1.60.0
45
- grpcio-status==1.60.0
46
  h11==0.14.0
47
  httpcore==1.0.2
48
  httplib2==0.22.0
49
  httpx==0.26.0
50
  huggingface-hub==0.20.3
51
  idna==3.6
 
52
  importlib-resources==6.1.1
53
  Jinja2==3.1.3
54
  jsonschema==4.21.1
55
  jsonschema-specifications==2023.12.1
56
  kiwisolver==1.4.5
 
57
  markdown-it-py==3.0.0
58
- markdown2==2.4.12
59
- MarkupSafe==2.1.4
60
  matplotlib==3.8.2
61
  mdurl==0.1.2
62
- mpmath==1.3.0
63
  msgpack==1.0.7
64
- multidict==6.0.4
65
- networkx==3.2.1
66
- nh3==0.2.15
67
  numpy==1.26.3
68
- openai==0.28.0
69
- orjson==3.9.12
70
  packaging==23.2
71
  pandas==2.2.0
72
- peft==0.8.1
73
  pillow==10.2.0
74
- prompt-toolkit==3.0.43
75
  proto-plus==1.23.0
76
  protobuf==4.25.2
77
- psutil==5.9.8
78
  pyasn1==0.5.1
79
  pyasn1-modules==0.3.0
80
  pycparser==2.21
81
- pydantic==1.10.14
82
  pydantic_core==2.16.1
83
  pydub==0.25.1
84
  Pygments==2.17.2
85
  PyJWT==2.8.0
86
  pyparsing==3.1.1
87
  python-dateutil==2.8.2
88
- python-multipart==0.0.6
89
- pytz==2023.4
 
90
  PyYAML==6.0.1
91
  referencing==0.33.0
92
  regex==2023.12.25
@@ -94,32 +84,23 @@ requests==2.31.0
94
  rich==13.7.0
95
  rpds-py==0.17.1
96
  rsa==4.9
97
- ruff==0.1.15
98
- safetensors==0.4.2
99
  semantic-version==2.10.0
100
- sentencepiece==0.1.99
101
- shapely==2.0.2
102
  shellingham==1.5.4
103
- shortuuid==1.0.11
104
  six==1.16.0
105
  sniffio==1.3.0
106
- starlette==0.35.1
107
- svgwrite==1.4.3
108
- sympy==1.12
109
  tiktoken==0.5.2
110
  tokenizers==0.15.1
111
  tomlkit==0.12.0
112
  toolz==0.12.1
113
- torch==2.2.0
114
  tqdm==4.66.1
115
- transformers==4.37.2
116
  typer==0.9.0
117
  typing_extensions==4.9.0
118
  tzdata==2023.4
119
  uritemplate==4.1.1
120
  urllib3==2.2.0
121
  uvicorn==0.27.0.post1
122
- wavedrom==2.0.3.post3
123
- wcwidth==0.2.13
124
  websockets==11.0.3
125
  yarl==1.9.4
 
 
 
1
  aiofiles==23.2.1
2
  aiohttp==3.9.3
3
  aiosignal==1.3.1
 
5
  annotated-types==0.6.0
6
  anyio==4.2.0
7
  attrs==23.2.0
8
+ CacheControl==0.14.0
9
  cachetools==5.3.2
10
+ certifi==2024.2.2
11
  cffi==1.16.0
12
  charset-normalizer==3.3.2
13
  click==8.1.7
 
16
  cryptography==42.0.2
17
  cycler==0.12.1
18
  distro==1.9.0
19
+ fastapi==0.109.2
20
  ffmpy==0.3.1
21
  filelock==3.13.1
22
  firebase-admin==6.4.0
23
  fonttools==4.47.2
24
  frozenlist==1.4.1
25
+ fsspec==2024.2.0
26
+ google-api-core==2.16.2
 
27
  google-api-python-client==2.116.0
28
  google-auth==2.27.0
29
  google-auth-httplib2==0.2.0
 
 
30
  google-cloud-core==2.4.1
31
  google-cloud-firestore==2.14.0
 
32
  google-cloud-storage==2.14.0
33
  google-crc32c==1.5.0
34
  google-resumable-media==2.7.0
35
  googleapis-common-protos==1.62.0
36
+ gradio==4.16.0
37
+ gradio_client==0.8.1
38
+ grpcio==1.60.1
39
+ grpcio-status==1.60.1
 
40
  h11==0.14.0
41
  httpcore==1.0.2
42
  httplib2==0.22.0
43
  httpx==0.26.0
44
  huggingface-hub==0.20.3
45
  idna==3.6
46
+ importlib-metadata==7.0.1
47
  importlib-resources==6.1.1
48
  Jinja2==3.1.3
49
  jsonschema==4.21.1
50
  jsonschema-specifications==2023.12.1
51
  kiwisolver==1.4.5
52
+ litellm==1.22.3
53
  markdown-it-py==3.0.0
54
+ MarkupSafe==2.1.5
 
55
  matplotlib==3.8.2
56
  mdurl==0.1.2
 
57
  msgpack==1.0.7
58
+ multidict==6.0.5
 
 
59
  numpy==1.26.3
60
+ openai==1.11.1
61
+ orjson==3.9.13
62
  packaging==23.2
63
  pandas==2.2.0
 
64
  pillow==10.2.0
 
65
  proto-plus==1.23.0
66
  protobuf==4.25.2
 
67
  pyasn1==0.5.1
68
  pyasn1-modules==0.3.0
69
  pycparser==2.21
70
+ pydantic==2.6.0
71
  pydantic_core==2.16.1
72
  pydub==0.25.1
73
  Pygments==2.17.2
74
  PyJWT==2.8.0
75
  pyparsing==3.1.1
76
  python-dateutil==2.8.2
77
+ python-dotenv==1.0.1
78
+ python-multipart==0.0.7
79
+ pytz==2024.1
80
  PyYAML==6.0.1
81
  referencing==0.33.0
82
  regex==2023.12.25
 
84
  rich==13.7.0
85
  rpds-py==0.17.1
86
  rsa==4.9
87
+ ruff==0.2.0
 
88
  semantic-version==2.10.0
 
 
89
  shellingham==1.5.4
 
90
  six==1.16.0
91
  sniffio==1.3.0
92
+ starlette==0.36.3
 
 
93
  tiktoken==0.5.2
94
  tokenizers==0.15.1
95
  tomlkit==0.12.0
96
  toolz==0.12.1
 
97
  tqdm==4.66.1
 
98
  typer==0.9.0
99
  typing_extensions==4.9.0
100
  tzdata==2023.4
101
  uritemplate==4.1.1
102
  urllib3==2.2.0
103
  uvicorn==0.27.0.post1
 
 
104
  websockets==11.0.3
105
  yarl==1.9.4
106
+ zipp==3.17.0