Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
fix some bugs
#24
by
hanhainebula
- opened
- .github/workflows/main.yaml +0 -20
- Makefile +0 -10
- README.md +31 -9
- app.py +363 -531
- pyproject.toml +3 -3
- requirements.txt +13 -14
- src/about.py +3 -19
- src/benchmarks.py +125 -44
- src/{css_html_js.py → display/css_html_js.py} +0 -0
- src/display/formatting.py +29 -0
- src/{components.py → display/gradio_formatting.py} +21 -26
- src/display/gradio_listener.py +53 -0
- src/{columns.py → display/utils.py} +40 -53
- src/envs.py +2 -45
- src/loaders.py +0 -88
- src/{models.py → read_evals.py} +119 -68
- src/utils.py +136 -267
- tests/src/display/test_utils.py +23 -0
- tests/src/test_benchmarks.py +5 -29
- tests/src/test_columns.py +0 -119
- tests/src/test_envs.py +0 -14
- tests/src/test_loaders.py +0 -46
- tests/src/test_models.py +0 -89
- tests/src/test_read_evals.py +68 -0
- tests/src/test_utils.py +0 -237
- tests/test_utils.py +115 -0
- tests/toydata/eval_results/AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json +0 -0
- tests/toydata/eval_results/AIR-Bench_24.05/bge-m3/NoReranker/results.json +0 -0
- tests/toydata/test_data.json +98 -0
- tests/toydata/test_results/bge-m3/NoReranker/results_2023-11-21T18-10-08.json +98 -0
- tests/toydata/test_results/bge-m3/bge-reranker-v2-m3/results_2023-11-21T18-10-08.json +98 -0
.github/workflows/main.yaml
DELETED
@@ -1,20 +0,0 @@
|
|
1 |
-
name: Sync to Hugging Face hub
|
2 |
-
on:
|
3 |
-
push:
|
4 |
-
branches: [main]
|
5 |
-
|
6 |
-
# to run this workflow manually from the Actions tab
|
7 |
-
workflow_dispatch:
|
8 |
-
|
9 |
-
jobs:
|
10 |
-
sync-to-hub:
|
11 |
-
runs-on: ubuntu-latest
|
12 |
-
steps:
|
13 |
-
- uses: actions/checkout@v3
|
14 |
-
with:
|
15 |
-
fetch-depth: 0
|
16 |
-
lfs: true
|
17 |
-
- name: Push to hub
|
18 |
-
env:
|
19 |
-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
20 |
-
run: git push https://hanhainebula:$HF_TOKEN@huggingface.co/spaces/AIR-Bench/leaderboard main
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Makefile
CHANGED
@@ -3,21 +3,11 @@
|
|
3 |
|
4 |
style:
|
5 |
python -m black --line-length 119 .
|
6 |
-
python -m black --line-length 119 src
|
7 |
python -m isort .
|
8 |
-
python -m isort src
|
9 |
ruff check --fix .
|
10 |
-
ruff check --fix src
|
11 |
|
12 |
|
13 |
quality:
|
14 |
python -m black --check --line-length 119 .
|
15 |
-
python -m black --check --line-length 119 src
|
16 |
python -m isort --check-only .
|
17 |
-
python -m isort --check-only src
|
18 |
ruff check .
|
19 |
-
ruff check src
|
20 |
-
|
21 |
-
|
22 |
-
test:
|
23 |
-
python -m pytest tests
|
|
|
3 |
|
4 |
style:
|
5 |
python -m black --line-length 119 .
|
|
|
6 |
python -m isort .
|
|
|
7 |
ruff check --fix .
|
|
|
8 |
|
9 |
|
10 |
quality:
|
11 |
python -m black --check --line-length 119 .
|
|
|
12 |
python -m isort --check-only .
|
|
|
13 |
ruff check .
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -10,14 +10,36 @@ pinned: true
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
-
#
|
14 |
|
15 |
-
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
+
# Start the configuration
|
14 |
|
15 |
+
Most of the variables to change for a default leaderboard are in `src/env.py` (replace the path for your leaderboard) and `src/about.py` (for tasks).
|
16 |
|
17 |
+
Results files should have the following format and be stored as json files:
|
18 |
+
```json
|
19 |
+
{
|
20 |
+
"config": {
|
21 |
+
"model_dtype": "torch.float16", # or torch.bfloat16 or 8bit or 4bit
|
22 |
+
"model_name": "path of the model on the hub: org/model",
|
23 |
+
"model_sha": "revision on the hub",
|
24 |
+
},
|
25 |
+
"results": {
|
26 |
+
"task_name": {
|
27 |
+
"metric_name": score,
|
28 |
+
},
|
29 |
+
"task_name2": {
|
30 |
+
"metric_name": score,
|
31 |
+
}
|
32 |
+
}
|
33 |
+
}
|
34 |
+
```
|
35 |
+
|
36 |
+
Request files are created automatically by this tool.
|
37 |
+
|
38 |
+
If you encounter problem on the space, don't hesitate to restart it to remove the create eval-queue, eval-queue-bk, eval-results and eval-results-bk created folder.
|
39 |
+
|
40 |
+
# Code logic for more complex edits
|
41 |
+
|
42 |
+
You'll find
|
43 |
+
- the main table' columns names and properties in `src/display/utils.py`
|
44 |
+
- the logic to read all results and request files, then convert them in dataframe lines, in `src/leaderboard/read_evals.py`, and `src/populate.py`
|
45 |
+
- teh logic to allow or filter submissions in `src/submission/submit.py` and `src/submission/check_validity.py`
|
app.py
CHANGED
@@ -1,557 +1,391 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
import gradio as gr
|
4 |
from apscheduler.schedulers.background import BackgroundScheduler
|
5 |
from huggingface_hub import snapshot_download
|
6 |
|
7 |
from src.about import (
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
|
13 |
-
from src.components import (
|
14 |
-
get_anonymous_checkbox,
|
15 |
-
get_domain_dropdown,
|
16 |
-
get_language_dropdown,
|
17 |
-
get_leaderboard_table,
|
18 |
-
get_metric_dropdown,
|
19 |
-
get_noreranking_dropdown,
|
20 |
-
get_reranking_dropdown,
|
21 |
-
get_revision_and_ts_checkbox,
|
22 |
-
get_search_bar,
|
23 |
-
get_version_dropdown,
|
24 |
)
|
25 |
-
from src.
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
RESULTS_REPO,
|
36 |
-
TOKEN,
|
37 |
-
)
|
38 |
-
from src.loaders import load_eval_results
|
39 |
-
from src.models import TaskType, model_hyperlink
|
40 |
-
from src.utils import remove_html, reset_rank, set_listeners, submit_results, update_metric, upload_file
|
41 |
-
|
42 |
|
43 |
def restart_space():
|
44 |
API.restart_space(repo_id=REPO_ID)
|
45 |
|
46 |
|
47 |
try:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
tqdm_class=None,
|
55 |
-
etag_timeout=30,
|
56 |
-
token=TOKEN,
|
57 |
-
)
|
58 |
-
else:
|
59 |
-
print("Running in local mode")
|
60 |
-
except Exception:
|
61 |
-
print("failed to download")
|
62 |
restart_space()
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
show_revision_and_timestamp,
|
90 |
-
)
|
91 |
-
|
92 |
-
|
93 |
-
def update_doc_metric(
|
94 |
-
metric: str,
|
95 |
-
domains: list,
|
96 |
-
langs: list,
|
97 |
-
reranking_model: list,
|
98 |
-
query: str,
|
99 |
-
show_anonymous: bool,
|
100 |
-
show_revision_and_timestamp,
|
101 |
):
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
show_anonymous,
|
112 |
show_revision_and_timestamp,
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
def update_datastore(version):
|
117 |
-
global datastore
|
118 |
-
global ds_dict
|
119 |
-
if datastore.version != version:
|
120 |
-
print(f"updated data version: {datastore.version} -> {version}")
|
121 |
-
datastore = ds_dict[version]
|
122 |
-
else:
|
123 |
-
print(f"current data version: {datastore.version}")
|
124 |
-
return datastore
|
125 |
-
|
126 |
-
|
127 |
-
def update_qa_domains(version):
|
128 |
-
datastore = update_datastore(version)
|
129 |
-
domain_elem = get_domain_dropdown(QABenchmarks[datastore.slug])
|
130 |
-
return domain_elem
|
131 |
-
|
132 |
-
|
133 |
-
def update_doc_domains(version):
|
134 |
-
datastore = update_datastore(version)
|
135 |
-
domain_elem = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
|
136 |
-
return domain_elem
|
137 |
-
|
138 |
-
|
139 |
-
def update_qa_langs(version):
|
140 |
-
datastore = update_datastore(version)
|
141 |
-
lang_elem = get_language_dropdown(QABenchmarks[datastore.slug])
|
142 |
-
return lang_elem
|
143 |
-
|
144 |
-
|
145 |
-
def update_doc_langs(version):
|
146 |
-
datastore = update_datastore(version)
|
147 |
-
lang_elem = get_language_dropdown(LongDocBenchmarks[datastore.slug])
|
148 |
-
return lang_elem
|
149 |
-
|
150 |
-
|
151 |
-
def update_qa_models(version):
|
152 |
-
datastore = update_datastore(version)
|
153 |
-
model_elem = get_reranking_dropdown(datastore.reranking_models)
|
154 |
-
return model_elem
|
155 |
-
|
156 |
-
|
157 |
-
def update_qa_df_ret_rerank(version):
|
158 |
-
datastore = update_datastore(version)
|
159 |
-
return get_leaderboard_table(datastore.qa_fmt_df, datastore.qa_types)
|
160 |
-
|
161 |
-
|
162 |
-
def update_qa_hidden_df_ret_rerank(version):
|
163 |
-
datastore = update_datastore(version)
|
164 |
-
return get_leaderboard_table(datastore.qa_raw_df, datastore.qa_types, visible=False)
|
165 |
-
|
166 |
-
|
167 |
-
def update_doc_df_ret_rerank(version):
|
168 |
-
datastore = update_datastore(version)
|
169 |
-
return get_leaderboard_table(datastore.doc_fmt_df, datastore.doc_types)
|
170 |
-
|
171 |
-
|
172 |
-
def update_doc_hidden_df_ret_rerank(version):
|
173 |
-
datastore = update_datastore(version)
|
174 |
-
return get_leaderboard_table(datastore.doc_raw_df, datastore.doc_types, visible=False)
|
175 |
-
|
176 |
-
|
177 |
-
def filter_df_ret(df):
|
178 |
-
df_ret = df[df[COL_NAME_RERANKING_MODEL] == "NoReranker"]
|
179 |
-
df_ret = reset_rank(df_ret)
|
180 |
-
return df_ret
|
181 |
-
|
182 |
-
|
183 |
-
def update_qa_df_ret(version):
|
184 |
-
datastore = update_datastore(version)
|
185 |
-
df_ret = filter_df_ret(datastore.qa_fmt_df)
|
186 |
-
return get_leaderboard_table(df_ret, datastore.qa_types)
|
187 |
-
|
188 |
-
|
189 |
-
def update_qa_hidden_df_ret(version):
|
190 |
-
datastore = update_datastore(version)
|
191 |
-
df_ret_hidden = filter_df_ret(datastore.qa_raw_df)
|
192 |
-
return get_leaderboard_table(df_ret_hidden, datastore.qa_types, visible=False)
|
193 |
-
|
194 |
-
|
195 |
-
def update_doc_df_ret(version):
|
196 |
-
datastore = update_datastore(version)
|
197 |
-
df_ret = filter_df_ret(datastore.doc_fmt_df)
|
198 |
-
return get_leaderboard_table(df_ret, datastore.doc_types)
|
199 |
-
|
200 |
-
|
201 |
-
def update_doc_hidden_df_ret(version):
|
202 |
-
datastore = update_datastore(version)
|
203 |
-
df_ret_hidden = filter_df_ret(datastore.doc_raw_df)
|
204 |
-
return get_leaderboard_table(df_ret_hidden, datastore.doc_types, visible=False)
|
205 |
-
|
206 |
-
|
207 |
-
def filter_df_rerank(df):
|
208 |
-
df_rerank = df[df[COL_NAME_RETRIEVAL_MODEL] == BM25_LINK]
|
209 |
-
df_rerank = reset_rank(df_rerank)
|
210 |
-
return df_rerank
|
211 |
-
|
212 |
-
|
213 |
-
def update_qa_df_rerank(version):
|
214 |
-
datastore = update_datastore(version)
|
215 |
-
df_rerank = filter_df_rerank(datastore.qa_fmt_df)
|
216 |
-
return get_leaderboard_table(df_rerank, datastore.qa_types)
|
217 |
-
|
218 |
-
|
219 |
-
def update_qa_hidden_df_rerank(version):
|
220 |
-
datastore = update_datastore(version)
|
221 |
-
df_rerank_hidden = filter_df_rerank(datastore.qa_raw_df)
|
222 |
-
return get_leaderboard_table(df_rerank_hidden, datastore.qa_types, visible=False)
|
223 |
-
|
224 |
-
|
225 |
-
def update_doc_df_rerank(version):
|
226 |
-
datastore = update_datastore(version)
|
227 |
-
df_rerank = filter_df_rerank(datastore.doc_fmt_df)
|
228 |
-
return get_leaderboard_table(df_rerank, datastore.doc_types)
|
229 |
-
|
230 |
-
|
231 |
-
def update_doc_hidden_df_rerank(version):
|
232 |
-
datastore = update_datastore(version)
|
233 |
-
df_rerank_hidden = filter_df_rerank(datastore.doc_raw_df)
|
234 |
-
return get_leaderboard_table(df_rerank_hidden, datastore.doc_types, visible=False)
|
235 |
|
236 |
|
237 |
demo = gr.Blocks(css=custom_css)
|
238 |
|
239 |
-
BM25_LINK = model_hyperlink("https://github.com/castorini/pyserini", "BM25")
|
240 |
-
|
241 |
with demo:
|
242 |
gr.HTML(TITLE)
|
243 |
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
244 |
|
245 |
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
246 |
-
with gr.TabItem("
|
247 |
with gr.Row():
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
with gr.
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
|
|
|
|
|
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
|
|
291 |
search_bar,
|
292 |
-
version,
|
293 |
-
domains,
|
294 |
-
langs,
|
295 |
-
models,
|
296 |
show_anonymous,
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
version.change(update_qa_df_ret, version, qa_df_elem_ret)
|
318 |
-
|
319 |
-
# Dummy leaderboard for handling the case when the user uses backspace key
|
320 |
-
_qa_df_ret_hidden = filter_df_ret(datastore.qa_raw_df)
|
321 |
-
qa_df_elem_ret_hidden = get_leaderboard_table(
|
322 |
-
_qa_df_ret_hidden, datastore.qa_types, visible=False
|
323 |
-
)
|
324 |
-
version.change(update_qa_hidden_df_ret, version, qa_df_elem_ret_hidden)
|
325 |
-
|
326 |
-
set_listeners(
|
327 |
-
TaskType.qa,
|
328 |
-
qa_df_elem_ret,
|
329 |
-
qa_df_elem_ret_hidden,
|
330 |
-
search_bar_ret,
|
331 |
-
version,
|
332 |
-
domains,
|
333 |
-
langs,
|
334 |
-
models_ret,
|
335 |
show_anonymous,
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
|
|
|
|
|
|
353 |
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
qa_df_elem_rerank,
|
375 |
-
qa_df_elem_rerank_hidden,
|
376 |
-
qa_search_bar_rerank,
|
377 |
-
version,
|
378 |
-
domains,
|
379 |
-
langs,
|
380 |
-
qa_models_rerank,
|
381 |
show_anonymous,
|
382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
domains,
|
390 |
-
langs,
|
391 |
-
qa_models_rerank,
|
392 |
-
qa_search_bar_rerank,
|
393 |
-
show_anonymous,
|
394 |
-
show_rev_ts,
|
395 |
-
],
|
396 |
-
qa_df_elem_rerank,
|
397 |
-
queue=True,
|
398 |
-
)
|
399 |
-
with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
|
400 |
-
with gr.Row():
|
401 |
-
with gr.Column(min_width=320):
|
402 |
-
# select domain
|
403 |
-
with gr.Row():
|
404 |
-
domains = get_domain_dropdown(LongDocBenchmarks[datastore.slug])
|
405 |
-
version.change(update_doc_domains, version, domains)
|
406 |
-
# select language
|
407 |
-
with gr.Row():
|
408 |
-
langs = get_language_dropdown(LongDocBenchmarks[datastore.slug])
|
409 |
-
version.change(update_doc_langs, version, langs)
|
410 |
-
with gr.Column():
|
411 |
-
# select the metric
|
412 |
-
with gr.Row():
|
413 |
-
metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_LONG_DOC)
|
414 |
-
with gr.Row():
|
415 |
-
show_anonymous = get_anonymous_checkbox()
|
416 |
-
with gr.Row():
|
417 |
-
show_rev_ts = get_revision_and_ts_checkbox()
|
418 |
-
with gr.Tabs(elem_classes="tab-buttons"):
|
419 |
-
with gr.TabItem("Retrieval + Reranking", id=20):
|
420 |
-
with gr.Row():
|
421 |
-
with gr.Column():
|
422 |
-
search_bar = get_search_bar()
|
423 |
-
with gr.Column():
|
424 |
-
models = get_reranking_dropdown(datastore.reranking_models)
|
425 |
-
version.change(update_qa_models, version, models)
|
426 |
-
|
427 |
-
doc_df_elem_ret_rerank = get_leaderboard_table(datastore.doc_fmt_df, datastore.doc_types)
|
428 |
-
|
429 |
-
version.change(update_doc_df_ret_rerank, version, doc_df_elem_ret_rerank)
|
430 |
-
|
431 |
-
doc_df_elem_ret_rerank_hidden = get_leaderboard_table(
|
432 |
-
datastore.doc_raw_df, datastore.doc_types, visible=False
|
433 |
-
)
|
434 |
|
435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
|
|
|
|
|
|
|
|
441 |
search_bar,
|
442 |
-
version,
|
443 |
-
domains,
|
444 |
-
langs,
|
445 |
-
models,
|
446 |
show_anonymous,
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
_doc_df_ret = filter_df_ret(datastore.doc_fmt_df)
|
472 |
-
doc_df_elem_ret = get_leaderboard_table(_doc_df_ret, datastore.doc_types)
|
473 |
-
version.change(update_doc_df_ret, version, doc_df_elem_ret)
|
474 |
-
|
475 |
-
_doc_df_ret_hidden = filter_df_ret(datastore.doc_raw_df)
|
476 |
-
doc_df_elem_ret_hidden = get_leaderboard_table(
|
477 |
-
_doc_df_ret_hidden, datastore.doc_types, visible=False
|
478 |
-
)
|
479 |
-
version.change(update_doc_hidden_df_ret, version, doc_df_elem_ret_hidden)
|
480 |
-
|
481 |
-
set_listeners(
|
482 |
-
TaskType.long_doc,
|
483 |
-
doc_df_elem_ret,
|
484 |
-
doc_df_elem_ret_hidden,
|
485 |
-
search_bar_ret,
|
486 |
-
version,
|
487 |
-
domains,
|
488 |
-
langs,
|
489 |
-
models_ret,
|
490 |
-
show_anonymous,
|
491 |
-
show_rev_ts,
|
492 |
-
)
|
493 |
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
doc_df_elem_ret,
|
506 |
-
queue=True,
|
507 |
-
)
|
508 |
-
with gr.TabItem("Reranking Only", id=22):
|
509 |
-
_doc_df_rerank = filter_df_rerank(datastore.doc_fmt_df)
|
510 |
-
doc_rerank_models = (
|
511 |
-
_doc_df_rerank[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
|
512 |
-
)
|
513 |
-
with gr.Row():
|
514 |
-
with gr.Column(scale=1):
|
515 |
-
doc_models_rerank = get_reranking_dropdown(doc_rerank_models)
|
516 |
-
with gr.Column(scale=1):
|
517 |
-
doc_search_bar_rerank = gr.Textbox(show_label=False, visible=False)
|
518 |
-
doc_df_elem_rerank = get_leaderboard_table(_doc_df_rerank, datastore.doc_types)
|
519 |
-
version.change(update_doc_df_rerank, version, doc_df_elem_rerank)
|
520 |
-
|
521 |
-
_doc_df_rerank_hidden = filter_df_rerank(datastore.doc_raw_df)
|
522 |
-
doc_df_elem_rerank_hidden = get_leaderboard_table(
|
523 |
-
_doc_df_rerank_hidden, datastore.doc_types, visible=False
|
524 |
-
)
|
525 |
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
domains,
|
535 |
-
langs,
|
536 |
-
doc_models_rerank,
|
537 |
show_anonymous,
|
538 |
-
|
539 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
555 |
|
556 |
with gr.TabItem("🚀Submit here!", elem_id="submit-tab-table", id=2):
|
557 |
with gr.Column():
|
@@ -567,18 +401,23 @@ with demo:
|
|
567 |
with gr.Row():
|
568 |
with gr.Column():
|
569 |
reranking_model_name = gr.Textbox(
|
570 |
-
label="Reranking Model name",
|
|
|
|
|
571 |
)
|
572 |
with gr.Column():
|
573 |
-
reranking_model_url = gr.Textbox(
|
|
|
|
|
|
|
|
|
574 |
with gr.Row():
|
575 |
with gr.Column():
|
576 |
benchmark_version = gr.Dropdown(
|
577 |
-
|
578 |
-
value=
|
579 |
interactive=True,
|
580 |
-
label="AIR-Bench Version
|
581 |
-
)
|
582 |
with gr.Row():
|
583 |
upload_button = gr.UploadButton("Click to upload search results", file_count="single")
|
584 |
with gr.Row():
|
@@ -587,8 +426,7 @@ with demo:
|
|
587 |
is_anonymous = gr.Checkbox(
|
588 |
label="Nope. I want to submit anonymously 🥷",
|
589 |
value=False,
|
590 |
-
info="Do you want to shown on the leaderboard by default?"
|
591 |
-
)
|
592 |
with gr.Row():
|
593 |
submit_button = gr.Button("Submit")
|
594 |
with gr.Row():
|
@@ -598,8 +436,7 @@ with demo:
|
|
598 |
[
|
599 |
upload_button,
|
600 |
],
|
601 |
-
file_output
|
602 |
-
)
|
603 |
submit_button.click(
|
604 |
submit_results,
|
605 |
[
|
@@ -609,21 +446,16 @@ with demo:
|
|
609 |
reranking_model_name,
|
610 |
reranking_model_url,
|
611 |
benchmark_version,
|
612 |
-
is_anonymous
|
613 |
],
|
614 |
submission_result,
|
615 |
-
show_progress="hidden"
|
616 |
)
|
617 |
|
618 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
|
619 |
gr.Markdown(BENCHMARKS_TEXT, elem_classes="markdown-text")
|
620 |
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
scheduler = BackgroundScheduler()
|
626 |
-
scheduler.add_job(restart_space, "interval", seconds=1800)
|
627 |
-
scheduler.start()
|
628 |
-
demo.queue(default_concurrency_limit=40)
|
629 |
-
demo.launch()
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from apscheduler.schedulers.background import BackgroundScheduler
|
3 |
from huggingface_hub import snapshot_download
|
4 |
|
5 |
from src.about import (
|
6 |
+
INTRODUCTION_TEXT,
|
7 |
+
BENCHMARKS_TEXT,
|
8 |
+
TITLE,
|
9 |
+
EVALUATION_QUEUE_TEXT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
)
|
11 |
+
from src.benchmarks import DOMAIN_COLS_QA, LANG_COLS_QA, DOMAIN_COLS_LONG_DOC, LANG_COLS_LONG_DOC, METRIC_LIST, \
|
12 |
+
DEFAULT_METRIC_QA, DEFAULT_METRIC_LONG_DOC
|
13 |
+
from src.display.css_html_js import custom_css
|
14 |
+
from src.display.utils import COL_NAME_IS_ANONYMOUS, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
|
15 |
+
from src.envs import API, EVAL_RESULTS_PATH, REPO_ID, RESULTS_REPO, TOKEN
|
16 |
+
from src.read_evals import get_raw_eval_results, get_leaderboard_df
|
17 |
+
from src.utils import update_metric, upload_file, get_default_cols, submit_results, reset_rank, remove_html
|
18 |
+
from src.display.gradio_formatting import get_version_dropdown, get_search_bar, get_reranking_dropdown, \
|
19 |
+
get_metric_dropdown, get_domain_dropdown, get_language_dropdown, get_anonymous_checkbox, get_revision_and_ts_checkbox, get_leaderboard_table, get_noreranking_dropdown
|
20 |
+
from src.display.gradio_listener import set_listeners
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def restart_space():
|
23 |
API.restart_space(repo_id=REPO_ID)
|
24 |
|
25 |
|
26 |
try:
|
27 |
+
snapshot_download(
|
28 |
+
repo_id=RESULTS_REPO, local_dir=EVAL_RESULTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30,
|
29 |
+
token=TOKEN
|
30 |
+
)
|
31 |
+
except Exception as e:
|
32 |
+
print(f'failed to download')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
restart_space()
|
34 |
|
35 |
+
raw_data = get_raw_eval_results(f"{EVAL_RESULTS_PATH}/AIR-Bench_24.04")
|
36 |
+
|
37 |
+
original_df_qa = get_leaderboard_df(
|
38 |
+
raw_data, task='qa', metric=DEFAULT_METRIC_QA)
|
39 |
+
original_df_long_doc = get_leaderboard_df(
|
40 |
+
raw_data, task='long-doc', metric=DEFAULT_METRIC_LONG_DOC)
|
41 |
+
print(f'raw data: {len(raw_data)}')
|
42 |
+
print(f'QA data loaded: {original_df_qa.shape}')
|
43 |
+
print(f'Long-Doc data loaded: {len(original_df_long_doc)}')
|
44 |
+
|
45 |
+
leaderboard_df_qa = original_df_qa.copy()
|
46 |
+
# leaderboard_df_qa = leaderboard_df_qa[has_no_nan_values(df, _benchmark_cols)]
|
47 |
+
shown_columns_qa, types_qa = get_default_cols(
|
48 |
+
'qa', leaderboard_df_qa.columns, add_fix_cols=True)
|
49 |
+
leaderboard_df_qa = leaderboard_df_qa[~leaderboard_df_qa[COL_NAME_IS_ANONYMOUS]][shown_columns_qa]
|
50 |
+
leaderboard_df_qa.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
51 |
+
|
52 |
+
leaderboard_df_long_doc = original_df_long_doc.copy()
|
53 |
+
shown_columns_long_doc, types_long_doc = get_default_cols(
|
54 |
+
'long-doc', leaderboard_df_long_doc.columns, add_fix_cols=True)
|
55 |
+
leaderboard_df_long_doc = leaderboard_df_long_doc[~leaderboard_df_long_doc[COL_NAME_IS_ANONYMOUS]][shown_columns_long_doc]
|
56 |
+
leaderboard_df_long_doc.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
57 |
+
|
58 |
+
# select reranking model
|
59 |
+
reranking_models = sorted(list(frozenset([eval_result.reranking_model for eval_result in raw_data])))
|
60 |
+
|
61 |
+
|
62 |
+
def update_metric_qa(
|
63 |
+
metric: str,
|
64 |
+
domains: list,
|
65 |
+
langs: list,
|
66 |
+
reranking_model: list,
|
67 |
+
query: str,
|
68 |
+
show_anonymous: bool,
|
69 |
show_revision_and_timestamp,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
):
|
71 |
+
return update_metric(raw_data, 'qa', metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp)
|
72 |
+
|
73 |
+
def update_metric_long_doc(
|
74 |
+
metric: str,
|
75 |
+
domains: list,
|
76 |
+
langs: list,
|
77 |
+
reranking_model: list,
|
78 |
+
query: str,
|
79 |
+
show_anonymous: bool,
|
|
|
80 |
show_revision_and_timestamp,
|
81 |
+
):
|
82 |
+
return update_metric(raw_data, "long-doc", metric, domains, langs, reranking_model, query, show_anonymous, show_revision_and_timestamp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
|
85 |
demo = gr.Blocks(css=custom_css)
|
86 |
|
|
|
|
|
87 |
with demo:
|
88 |
gr.HTML(TITLE)
|
89 |
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
90 |
|
91 |
with gr.Tabs(elem_classes="tab-buttons") as tabs:
|
92 |
+
with gr.TabItem("QA", elem_id="qa-benchmark-tab-table", id=0):
|
93 |
with gr.Row():
|
94 |
+
with gr.Column(min_width=320):
|
95 |
+
# select domain
|
96 |
+
with gr.Row():
|
97 |
+
selected_domains = get_domain_dropdown(DOMAIN_COLS_QA, DOMAIN_COLS_QA)
|
98 |
+
# select language
|
99 |
+
with gr.Row():
|
100 |
+
selected_langs = get_language_dropdown(LANG_COLS_QA, LANG_COLS_QA)
|
101 |
+
|
102 |
+
with gr.Column():
|
103 |
+
with gr.Row():
|
104 |
+
selected_version = get_version_dropdown()
|
105 |
+
# select the metric
|
106 |
+
selected_metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_QA)
|
107 |
+
with gr.Row():
|
108 |
+
show_anonymous = get_anonymous_checkbox()
|
109 |
+
with gr.Row():
|
110 |
+
show_revision_and_timestamp = get_revision_and_ts_checkbox()
|
111 |
+
with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
|
112 |
+
with gr.TabItem("Retrieval + Reranking", id=10):
|
113 |
+
with gr.Row():
|
114 |
+
# search retrieval models
|
115 |
+
with gr.Column():
|
116 |
+
search_bar = get_search_bar()
|
117 |
+
# select reranking models
|
118 |
+
with gr.Column():
|
119 |
+
selected_rerankings = get_reranking_dropdown(reranking_models)
|
120 |
+
leaderboard_table = get_leaderboard_table(leaderboard_df_qa, types_qa)
|
121 |
+
# Dummy leaderboard for handling the case when the user uses backspace key
|
122 |
+
hidden_leaderboard_table_for_search = get_leaderboard_table(original_df_qa, types_qa, visible=False)
|
123 |
+
|
124 |
+
set_listeners(
|
125 |
+
"qa",
|
126 |
+
leaderboard_table,
|
127 |
+
hidden_leaderboard_table_for_search,
|
128 |
+
search_bar,
|
129 |
+
selected_domains,
|
130 |
+
selected_langs,
|
131 |
+
selected_rerankings,
|
132 |
+
show_anonymous,
|
133 |
+
show_revision_and_timestamp,
|
134 |
+
)
|
135 |
|
136 |
+
# set metric listener
|
137 |
+
selected_metric.change(
|
138 |
+
update_metric_qa,
|
139 |
+
[
|
140 |
+
selected_metric,
|
141 |
+
selected_domains,
|
142 |
+
selected_langs,
|
143 |
+
selected_rerankings,
|
144 |
search_bar,
|
|
|
|
|
|
|
|
|
145 |
show_anonymous,
|
146 |
+
show_revision_and_timestamp,
|
147 |
+
],
|
148 |
+
leaderboard_table,
|
149 |
+
queue=True
|
150 |
+
)
|
151 |
+
with gr.TabItem("Retrieval Only", id=11):
|
152 |
+
with gr.Row():
|
153 |
+
with gr.Column(scale=1):
|
154 |
+
search_bar_retriever = get_search_bar()
|
155 |
+
with gr.Column(scale=1):
|
156 |
+
selected_noreranker = get_noreranking_dropdown()
|
157 |
+
lb_df_retriever = leaderboard_df_qa[leaderboard_df_qa[COL_NAME_RERANKING_MODEL] == "NoReranker"]
|
158 |
+
lb_df_retriever = reset_rank(lb_df_retriever)
|
159 |
+
lb_table_retriever = get_leaderboard_table(lb_df_retriever, types_qa)
|
160 |
+
# Dummy leaderboard for handling the case when the user uses backspace key
|
161 |
+
hidden_lb_df_retriever = original_df_qa[original_df_qa[COL_NAME_RERANKING_MODEL] == "NoReranker"]
|
162 |
+
hidden_lb_df_retriever = reset_rank(hidden_lb_df_retriever)
|
163 |
+
hidden_lb_table_retriever = get_leaderboard_table(hidden_lb_df_retriever, types_qa, visible=False)
|
164 |
+
|
165 |
+
set_listeners(
|
166 |
+
"qa",
|
167 |
+
lb_table_retriever,
|
168 |
+
hidden_lb_table_retriever,
|
169 |
+
search_bar_retriever,
|
170 |
+
selected_domains,
|
171 |
+
selected_langs,
|
172 |
+
selected_noreranker,
|
173 |
+
show_anonymous,
|
174 |
+
show_revision_and_timestamp,
|
175 |
+
)
|
176 |
|
177 |
+
# set metric listener
|
178 |
+
selected_metric.change(
|
179 |
+
update_metric_qa,
|
180 |
+
[
|
181 |
+
selected_metric,
|
182 |
+
selected_domains,
|
183 |
+
selected_langs,
|
184 |
+
selected_noreranker,
|
185 |
+
search_bar_retriever,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
show_anonymous,
|
187 |
+
show_revision_and_timestamp,
|
188 |
+
],
|
189 |
+
lb_table_retriever,
|
190 |
+
queue=True
|
191 |
+
)
|
192 |
+
with gr.TabItem("Reranking Only", id=12):
|
193 |
+
lb_df_reranker = leaderboard_df_qa[leaderboard_df_qa[COL_NAME_RETRIEVAL_MODEL] == "BM25"]
|
194 |
+
lb_df_reranker = reset_rank(lb_df_reranker)
|
195 |
+
reranking_models_reranker = lb_df_reranker[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
|
196 |
+
with gr.Row():
|
197 |
+
with gr.Column(scale=1):
|
198 |
+
selected_rerankings_reranker = get_reranking_dropdown(reranking_models_reranker)
|
199 |
+
with gr.Column(scale=1):
|
200 |
+
search_bar_reranker = gr.Textbox(show_label=False, visible=False)
|
201 |
+
lb_table_reranker = get_leaderboard_table(lb_df_reranker, types_qa)
|
202 |
+
hidden_lb_df_reranker = original_df_qa[original_df_qa[COL_NAME_RETRIEVAL_MODEL] == "BM25"]
|
203 |
+
hidden_lb_df_reranker = reset_rank(hidden_lb_df_reranker)
|
204 |
+
hidden_lb_table_reranker = get_leaderboard_table(
|
205 |
+
hidden_lb_df_reranker, types_qa, visible=False
|
206 |
+
)
|
207 |
|
208 |
+
set_listeners(
|
209 |
+
"qa",
|
210 |
+
lb_table_reranker,
|
211 |
+
hidden_lb_table_reranker,
|
212 |
+
search_bar_reranker,
|
213 |
+
selected_domains,
|
214 |
+
selected_langs,
|
215 |
+
selected_rerankings_reranker,
|
216 |
+
show_anonymous,
|
217 |
+
show_revision_and_timestamp,
|
218 |
+
)
|
219 |
+
# set metric listener
|
220 |
+
selected_metric.change(
|
221 |
+
update_metric_qa,
|
222 |
+
[
|
223 |
+
selected_metric,
|
224 |
+
selected_domains,
|
225 |
+
selected_langs,
|
226 |
+
selected_rerankings_reranker,
|
227 |
+
search_bar_reranker,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
show_anonymous,
|
229 |
+
show_revision_and_timestamp,
|
230 |
+
],
|
231 |
+
lb_table_reranker,
|
232 |
+
queue=True
|
233 |
+
)
|
234 |
+
with gr.TabItem("Long Doc", elem_id="long-doc-benchmark-tab-table", id=1):
|
235 |
+
with gr.Row():
|
236 |
+
with gr.Column(min_width=320):
|
237 |
+
# select domain
|
238 |
+
with gr.Row():
|
239 |
+
selected_domains = get_domain_dropdown(DOMAIN_COLS_LONG_DOC, DOMAIN_COLS_LONG_DOC)
|
240 |
+
# select language
|
241 |
+
with gr.Row():
|
242 |
+
selected_langs = get_language_dropdown(
|
243 |
+
LANG_COLS_LONG_DOC, LANG_COLS_LONG_DOC
|
244 |
)
|
245 |
+
with gr.Column():
|
246 |
+
with gr.Row():
|
247 |
+
selected_version = get_version_dropdown()
|
248 |
+
# select the metric
|
249 |
+
with gr.Row():
|
250 |
+
selected_metric = get_metric_dropdown(METRIC_LIST, DEFAULT_METRIC_LONG_DOC)
|
251 |
+
with gr.Row():
|
252 |
+
show_anonymous = get_anonymous_checkbox()
|
253 |
+
with gr.Row():
|
254 |
+
show_revision_and_timestamp = get_revision_and_ts_checkbox()
|
255 |
+
with gr.Tabs(elem_classes="tab-buttons") as sub_tabs:
|
256 |
+
with gr.TabItem("Retrieval + Reranking", id=20):
|
257 |
+
with gr.Row():
|
258 |
+
with gr.Column():
|
259 |
+
search_bar = get_search_bar()
|
260 |
+
# select reranking model
|
261 |
+
with gr.Column():
|
262 |
+
selected_rerankings = get_reranking_dropdown(reranking_models)
|
263 |
+
|
264 |
+
lb_table = get_leaderboard_table(
|
265 |
+
leaderboard_df_long_doc, types_long_doc
|
266 |
+
)
|
267 |
|
268 |
+
# Dummy leaderboard for handling the case when the user uses backspace key
|
269 |
+
hidden_lb_table_for_search = get_leaderboard_table(
|
270 |
+
original_df_long_doc, types_long_doc, visible=False
|
271 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
+
set_listeners(
|
274 |
+
"long-doc",
|
275 |
+
lb_table,
|
276 |
+
hidden_lb_table_for_search,
|
277 |
+
search_bar,
|
278 |
+
selected_domains,
|
279 |
+
selected_langs,
|
280 |
+
selected_rerankings,
|
281 |
+
show_anonymous,
|
282 |
+
show_revision_and_timestamp,
|
283 |
+
)
|
284 |
|
285 |
+
# set metric listener
|
286 |
+
selected_metric.change(
|
287 |
+
update_metric_long_doc,
|
288 |
+
[
|
289 |
+
selected_metric,
|
290 |
+
selected_domains,
|
291 |
+
selected_langs,
|
292 |
+
selected_rerankings,
|
293 |
search_bar,
|
|
|
|
|
|
|
|
|
294 |
show_anonymous,
|
295 |
+
show_revision_and_timestamp
|
296 |
+
],
|
297 |
+
lb_table,
|
298 |
+
queue=True
|
299 |
+
)
|
300 |
+
with gr.TabItem("Retrieval Only", id=21):
|
301 |
+
with gr.Row():
|
302 |
+
with gr.Column(scale=1):
|
303 |
+
search_bar_retriever = get_search_bar()
|
304 |
+
with gr.Column(scale=1):
|
305 |
+
selected_noreranker = get_noreranking_dropdown()
|
306 |
+
lb_df_retriever_long_doc = leaderboard_df_long_doc[
|
307 |
+
leaderboard_df_long_doc[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
308 |
+
]
|
309 |
+
lb_df_retriever_long_doc = reset_rank(lb_df_retriever_long_doc)
|
310 |
+
hidden_lb_db_retriever_long_doc = original_df_long_doc[
|
311 |
+
original_df_long_doc[COL_NAME_RERANKING_MODEL] == "NoReranker"
|
312 |
+
]
|
313 |
+
hidden_lb_db_retriever_long_doc = reset_rank(hidden_lb_db_retriever_long_doc)
|
314 |
+
lb_table_retriever_long_doc = get_leaderboard_table(
|
315 |
+
lb_df_retriever_long_doc, types_long_doc)
|
316 |
+
hidden_lb_table_retriever_long_doc = get_leaderboard_table(
|
317 |
+
hidden_lb_db_retriever_long_doc, types_long_doc, visible=False
|
318 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
+
set_listeners(
|
321 |
+
"long-doc",
|
322 |
+
lb_table_retriever_long_doc,
|
323 |
+
hidden_lb_table_retriever_long_doc,
|
324 |
+
search_bar_retriever,
|
325 |
+
selected_domains,
|
326 |
+
selected_langs,
|
327 |
+
selected_noreranker,
|
328 |
+
show_anonymous,
|
329 |
+
show_revision_and_timestamp,
|
330 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
+
selected_metric.change(
|
333 |
+
update_metric_long_doc,
|
334 |
+
[
|
335 |
+
selected_metric,
|
336 |
+
selected_domains,
|
337 |
+
selected_langs,
|
338 |
+
selected_noreranker,
|
339 |
+
search_bar_retriever,
|
|
|
|
|
|
|
340 |
show_anonymous,
|
341 |
+
show_revision_and_timestamp,
|
342 |
+
],
|
343 |
+
lb_table_retriever_long_doc,
|
344 |
+
queue=True
|
345 |
+
)
|
346 |
+
with gr.TabItem("Reranking Only", id=22):
|
347 |
+
lb_df_reranker_ldoc = leaderboard_df_long_doc[
|
348 |
+
leaderboard_df_long_doc[COL_NAME_RETRIEVAL_MODEL] == "BM25"
|
349 |
+
]
|
350 |
+
lb_df_reranker_ldoc = reset_rank(lb_df_reranker_ldoc)
|
351 |
+
reranking_models_reranker_ldoc = lb_df_reranker_ldoc[COL_NAME_RERANKING_MODEL].apply(remove_html).unique().tolist()
|
352 |
+
with gr.Row():
|
353 |
+
with gr.Column(scale=1):
|
354 |
+
selected_rerankings_reranker_ldoc = get_reranking_dropdown(reranking_models_reranker_ldoc)
|
355 |
+
with gr.Column(scale=1):
|
356 |
+
search_bar_reranker_ldoc = gr.Textbox(show_label=False, visible=False)
|
357 |
+
lb_table_reranker_ldoc = get_leaderboard_table(lb_df_reranker_ldoc, types_long_doc)
|
358 |
+
hidden_lb_df_reranker_ldoc = original_df_long_doc[original_df_long_doc[COL_NAME_RETRIEVAL_MODEL] == "BM25"]
|
359 |
+
hidden_lb_df_reranker_ldoc = reset_rank(hidden_lb_df_reranker_ldoc)
|
360 |
+
hidden_lb_table_reranker_ldoc = get_leaderboard_table(
|
361 |
+
hidden_lb_df_reranker_ldoc, types_long_doc, visible=False
|
362 |
+
)
|
363 |
|
364 |
+
set_listeners(
|
365 |
+
"long-doc",
|
366 |
+
lb_table_reranker_ldoc,
|
367 |
+
hidden_lb_table_reranker_ldoc,
|
368 |
+
search_bar_reranker_ldoc,
|
369 |
+
selected_domains,
|
370 |
+
selected_langs,
|
371 |
+
selected_rerankings_reranker_ldoc,
|
372 |
+
show_anonymous,
|
373 |
+
show_revision_and_timestamp,
|
374 |
+
)
|
375 |
+
selected_metric.change(
|
376 |
+
update_metric_long_doc,
|
377 |
+
[
|
378 |
+
selected_metric,
|
379 |
+
selected_domains,
|
380 |
+
selected_langs,
|
381 |
+
selected_rerankings_reranker_ldoc,
|
382 |
+
search_bar_reranker_ldoc,
|
383 |
+
show_anonymous,
|
384 |
+
show_revision_and_timestamp,
|
385 |
+
],
|
386 |
+
lb_table_reranker_ldoc,
|
387 |
+
queue=True
|
388 |
+
)
|
389 |
|
390 |
with gr.TabItem("🚀Submit here!", elem_id="submit-tab-table", id=2):
|
391 |
with gr.Column():
|
|
|
401 |
with gr.Row():
|
402 |
with gr.Column():
|
403 |
reranking_model_name = gr.Textbox(
|
404 |
+
label="Reranking Model name",
|
405 |
+
info="Optional",
|
406 |
+
value="NoReranker"
|
407 |
)
|
408 |
with gr.Column():
|
409 |
+
reranking_model_url = gr.Textbox(
|
410 |
+
label="Reranking Model URL",
|
411 |
+
info="Optional",
|
412 |
+
value=""
|
413 |
+
)
|
414 |
with gr.Row():
|
415 |
with gr.Column():
|
416 |
benchmark_version = gr.Dropdown(
|
417 |
+
["AIR-Bench_24.04", ],
|
418 |
+
value="AIR-Bench_24.04",
|
419 |
interactive=True,
|
420 |
+
label="AIR-Bench Version")
|
|
|
421 |
with gr.Row():
|
422 |
upload_button = gr.UploadButton("Click to upload search results", file_count="single")
|
423 |
with gr.Row():
|
|
|
426 |
is_anonymous = gr.Checkbox(
|
427 |
label="Nope. I want to submit anonymously 🥷",
|
428 |
value=False,
|
429 |
+
info="Do you want to shown on the leaderboard by default?")
|
|
|
430 |
with gr.Row():
|
431 |
submit_button = gr.Button("Submit")
|
432 |
with gr.Row():
|
|
|
436 |
[
|
437 |
upload_button,
|
438 |
],
|
439 |
+
file_output)
|
|
|
440 |
submit_button.click(
|
441 |
submit_results,
|
442 |
[
|
|
|
446 |
reranking_model_name,
|
447 |
reranking_model_url,
|
448 |
benchmark_version,
|
449 |
+
is_anonymous
|
450 |
],
|
451 |
submission_result,
|
452 |
+
show_progress="hidden"
|
453 |
)
|
454 |
|
455 |
with gr.TabItem("📝 About", elem_id="llm-benchmark-tab-table", id=3):
|
456 |
gr.Markdown(BENCHMARKS_TEXT, elem_classes="markdown-text")
|
457 |
|
458 |
+
scheduler = BackgroundScheduler()
|
459 |
+
scheduler.add_job(restart_space, "interval", seconds=1800)
|
460 |
+
scheduler.start()
|
461 |
+
demo.queue(default_concurrency_limit=40).launch()
|
|
|
|
|
|
|
|
|
|
pyproject.toml
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
[tool.ruff]
|
2 |
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
|
3 |
-
|
4 |
-
|
5 |
line-length = 119
|
6 |
-
|
7 |
|
8 |
[tool.isort]
|
9 |
profile = "black"
|
|
|
1 |
[tool.ruff]
|
2 |
# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
|
3 |
+
select = ["E", "F"]
|
4 |
+
ignore = ["E501"] # line too long (black is taking care of this)
|
5 |
line-length = 119
|
6 |
+
fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
|
7 |
|
8 |
[tool.isort]
|
9 |
profile = "black"
|
requirements.txt
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
-
APScheduler
|
2 |
-
black
|
3 |
-
click
|
4 |
-
datasets
|
5 |
-
gradio
|
6 |
-
gradio_client
|
7 |
huggingface-hub>=0.18.0
|
8 |
-
numpy
|
9 |
-
pandas
|
10 |
-
python-dateutil
|
11 |
-
requests
|
12 |
-
tqdm
|
13 |
-
accelerate
|
14 |
-
socksio
|
15 |
-
air-benchmark>=0.1.0
|
|
|
1 |
+
APScheduler==3.10.1
|
2 |
+
black==23.11.0
|
3 |
+
click==8.1.3
|
4 |
+
datasets==2.14.5
|
5 |
+
gradio==4.29.0
|
6 |
+
gradio_client==0.16.1
|
7 |
huggingface-hub>=0.18.0
|
8 |
+
numpy==1.24.2
|
9 |
+
pandas==2.0.0
|
10 |
+
python-dateutil==2.8.2
|
11 |
+
requests==2.31.0
|
12 |
+
tqdm==4.65.0
|
13 |
+
accelerate==0.24.1
|
14 |
+
socksio==1.0.0
|
|
src/about.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# Your leaderboard name
|
2 |
TITLE = """<h1 align="center" id="space-title">AIR-Bench: Automated Heterogeneous Information Retrieval Benchmark
|
3 |
-
(v0.
|
4 |
|
5 |
# What does your leaderboard evaluate?
|
6 |
INTRODUCTION_TEXT = """
|
@@ -8,7 +8,7 @@ INTRODUCTION_TEXT = """
|
|
8 |
"""
|
9 |
|
10 |
# Which evaluations are you running? how can people reproduce what you have?
|
11 |
-
BENCHMARKS_TEXT = """
|
12 |
## How the test data are generated?
|
13 |
### Find more information at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/data_generation.md)
|
14 |
|
@@ -17,30 +17,14 @@ BENCHMARKS_TEXT = """
|
|
17 |
- A: Yes, we plan to release new datasets on regular basis. However, the update frequency is to be decided.
|
18 |
|
19 |
- Q: As you are using models to do the quality control when generating the data, is it biased to the models that are used?
|
20 |
-
- A: Yes, the results is biased to the chosen models. However, we believe the datasets labeled by human are also biased to the human's preference. The key point to verify is whether the model's bias is consistent with the human's. We use our approach to generate test data using the well established MSMARCO datasets. We benchmark different models' performances using the generated dataset and the human-label DEV dataset. Comparing the ranking of different models on these two datasets, we observe the spearman correlation between them is 0.8211 (p-value=5e-5). This indicates that the models' perference is well aligned with the human. Please refer to [here](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/
|
21 |
|
22 |
"""
|
23 |
|
24 |
EVALUATION_QUEUE_TEXT = """
|
25 |
## Check out the submission steps at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/submit_to_leaderboard.md)
|
26 |
-
|
27 |
-
## You can find the **STATUS of Your Submission** at the [Backend Space](https://huggingface.co/spaces/AIR-Bench/leaderboard_backend)
|
28 |
-
|
29 |
-
- If the status is **✔️ Success**, then you can find your results at the [Leaderboard Space](https://huggingface.co/spaces/AIR-Bench/leaderboard) in no more than one hour.
|
30 |
-
- If the status is **❌ Failed**, please check your submission steps and try again. If you have any questions, please feel free to open an issue [here](https://github.com/AIR-Bench/AIR-Bench/issues/new).
|
31 |
"""
|
32 |
|
33 |
CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results"
|
34 |
CITATION_BUTTON_TEXT = r"""
|
35 |
-
```bibtex
|
36 |
-
@misc{chen2024airbench,
|
37 |
-
title={AIR-Bench: Automated Heterogeneous Information Retrieval Benchmark},
|
38 |
-
author={Jianlyu Chen and Nan Wang and Chaofan Li and Bo Wang and Shitao Xiao and Han Xiao and Hao Liao and Defu Lian and Zheng Liu},
|
39 |
-
year={2024},
|
40 |
-
eprint={2412.13102},
|
41 |
-
archivePrefix={arXiv},
|
42 |
-
primaryClass={cs.IR},
|
43 |
-
url={https://arxiv.org/abs/2412.13102},
|
44 |
-
}
|
45 |
-
```
|
46 |
"""
|
|
|
1 |
# Your leaderboard name
|
2 |
TITLE = """<h1 align="center" id="space-title">AIR-Bench: Automated Heterogeneous Information Retrieval Benchmark
|
3 |
+
(v0.0.3) </h1>"""
|
4 |
|
5 |
# What does your leaderboard evaluate?
|
6 |
INTRODUCTION_TEXT = """
|
|
|
8 |
"""
|
9 |
|
10 |
# Which evaluations are you running? how can people reproduce what you have?
|
11 |
+
BENCHMARKS_TEXT = f"""
|
12 |
## How the test data are generated?
|
13 |
### Find more information at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/data_generation.md)
|
14 |
|
|
|
17 |
- A: Yes, we plan to release new datasets on regular basis. However, the update frequency is to be decided.
|
18 |
|
19 |
- Q: As you are using models to do the quality control when generating the data, is it biased to the models that are used?
|
20 |
+
- A: Yes, the results is biased to the chosen models. However, we believe the datasets labeled by human are also biased to the human's preference. The key point to verify is whether the model's bias is consistent with the human's. We use our approach to generate test data using the well established MSMARCO datasets. We benchmark different models' performances using the generated dataset and the human-label DEV dataset. Comparing the ranking of different models on these two datasets, we observe the spearman correlation between them is 0.8211 (p-value=5e-5). This indicates that the models' perference is well aligned with the human. Please refer to [here](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/available_evaluation_results.md#consistency-with-ms-marco) for details
|
21 |
|
22 |
"""
|
23 |
|
24 |
EVALUATION_QUEUE_TEXT = """
|
25 |
## Check out the submission steps at [our GitHub repo](https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/submit_to_leaderboard.md)
|
|
|
|
|
|
|
|
|
|
|
26 |
"""
|
27 |
|
28 |
CITATION_BUTTON_LABEL = "Copy the following snippet to cite these results"
|
29 |
CITATION_BUTTON_TEXT = r"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
"""
|
src/benchmarks.py
CHANGED
@@ -1,71 +1,152 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from enum import Enum
|
3 |
|
4 |
-
from air_benchmark.tasks.tasks import BenchmarkTable
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
@dataclass
|
11 |
class Benchmark:
|
12 |
name: str # [domain]_[language]_[metric], task_key in the json file,
|
13 |
-
metric: str # metric_key in the json file
|
14 |
col_name: str # [domain]_[language], name to display in the leaderboard
|
15 |
domain: str
|
16 |
lang: str
|
17 |
task: str
|
18 |
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
for
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
benchmark_name = get_safe_name(f"{domain}_{lang}")
|
29 |
col_name = benchmark_name
|
30 |
for metric in dataset_list:
|
31 |
-
|
32 |
-
|
33 |
-
benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
|
34 |
-
return benchmark_dict
|
35 |
-
|
36 |
-
|
37 |
-
def get_doc_benchmarks_dict(version: str):
|
38 |
-
benchmark_dict = {}
|
39 |
-
for task, domain_dict in BenchmarkTable[version].items():
|
40 |
-
if task != TaskType.long_doc.value:
|
41 |
-
continue
|
42 |
-
for domain, lang_dict in domain_dict.items():
|
43 |
-
for lang, dataset_list in lang_dict.items():
|
44 |
for dataset in dataset_list:
|
45 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
46 |
benchmark_name = get_safe_name(benchmark_name)
|
47 |
col_name = benchmark_name
|
48 |
-
if "test" not in dataset_list[dataset]["splits"]:
|
49 |
-
continue
|
50 |
for metric in METRIC_LIST:
|
51 |
-
|
52 |
-
|
53 |
-
)
|
54 |
-
return benchmark_dict
|
55 |
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
for
|
59 |
-
safe_version_name = get_safe_name(version)
|
60 |
-
_qa_benchmark_dict[safe_version_name] = Enum(f"QABenchmarks_{safe_version_name}", get_qa_benchmarks_dict(version))
|
61 |
|
62 |
-
|
63 |
-
for
|
64 |
-
safe_version_name = get_safe_name(version)
|
65 |
-
_doc_benchmark_dict[safe_version_name] = Enum(
|
66 |
-
f"LongDocBenchmarks_{safe_version_name}", get_doc_benchmarks_dict(version)
|
67 |
-
)
|
68 |
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from enum import Enum
|
3 |
|
|
|
4 |
|
5 |
+
def get_safe_name(name: str):
|
6 |
+
"""Get RFC 1123 compatible safe name"""
|
7 |
+
name = name.replace('-', '_')
|
8 |
+
return ''.join(
|
9 |
+
character.lower()
|
10 |
+
for character in name
|
11 |
+
if (character.isalnum() or character == '_'))
|
12 |
+
|
13 |
+
|
14 |
+
dataset_dict = {
|
15 |
+
"qa": {
|
16 |
+
"wiki": {
|
17 |
+
"en": ["wikipedia_20240101", ],
|
18 |
+
"zh": ["wikipedia_20240101", ]
|
19 |
+
},
|
20 |
+
"web": {
|
21 |
+
"en": ["mC4", ],
|
22 |
+
"zh": ["mC4", ]
|
23 |
+
},
|
24 |
+
"news": {
|
25 |
+
"en": ["CC-News", ],
|
26 |
+
"zh": ["CC-News", ]
|
27 |
+
},
|
28 |
+
"healthcare": {
|
29 |
+
"en": ["PubMedQA", ],
|
30 |
+
"zh": ["Huatuo-26M", ]
|
31 |
+
},
|
32 |
+
"law": {
|
33 |
+
"en": ["pile-of-law", ],
|
34 |
+
# "zh": ["flk_npc_gov_cn", ]
|
35 |
+
},
|
36 |
+
"finance": {
|
37 |
+
"en": ["Reuters-Financial", ],
|
38 |
+
"zh": ["FinCorpus", ]
|
39 |
+
},
|
40 |
+
"arxiv": {
|
41 |
+
"en": ["Arxiv", ]},
|
42 |
+
"msmarco": {
|
43 |
+
"en": ["MS MARCO", ]},
|
44 |
+
},
|
45 |
+
"long-doc": {
|
46 |
+
"arxiv": {
|
47 |
+
"en": ["gpt3", "llama2", "llm-survey", "gemini"],
|
48 |
+
},
|
49 |
+
"book": {
|
50 |
+
"en": [
|
51 |
+
"origin-of-species_darwin",
|
52 |
+
"a-brief-history-of-time_stephen-hawking"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
"healthcare": {
|
56 |
+
"en": [
|
57 |
+
"pubmed_100k-200k_1",
|
58 |
+
"pubmed_100k-200k_2",
|
59 |
+
"pubmed_100k-200k_3",
|
60 |
+
"pubmed_40k-50k_5-merged",
|
61 |
+
"pubmed_30k-40k_10-merged"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
"law": {
|
65 |
+
"en": [
|
66 |
+
"lex_files_300k-400k",
|
67 |
+
"lex_files_400k-500k",
|
68 |
+
"lex_files_500k-600k",
|
69 |
+
"lex_files_600k-700k"
|
70 |
+
]
|
71 |
+
}
|
72 |
+
}
|
73 |
+
}
|
74 |
+
|
75 |
+
METRIC_LIST = [
|
76 |
+
"ndcg_at_1",
|
77 |
+
"ndcg_at_3",
|
78 |
+
"ndcg_at_5",
|
79 |
+
"ndcg_at_10",
|
80 |
+
"ndcg_at_100",
|
81 |
+
"ndcg_at_1000",
|
82 |
+
"map_at_1",
|
83 |
+
"map_at_3",
|
84 |
+
"map_at_5",
|
85 |
+
"map_at_10",
|
86 |
+
"map_at_100",
|
87 |
+
"map_at_1000",
|
88 |
+
"recall_at_1",
|
89 |
+
"recall_at_3",
|
90 |
+
"recall_at_5",
|
91 |
+
"recall_at_10",
|
92 |
+
"recall_at_100",
|
93 |
+
"recall_at_1000",
|
94 |
+
"precision_at_1",
|
95 |
+
"precision_at_3",
|
96 |
+
"precision_at_5",
|
97 |
+
"precision_at_10",
|
98 |
+
"precision_at_100",
|
99 |
+
"precision_at_1000",
|
100 |
+
"mrr_at_1",
|
101 |
+
"mrr_at_3",
|
102 |
+
"mrr_at_5",
|
103 |
+
"mrr_at_10",
|
104 |
+
"mrr_at_100",
|
105 |
+
"mrr_at_1000"
|
106 |
+
]
|
107 |
|
108 |
|
109 |
@dataclass
|
110 |
class Benchmark:
|
111 |
name: str # [domain]_[language]_[metric], task_key in the json file,
|
112 |
+
metric: str # ndcg_at_1 ,metric_key in the json file
|
113 |
col_name: str # [domain]_[language], name to display in the leaderboard
|
114 |
domain: str
|
115 |
lang: str
|
116 |
task: str
|
117 |
|
118 |
|
119 |
+
qa_benchmark_dict = {}
|
120 |
+
long_doc_benchmark_dict = {}
|
121 |
+
for task, domain_dict in dataset_dict.items():
|
122 |
+
for domain, lang_dict in domain_dict.items():
|
123 |
+
for lang, dataset_list in lang_dict.items():
|
124 |
+
if task == "qa":
|
125 |
+
benchmark_name = f"{domain}_{lang}"
|
126 |
+
benchmark_name = get_safe_name(benchmark_name)
|
|
|
127 |
col_name = benchmark_name
|
128 |
for metric in dataset_list:
|
129 |
+
qa_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain, lang, task)
|
130 |
+
elif task == "long-doc":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
for dataset in dataset_list:
|
132 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
133 |
benchmark_name = get_safe_name(benchmark_name)
|
134 |
col_name = benchmark_name
|
|
|
|
|
135 |
for metric in METRIC_LIST:
|
136 |
+
long_doc_benchmark_dict[benchmark_name] = Benchmark(benchmark_name, metric, col_name, domain,
|
137 |
+
lang, task)
|
|
|
|
|
138 |
|
139 |
+
BenchmarksQA = Enum('BenchmarksQA', qa_benchmark_dict)
|
140 |
+
BenchmarksLongDoc = Enum('BenchmarksLongDoc', long_doc_benchmark_dict)
|
141 |
|
142 |
+
BENCHMARK_COLS_QA = [c.col_name for c in qa_benchmark_dict.values()]
|
143 |
+
BENCHMARK_COLS_LONG_DOC = [c.col_name for c in long_doc_benchmark_dict.values()]
|
|
|
|
|
144 |
|
145 |
+
DOMAIN_COLS_QA = list(frozenset([c.domain for c in qa_benchmark_dict.values()]))
|
146 |
+
LANG_COLS_QA = list(frozenset([c.lang for c in qa_benchmark_dict.values()]))
|
|
|
|
|
|
|
|
|
147 |
|
148 |
+
DOMAIN_COLS_LONG_DOC = list(frozenset([c.domain for c in long_doc_benchmark_dict.values()]))
|
149 |
+
LANG_COLS_LONG_DOC = list(frozenset([c.lang for c in long_doc_benchmark_dict.values()]))
|
150 |
|
151 |
+
DEFAULT_METRIC_QA = "ndcg_at_10"
|
152 |
+
DEFAULT_METRIC_LONG_DOC = "recall_at_10"
|
src/{css_html_js.py → display/css_html_js.py}
RENAMED
File without changes
|
src/display/formatting.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def model_hyperlink(link, model_name):
|
2 |
+
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
3 |
+
|
4 |
+
|
5 |
+
def make_clickable_model(model_name: str, model_link: str):
|
6 |
+
# link = f"https://huggingface.co/{model_name}"
|
7 |
+
if not model_link or not model_link.startswith("https://") or model_name == "BM25":
|
8 |
+
return model_name
|
9 |
+
return model_hyperlink(model_link, model_name)
|
10 |
+
|
11 |
+
|
12 |
+
def styled_error(error):
|
13 |
+
return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
|
14 |
+
|
15 |
+
|
16 |
+
def styled_warning(warn):
|
17 |
+
return f"<p style='color: orange; font-size: 20px; text-align: center;'>{warn}</p>"
|
18 |
+
|
19 |
+
|
20 |
+
def styled_message(message):
|
21 |
+
return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
|
22 |
+
|
23 |
+
|
24 |
+
def has_no_nan_values(df, columns):
|
25 |
+
return df[columns].notna().all(axis=1)
|
26 |
+
|
27 |
+
|
28 |
+
def has_nan_values(df, columns):
|
29 |
+
return df[columns].isna().any(axis=1)
|
src/{components.py → display/gradio_formatting.py}
RENAMED
@@ -1,14 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
from src.envs import BENCHMARK_VERSION_LIST, LATEST_BENCHMARK_VERSION
|
4 |
-
|
5 |
|
6 |
def get_version_dropdown():
|
7 |
return gr.Dropdown(
|
8 |
-
choices=
|
9 |
-
value=
|
10 |
label="Select the version of AIR-Bench",
|
11 |
-
interactive=True
|
12 |
)
|
13 |
|
14 |
|
@@ -16,25 +14,26 @@ def get_search_bar():
|
|
16 |
return gr.Textbox(
|
17 |
placeholder=" 🔍 Search for retrieval methods (separate multiple queries with `;`) and press ENTER...",
|
18 |
show_label=False,
|
19 |
-
info="Search the retrieval methods"
|
20 |
)
|
21 |
|
22 |
|
23 |
def get_reranking_dropdown(model_list):
|
24 |
-
return gr.Dropdown(
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
def get_noreranking_dropdown():
|
28 |
return gr.Dropdown(
|
29 |
-
choices=[
|
30 |
-
|
31 |
-
],
|
32 |
-
value=[
|
33 |
-
"NoReranker",
|
34 |
-
],
|
35 |
interactive=False,
|
36 |
multiselect=True,
|
37 |
-
visible=False
|
38 |
)
|
39 |
|
40 |
|
@@ -53,10 +52,7 @@ def get_metric_dropdown(metric_list, default_metrics):
|
|
53 |
)
|
54 |
|
55 |
|
56 |
-
def get_domain_dropdown(
|
57 |
-
domain_list = list(frozenset([c.value.domain for c in list(benchmarks.value)]))
|
58 |
-
if default_domains is None:
|
59 |
-
default_domains = domain_list
|
60 |
return gr.CheckboxGroup(
|
61 |
choices=domain_list,
|
62 |
value=default_domains,
|
@@ -65,16 +61,13 @@ def get_domain_dropdown(benchmarks, default_domains=None):
|
|
65 |
)
|
66 |
|
67 |
|
68 |
-
def get_language_dropdown(
|
69 |
-
language_list = list(frozenset([c.value.lang for c in list(benchmarks.value)]))
|
70 |
-
if default_languages is None:
|
71 |
-
default_languages = language_list
|
72 |
return gr.Dropdown(
|
73 |
choices=language_list,
|
74 |
-
value=
|
75 |
label="Select the languages",
|
76 |
multiselect=True,
|
77 |
-
interactive=True
|
78 |
)
|
79 |
|
80 |
|
@@ -82,13 +75,15 @@ def get_anonymous_checkbox():
|
|
82 |
return gr.Checkbox(
|
83 |
label="Show anonymous submissions",
|
84 |
value=False,
|
85 |
-
info="The anonymous submissions might have invalid model information."
|
86 |
)
|
87 |
|
88 |
|
89 |
def get_revision_and_ts_checkbox():
|
90 |
return gr.Checkbox(
|
91 |
-
label="Show submission details",
|
|
|
|
|
92 |
)
|
93 |
|
94 |
|
|
|
1 |
import gradio as gr
|
2 |
|
|
|
|
|
3 |
|
4 |
def get_version_dropdown():
|
5 |
return gr.Dropdown(
|
6 |
+
choices=["AIR-Bench_24.04", ],
|
7 |
+
value="AIR-Bench_24.04",
|
8 |
label="Select the version of AIR-Bench",
|
9 |
+
interactive=True
|
10 |
)
|
11 |
|
12 |
|
|
|
14 |
return gr.Textbox(
|
15 |
placeholder=" 🔍 Search for retrieval methods (separate multiple queries with `;`) and press ENTER...",
|
16 |
show_label=False,
|
17 |
+
info="Search the retrieval methods"
|
18 |
)
|
19 |
|
20 |
|
21 |
def get_reranking_dropdown(model_list):
|
22 |
+
return gr.Dropdown(
|
23 |
+
choices=model_list,
|
24 |
+
label="Select the reranking models",
|
25 |
+
interactive=True,
|
26 |
+
multiselect=True
|
27 |
+
)
|
28 |
|
29 |
|
30 |
def get_noreranking_dropdown():
|
31 |
return gr.Dropdown(
|
32 |
+
choices=["NoReranker", ],
|
33 |
+
value=["NoReranker", ],
|
|
|
|
|
|
|
|
|
34 |
interactive=False,
|
35 |
multiselect=True,
|
36 |
+
visible=False
|
37 |
)
|
38 |
|
39 |
|
|
|
52 |
)
|
53 |
|
54 |
|
55 |
+
def get_domain_dropdown(domain_list, default_domains):
|
|
|
|
|
|
|
56 |
return gr.CheckboxGroup(
|
57 |
choices=domain_list,
|
58 |
value=default_domains,
|
|
|
61 |
)
|
62 |
|
63 |
|
64 |
+
def get_language_dropdown(language_list, default_languages):
|
|
|
|
|
|
|
65 |
return gr.Dropdown(
|
66 |
choices=language_list,
|
67 |
+
value=language_list,
|
68 |
label="Select the languages",
|
69 |
multiselect=True,
|
70 |
+
interactive=True
|
71 |
)
|
72 |
|
73 |
|
|
|
75 |
return gr.Checkbox(
|
76 |
label="Show anonymous submissions",
|
77 |
value=False,
|
78 |
+
info="The anonymous submissions might have invalid model information."
|
79 |
)
|
80 |
|
81 |
|
82 |
def get_revision_and_ts_checkbox():
|
83 |
return gr.Checkbox(
|
84 |
+
label="Show submission details",
|
85 |
+
value=False,
|
86 |
+
info="Show the revision and timestamp information of submissions"
|
87 |
)
|
88 |
|
89 |
|
src/display/gradio_listener.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils import update_table, update_table_long_doc
|
2 |
+
|
3 |
+
|
4 |
+
def set_listeners(
|
5 |
+
task,
|
6 |
+
displayed_leaderboard,
|
7 |
+
hidden_leaderboard,
|
8 |
+
search_bar,
|
9 |
+
selected_domains,
|
10 |
+
selected_langs,
|
11 |
+
selected_rerankings,
|
12 |
+
show_anonymous,
|
13 |
+
show_revision_and_timestamp,
|
14 |
+
|
15 |
+
):
|
16 |
+
if task == "qa":
|
17 |
+
update_table_func = update_table
|
18 |
+
elif task == "long-doc":
|
19 |
+
update_table_func = update_table_long_doc
|
20 |
+
else:
|
21 |
+
raise NotImplementedError
|
22 |
+
# Set search_bar listener
|
23 |
+
search_bar.submit(
|
24 |
+
update_table_func,
|
25 |
+
[
|
26 |
+
hidden_leaderboard, # hidden_leaderboard_table_for_search,
|
27 |
+
selected_domains,
|
28 |
+
selected_langs,
|
29 |
+
selected_rerankings,
|
30 |
+
search_bar,
|
31 |
+
show_anonymous,
|
32 |
+
],
|
33 |
+
displayed_leaderboard
|
34 |
+
)
|
35 |
+
|
36 |
+
# Set column-wise listener
|
37 |
+
for selector in [
|
38 |
+
selected_domains, selected_langs, show_anonymous, show_revision_and_timestamp, selected_rerankings
|
39 |
+
]:
|
40 |
+
selector.change(
|
41 |
+
update_table_func,
|
42 |
+
[
|
43 |
+
hidden_leaderboard,
|
44 |
+
selected_domains,
|
45 |
+
selected_langs,
|
46 |
+
selected_rerankings,
|
47 |
+
search_bar,
|
48 |
+
show_anonymous,
|
49 |
+
show_revision_and_timestamp
|
50 |
+
],
|
51 |
+
displayed_leaderboard,
|
52 |
+
queue=True,
|
53 |
+
)
|
src/{columns.py → display/utils.py}
RENAMED
@@ -1,7 +1,9 @@
|
|
1 |
from dataclasses import dataclass, make_dataclass
|
2 |
|
|
|
3 |
|
4 |
-
|
|
|
5 |
return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
|
6 |
|
7 |
|
@@ -17,22 +19,28 @@ class ColumnContent:
|
|
17 |
never_hidden: bool = False
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def get_default_auto_eval_column_dict():
|
21 |
auto_eval_column_dict = []
|
22 |
-
|
23 |
auto_eval_column_dict.append(
|
24 |
-
[
|
25 |
-
"retrieval_model",
|
26 |
-
ColumnContent,
|
27 |
-
ColumnContent(COL_NAME_RETRIEVAL_MODEL, "markdown", True, never_hidden=True),
|
28 |
-
]
|
29 |
)
|
30 |
auto_eval_column_dict.append(
|
31 |
-
[
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
]
|
36 |
)
|
37 |
auto_eval_column_dict.append(
|
38 |
["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
|
@@ -40,30 +48,14 @@ def get_default_auto_eval_column_dict():
|
|
40 |
auto_eval_column_dict.append(
|
41 |
["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
|
42 |
)
|
43 |
-
auto_eval_column_dict.append(["average", ColumnContent, ColumnContent(COL_NAME_AVG, "number", True)])
|
44 |
auto_eval_column_dict.append(
|
45 |
-
[
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
COL_NAME_RETRIEVAL_MODEL_LINK,
|
50 |
-
"markdown",
|
51 |
-
False,
|
52 |
-
hidden=True,
|
53 |
-
),
|
54 |
-
]
|
55 |
)
|
56 |
auto_eval_column_dict.append(
|
57 |
-
[
|
58 |
-
"reranking_model_link",
|
59 |
-
ColumnContent,
|
60 |
-
ColumnContent(
|
61 |
-
COL_NAME_RERANKING_MODEL_LINK,
|
62 |
-
"markdown",
|
63 |
-
False,
|
64 |
-
hidden=True,
|
65 |
-
),
|
66 |
-
]
|
67 |
)
|
68 |
auto_eval_column_dict.append(
|
69 |
["is_anonymous", ColumnContent, ColumnContent(COL_NAME_IS_ANONYMOUS, "bool", False, hidden=True)]
|
@@ -71,10 +63,10 @@ def get_default_auto_eval_column_dict():
|
|
71 |
return auto_eval_column_dict
|
72 |
|
73 |
|
74 |
-
def make_autoevalcolumn(cls_name, benchmarks):
|
75 |
auto_eval_column_dict = get_default_auto_eval_column_dict()
|
76 |
-
|
77 |
-
for benchmark in
|
78 |
auto_eval_column_dict.append(
|
79 |
[benchmark.name, ColumnContent, ColumnContent(benchmark.value.col_name, "number", True)]
|
80 |
)
|
@@ -83,24 +75,19 @@ def make_autoevalcolumn(cls_name, benchmarks):
|
|
83 |
return make_dataclass(cls_name, auto_eval_column_dict, frozen=True)
|
84 |
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
return col_names, col_types
|
91 |
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
96 |
|
|
|
97 |
|
98 |
-
|
99 |
-
COL_NAME_RETRIEVAL_MODEL = "Retrieval Method"
|
100 |
-
COL_NAME_RERANKING_MODEL = "Reranking Model"
|
101 |
-
COL_NAME_RETRIEVAL_MODEL_LINK = "Retrieval Model LINK"
|
102 |
-
COL_NAME_RERANKING_MODEL_LINK = "Reranking Model LINK"
|
103 |
-
COL_NAME_RANK = "Rank 🏆"
|
104 |
-
COL_NAME_REVISION = "Revision"
|
105 |
-
COL_NAME_TIMESTAMP = "Submission Date"
|
106 |
-
COL_NAME_IS_ANONYMOUS = "Anonymous Submission"
|
|
|
1 |
from dataclasses import dataclass, make_dataclass
|
2 |
|
3 |
+
from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
|
4 |
|
5 |
+
|
6 |
+
def fields(raw_class):
|
7 |
return [v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__"]
|
8 |
|
9 |
|
|
|
19 |
never_hidden: bool = False
|
20 |
|
21 |
|
22 |
+
COL_NAME_AVG = "Average ⬆️"
|
23 |
+
COL_NAME_RETRIEVAL_MODEL = "Retrieval Method"
|
24 |
+
COL_NAME_RERANKING_MODEL = "Reranking Model"
|
25 |
+
COL_NAME_RETRIEVAL_MODEL_LINK = "Retrieval Model LINK"
|
26 |
+
COL_NAME_RERANKING_MODEL_LINK = "Reranking Model LINK"
|
27 |
+
COL_NAME_RANK = "Rank 🏆"
|
28 |
+
COL_NAME_REVISION = "Revision"
|
29 |
+
COL_NAME_TIMESTAMP = "Submission Date"
|
30 |
+
COL_NAME_IS_ANONYMOUS = "Anonymous Submission"
|
31 |
+
|
32 |
+
|
33 |
def get_default_auto_eval_column_dict():
|
34 |
auto_eval_column_dict = []
|
35 |
+
# Init
|
36 |
auto_eval_column_dict.append(
|
37 |
+
["rank", ColumnContent, ColumnContent(COL_NAME_RANK, "number", True)]
|
|
|
|
|
|
|
|
|
38 |
)
|
39 |
auto_eval_column_dict.append(
|
40 |
+
["retrieval_model", ColumnContent, ColumnContent(COL_NAME_RETRIEVAL_MODEL, "markdown", True, hidden=False, never_hidden=True)]
|
41 |
+
)
|
42 |
+
auto_eval_column_dict.append(
|
43 |
+
["reranking_model", ColumnContent, ColumnContent(COL_NAME_RERANKING_MODEL, "markdown", True, hidden=False, never_hidden=True)]
|
|
|
44 |
)
|
45 |
auto_eval_column_dict.append(
|
46 |
["revision", ColumnContent, ColumnContent(COL_NAME_REVISION, "markdown", True, never_hidden=True)]
|
|
|
48 |
auto_eval_column_dict.append(
|
49 |
["timestamp", ColumnContent, ColumnContent(COL_NAME_TIMESTAMP, "date", True, never_hidden=True)]
|
50 |
)
|
|
|
51 |
auto_eval_column_dict.append(
|
52 |
+
["average", ColumnContent, ColumnContent(COL_NAME_AVG, "number", True)]
|
53 |
+
)
|
54 |
+
auto_eval_column_dict.append(
|
55 |
+
["retrieval_model_link", ColumnContent, ColumnContent(COL_NAME_RETRIEVAL_MODEL_LINK, "markdown", False, hidden=True, never_hidden=False)]
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
)
|
57 |
auto_eval_column_dict.append(
|
58 |
+
["reranking_model_link", ColumnContent, ColumnContent(COL_NAME_RERANKING_MODEL_LINK, "markdown", False, hidden=True, never_hidden=False)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
)
|
60 |
auto_eval_column_dict.append(
|
61 |
["is_anonymous", ColumnContent, ColumnContent(COL_NAME_IS_ANONYMOUS, "bool", False, hidden=True)]
|
|
|
63 |
return auto_eval_column_dict
|
64 |
|
65 |
|
66 |
+
def make_autoevalcolumn(cls_name="BenchmarksQA", benchmarks=BenchmarksQA):
|
67 |
auto_eval_column_dict = get_default_auto_eval_column_dict()
|
68 |
+
## Leaderboard columns
|
69 |
+
for benchmark in benchmarks:
|
70 |
auto_eval_column_dict.append(
|
71 |
[benchmark.name, ColumnContent, ColumnContent(benchmark.value.col_name, "number", True)]
|
72 |
)
|
|
|
75 |
return make_dataclass(cls_name, auto_eval_column_dict, frozen=True)
|
76 |
|
77 |
|
78 |
+
AutoEvalColumnQA = make_autoevalcolumn(
|
79 |
+
"AutoEvalColumnQA", BenchmarksQA)
|
80 |
+
AutoEvalColumnLongDoc = make_autoevalcolumn(
|
81 |
+
"AutoEvalColumnLongDoc", BenchmarksLongDoc)
|
|
|
82 |
|
83 |
|
84 |
+
# Column selection
|
85 |
+
COLS_QA = [c.name for c in fields(AutoEvalColumnQA) if not c.hidden]
|
86 |
+
COLS_LONG_DOC = [c.name for c in fields(AutoEvalColumnLongDoc) if not c.hidden]
|
87 |
+
TYPES_QA = [c.type for c in fields(AutoEvalColumnQA) if not c.hidden]
|
88 |
+
TYPES_LONG_DOC = [c.type for c in fields(AutoEvalColumnLongDoc) if not c.hidden]
|
89 |
+
COLS_LITE = [c.name for c in fields(AutoEvalColumnQA) if c.displayed_by_default and not c.hidden]
|
90 |
|
91 |
+
QA_BENCHMARK_COLS = [t.value.col_name for t in BenchmarksQA]
|
92 |
|
93 |
+
LONG_DOC_BENCHMARK_COLS = [t.value.col_name for t in BenchmarksLongDoc]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/envs.py
CHANGED
@@ -6,9 +6,7 @@ from huggingface_hub import HfApi
|
|
6 |
# ----------------------------------
|
7 |
TOKEN = os.environ.get("TOKEN", "") # A read/write token for your org
|
8 |
|
9 |
-
OWNER =
|
10 |
-
"AIR-Bench" # Change to your org - don't forget to create a results and request dataset, with the correct format!
|
11 |
-
)
|
12 |
# ----------------------------------
|
13 |
|
14 |
REPO_ID = f"{OWNER}/leaderboard"
|
@@ -17,51 +15,10 @@ RESULTS_REPO = f"{OWNER}/eval_results"
|
|
17 |
# repo for submitting the evaluation
|
18 |
SEARCH_RESULTS_REPO = f"{OWNER}/search_results"
|
19 |
|
20 |
-
# If you
|
21 |
CACHE_PATH = os.getenv("HF_HOME", ".")
|
22 |
|
23 |
# Local caches
|
24 |
EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval_results")
|
25 |
|
26 |
API = HfApi(token=TOKEN)
|
27 |
-
|
28 |
-
BENCHMARK_VERSION_LIST = [
|
29 |
-
"AIR-Bench_24.04",
|
30 |
-
"AIR-Bench_24.05",
|
31 |
-
]
|
32 |
-
|
33 |
-
LATEST_BENCHMARK_VERSION = BENCHMARK_VERSION_LIST[-1] # Change to the latest benchmark version
|
34 |
-
DEFAULT_METRIC_QA = "ndcg_at_10"
|
35 |
-
DEFAULT_METRIC_LONG_DOC = "recall_at_10"
|
36 |
-
METRIC_LIST = [
|
37 |
-
"ndcg_at_1",
|
38 |
-
"ndcg_at_3",
|
39 |
-
"ndcg_at_5",
|
40 |
-
"ndcg_at_10",
|
41 |
-
"ndcg_at_100",
|
42 |
-
"ndcg_at_1000",
|
43 |
-
"map_at_1",
|
44 |
-
"map_at_3",
|
45 |
-
"map_at_5",
|
46 |
-
"map_at_10",
|
47 |
-
"map_at_100",
|
48 |
-
"map_at_1000",
|
49 |
-
"recall_at_1",
|
50 |
-
"recall_at_3",
|
51 |
-
"recall_at_5",
|
52 |
-
"recall_at_10",
|
53 |
-
"recall_at_100",
|
54 |
-
"recall_at_1000",
|
55 |
-
"precision_at_1",
|
56 |
-
"precision_at_3",
|
57 |
-
"precision_at_5",
|
58 |
-
"precision_at_10",
|
59 |
-
"precision_at_100",
|
60 |
-
"precision_at_1000",
|
61 |
-
"mrr_at_1",
|
62 |
-
"mrr_at_3",
|
63 |
-
"mrr_at_5",
|
64 |
-
"mrr_at_10",
|
65 |
-
"mrr_at_100",
|
66 |
-
"mrr_at_1000",
|
67 |
-
]
|
|
|
6 |
# ----------------------------------
|
7 |
TOKEN = os.environ.get("TOKEN", "") # A read/write token for your org
|
8 |
|
9 |
+
OWNER = "AIR-Bench" # "nan" # Change to your org - don't forget to create a results and request dataset, with the correct format!
|
|
|
|
|
10 |
# ----------------------------------
|
11 |
|
12 |
REPO_ID = f"{OWNER}/leaderboard"
|
|
|
15 |
# repo for submitting the evaluation
|
16 |
SEARCH_RESULTS_REPO = f"{OWNER}/search_results"
|
17 |
|
18 |
+
# If you setup a cache later, just change HF_HOME
|
19 |
CACHE_PATH = os.getenv("HF_HOME", ".")
|
20 |
|
21 |
# Local caches
|
22 |
EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval_results")
|
23 |
|
24 |
API = HfApi(token=TOKEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/loaders.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
import os.path
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Dict, List, Union
|
4 |
-
|
5 |
-
import pandas as pd
|
6 |
-
|
7 |
-
from src.columns import COL_NAME_IS_ANONYMOUS, COL_NAME_REVISION, COL_NAME_TIMESTAMP
|
8 |
-
from src.envs import BENCHMARK_VERSION_LIST, DEFAULT_METRIC_LONG_DOC, DEFAULT_METRIC_QA
|
9 |
-
from src.models import FullEvalResult, LeaderboardDataStore, TaskType, get_safe_name
|
10 |
-
from src.utils import get_default_cols, get_leaderboard_df, reset_rank
|
11 |
-
|
12 |
-
pd.options.mode.copy_on_write = True
|
13 |
-
|
14 |
-
|
15 |
-
def load_raw_eval_results(results_path: Union[Path, str]) -> List[FullEvalResult]:
|
16 |
-
"""
|
17 |
-
Load the evaluation results from a json file
|
18 |
-
"""
|
19 |
-
model_result_filepaths = []
|
20 |
-
for root, dirs, files in os.walk(results_path):
|
21 |
-
if len(files) == 0:
|
22 |
-
continue
|
23 |
-
|
24 |
-
# select the latest results
|
25 |
-
for file in files:
|
26 |
-
if not (file.startswith("results") and file.endswith(".json")):
|
27 |
-
print(f"skip {file}")
|
28 |
-
continue
|
29 |
-
model_result_filepaths.append(os.path.join(root, file))
|
30 |
-
|
31 |
-
eval_results = {}
|
32 |
-
for model_result_filepath in model_result_filepaths:
|
33 |
-
# create evaluation results
|
34 |
-
try:
|
35 |
-
eval_result = FullEvalResult.init_from_json_file(model_result_filepath)
|
36 |
-
except UnicodeDecodeError:
|
37 |
-
print(f"loading file failed. {model_result_filepath}")
|
38 |
-
continue
|
39 |
-
print(f"file loaded: {model_result_filepath}")
|
40 |
-
timestamp = eval_result.timestamp
|
41 |
-
eval_results[timestamp] = eval_result
|
42 |
-
|
43 |
-
results = []
|
44 |
-
for k, v in eval_results.items():
|
45 |
-
try:
|
46 |
-
v.to_dict()
|
47 |
-
results.append(v)
|
48 |
-
except KeyError:
|
49 |
-
print(f"loading failed: {k}")
|
50 |
-
continue
|
51 |
-
return results
|
52 |
-
|
53 |
-
|
54 |
-
def load_leaderboard_datastore(file_path, version) -> LeaderboardDataStore:
|
55 |
-
ds = LeaderboardDataStore(version, get_safe_name(version))
|
56 |
-
ds.raw_data = load_raw_eval_results(file_path)
|
57 |
-
print(f"raw data: {len(ds.raw_data)}")
|
58 |
-
|
59 |
-
ds.qa_raw_df = get_leaderboard_df(ds, TaskType.qa, DEFAULT_METRIC_QA)
|
60 |
-
print(f"QA data loaded: {ds.qa_raw_df.shape}")
|
61 |
-
ds.qa_fmt_df = ds.qa_raw_df.copy()
|
62 |
-
qa_cols, ds.qa_types = get_default_cols(TaskType.qa, ds.slug, add_fix_cols=True)
|
63 |
-
# by default, drop the anonymous submissions
|
64 |
-
ds.qa_fmt_df = ds.qa_fmt_df[~ds.qa_fmt_df[COL_NAME_IS_ANONYMOUS]][qa_cols]
|
65 |
-
# reset the rank after dropping the anonymous submissions
|
66 |
-
ds.qa_fmt_df = reset_rank(ds.qa_fmt_df)
|
67 |
-
ds.qa_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
68 |
-
|
69 |
-
ds.doc_raw_df = get_leaderboard_df(ds, TaskType.long_doc, DEFAULT_METRIC_LONG_DOC)
|
70 |
-
print(f"Long-Doc data loaded: {len(ds.doc_raw_df)}")
|
71 |
-
ds.doc_fmt_df = ds.doc_raw_df.copy()
|
72 |
-
doc_cols, ds.doc_types = get_default_cols(TaskType.long_doc, ds.slug, add_fix_cols=True)
|
73 |
-
# by default, drop the anonymous submissions
|
74 |
-
ds.doc_fmt_df = ds.doc_fmt_df[~ds.doc_fmt_df[COL_NAME_IS_ANONYMOUS]][doc_cols]
|
75 |
-
# reset the rank after dropping the anonymous submissions
|
76 |
-
ds.doc_fmt_df = reset_rank(ds.doc_fmt_df)
|
77 |
-
ds.doc_fmt_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
78 |
-
|
79 |
-
ds.reranking_models = sorted(list(frozenset([eval_result.reranking_model for eval_result in ds.raw_data])))
|
80 |
-
return ds
|
81 |
-
|
82 |
-
|
83 |
-
def load_eval_results(file_path: Union[str, Path]) -> Dict[str, LeaderboardDataStore]:
|
84 |
-
output = {}
|
85 |
-
for version in BENCHMARK_VERSION_LIST:
|
86 |
-
fn = f"{file_path}/{version}"
|
87 |
-
output[version] = load_leaderboard_datastore(fn, version)
|
88 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/{models.py → read_evals.py}
RENAMED
@@ -1,21 +1,38 @@
|
|
1 |
import json
|
|
|
2 |
from collections import defaultdict
|
3 |
from dataclasses import dataclass
|
4 |
-
from enum import Enum
|
5 |
from typing import List
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
-
from src.
|
10 |
-
|
11 |
COL_NAME_RERANKING_MODEL,
|
12 |
-
COL_NAME_RERANKING_MODEL_LINK,
|
13 |
COL_NAME_RETRIEVAL_MODEL,
|
|
|
14 |
COL_NAME_RETRIEVAL_MODEL_LINK,
|
15 |
COL_NAME_REVISION,
|
16 |
COL_NAME_TIMESTAMP,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
@dataclass
|
21 |
class EvalResult:
|
@@ -23,7 +40,6 @@ class EvalResult:
|
|
23 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different
|
24 |
domains, languages, and datasets
|
25 |
"""
|
26 |
-
|
27 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]_[metric]
|
28 |
retrieval_model: str
|
29 |
reranking_model: str
|
@@ -40,7 +56,6 @@ class FullEvalResult:
|
|
40 |
"""
|
41 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different tasks
|
42 |
"""
|
43 |
-
|
44 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]
|
45 |
retrieval_model: str
|
46 |
reranking_model: str
|
@@ -64,6 +79,7 @@ class FullEvalResult:
|
|
64 |
result_list = []
|
65 |
retrieval_model_link = ""
|
66 |
reranking_model_link = ""
|
|
|
67 |
for item in model_data:
|
68 |
config = item.get("config", {})
|
69 |
# eval results for different metrics
|
@@ -82,26 +98,24 @@ class FullEvalResult:
|
|
82 |
metric=config["metric"],
|
83 |
timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
|
84 |
revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e"),
|
85 |
-
is_anonymous=config.get("is_anonymous", False)
|
86 |
)
|
87 |
result_list.append(eval_result)
|
88 |
-
eval_result = result_list[0]
|
89 |
return cls(
|
90 |
-
eval_name=f"{
|
91 |
-
retrieval_model=
|
92 |
-
reranking_model=
|
93 |
retrieval_model_link=retrieval_model_link,
|
94 |
reranking_model_link=reranking_model_link,
|
95 |
results=result_list,
|
96 |
-
timestamp=
|
97 |
-
revision=
|
98 |
-
is_anonymous=
|
99 |
)
|
100 |
|
101 |
-
def to_dict(self, task=
|
102 |
"""
|
103 |
-
Convert the results in all the EvalResults over different tasks and metrics.
|
104 |
-
The output is a list of dict compatible with the dataframe UI
|
105 |
"""
|
106 |
results = defaultdict(dict)
|
107 |
for eval_result in self.results:
|
@@ -109,66 +123,103 @@ class FullEvalResult:
|
|
109 |
continue
|
110 |
if eval_result.task != task:
|
111 |
continue
|
112 |
-
eval_name = eval_result.eval_name
|
113 |
-
results[eval_name][
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
results[eval_name][
|
118 |
-
|
119 |
-
|
120 |
-
results[eval_name][
|
121 |
-
results[eval_name][
|
122 |
-
|
123 |
-
|
124 |
-
results[eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
|
125 |
-
|
126 |
for result in eval_result.results:
|
127 |
# add result for each domain, language, and dataset
|
128 |
domain = result["domain"]
|
129 |
lang = result["lang"]
|
130 |
dataset = result["dataset"]
|
131 |
value = result["value"] * 100
|
132 |
-
if dataset ==
|
133 |
benchmark_name = f"{domain}_{lang}"
|
134 |
else:
|
135 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
136 |
-
results[eval_name][get_safe_name(benchmark_name)] = value
|
137 |
return [v for v in results.values()]
|
138 |
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
import os.path
|
3 |
from collections import defaultdict
|
4 |
from dataclasses import dataclass
|
|
|
5 |
from typing import List
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
+
from src.benchmarks import get_safe_name
|
10 |
+
from src.display.utils import (
|
11 |
COL_NAME_RERANKING_MODEL,
|
|
|
12 |
COL_NAME_RETRIEVAL_MODEL,
|
13 |
+
COL_NAME_RERANKING_MODEL_LINK,
|
14 |
COL_NAME_RETRIEVAL_MODEL_LINK,
|
15 |
COL_NAME_REVISION,
|
16 |
COL_NAME_TIMESTAMP,
|
17 |
+
COL_NAME_IS_ANONYMOUS,
|
18 |
+
COLS_QA,
|
19 |
+
QA_BENCHMARK_COLS,
|
20 |
+
COLS_LONG_DOC,
|
21 |
+
LONG_DOC_BENCHMARK_COLS,
|
22 |
+
COL_NAME_AVG,
|
23 |
+
COL_NAME_RANK
|
24 |
)
|
25 |
|
26 |
+
from src.display.formatting import make_clickable_model
|
27 |
+
|
28 |
+
pd.options.mode.copy_on_write = True
|
29 |
+
|
30 |
+
def calculate_mean(row):
|
31 |
+
if pd.isna(row).any():
|
32 |
+
return 0
|
33 |
+
else:
|
34 |
+
return row.mean()
|
35 |
+
|
36 |
|
37 |
@dataclass
|
38 |
class EvalResult:
|
|
|
40 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different
|
41 |
domains, languages, and datasets
|
42 |
"""
|
|
|
43 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]_[metric]
|
44 |
retrieval_model: str
|
45 |
reranking_model: str
|
|
|
56 |
"""
|
57 |
Evaluation result of a single embedding model with a specific reranking model on benchmarks over different tasks
|
58 |
"""
|
|
|
59 |
eval_name: str # name of the evaluation, [retrieval_model]_[reranking_model]
|
60 |
retrieval_model: str
|
61 |
reranking_model: str
|
|
|
79 |
result_list = []
|
80 |
retrieval_model_link = ""
|
81 |
reranking_model_link = ""
|
82 |
+
revision = ""
|
83 |
for item in model_data:
|
84 |
config = item.get("config", {})
|
85 |
# eval results for different metrics
|
|
|
98 |
metric=config["metric"],
|
99 |
timestamp=config.get("timestamp", "2024-05-12T12:24:02Z"),
|
100 |
revision=config.get("revision", "3a2ba9dcad796a48a02ca1147557724e"),
|
101 |
+
is_anonymous=config.get("is_anonymous", False)
|
102 |
)
|
103 |
result_list.append(eval_result)
|
|
|
104 |
return cls(
|
105 |
+
eval_name=f"{result_list[0].retrieval_model}_{result_list[0].reranking_model}",
|
106 |
+
retrieval_model=result_list[0].retrieval_model,
|
107 |
+
reranking_model=result_list[0].reranking_model,
|
108 |
retrieval_model_link=retrieval_model_link,
|
109 |
reranking_model_link=reranking_model_link,
|
110 |
results=result_list,
|
111 |
+
timestamp=result_list[0].timestamp,
|
112 |
+
revision=result_list[0].revision,
|
113 |
+
is_anonymous=result_list[0].is_anonymous
|
114 |
)
|
115 |
|
116 |
+
def to_dict(self, task='qa', metric='ndcg_at_3') -> List:
|
117 |
"""
|
118 |
+
Convert the results in all the EvalResults over different tasks and metrics. The output is a list of dict compatible with the dataframe UI
|
|
|
119 |
"""
|
120 |
results = defaultdict(dict)
|
121 |
for eval_result in self.results:
|
|
|
123 |
continue
|
124 |
if eval_result.task != task:
|
125 |
continue
|
126 |
+
results[eval_result.eval_name]["eval_name"] = eval_result.eval_name
|
127 |
+
results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL] = (
|
128 |
+
make_clickable_model(self.retrieval_model, self.retrieval_model_link))
|
129 |
+
results[eval_result.eval_name][COL_NAME_RERANKING_MODEL] = (
|
130 |
+
make_clickable_model(self.reranking_model, self.reranking_model_link))
|
131 |
+
results[eval_result.eval_name][COL_NAME_RETRIEVAL_MODEL_LINK] = self.retrieval_model_link
|
132 |
+
results[eval_result.eval_name][COL_NAME_RERANKING_MODEL_LINK] = self.reranking_model_link
|
133 |
+
results[eval_result.eval_name][COL_NAME_REVISION] = self.revision
|
134 |
+
results[eval_result.eval_name][COL_NAME_TIMESTAMP] = self.timestamp
|
135 |
+
results[eval_result.eval_name][COL_NAME_IS_ANONYMOUS] = self.is_anonymous
|
136 |
+
|
137 |
+
# print(f'result loaded: {eval_result.eval_name}')
|
|
|
|
|
138 |
for result in eval_result.results:
|
139 |
# add result for each domain, language, and dataset
|
140 |
domain = result["domain"]
|
141 |
lang = result["lang"]
|
142 |
dataset = result["dataset"]
|
143 |
value = result["value"] * 100
|
144 |
+
if dataset == 'default':
|
145 |
benchmark_name = f"{domain}_{lang}"
|
146 |
else:
|
147 |
benchmark_name = f"{domain}_{lang}_{dataset}"
|
148 |
+
results[eval_result.eval_name][get_safe_name(benchmark_name)] = value
|
149 |
return [v for v in results.values()]
|
150 |
|
151 |
|
152 |
+
def get_raw_eval_results(results_path: str) -> List[FullEvalResult]:
|
153 |
+
"""
|
154 |
+
Load the evaluation results from a json file
|
155 |
+
"""
|
156 |
+
model_result_filepaths = []
|
157 |
+
for root, dirs, files in os.walk(results_path):
|
158 |
+
if len(files) == 0:
|
159 |
+
continue
|
160 |
+
|
161 |
+
# select the latest results
|
162 |
+
for file in files:
|
163 |
+
if not (file.startswith("results") and file.endswith(".json")):
|
164 |
+
print(f'skip {file}')
|
165 |
+
continue
|
166 |
+
model_result_filepaths.append(os.path.join(root, file))
|
167 |
+
|
168 |
+
eval_results = {}
|
169 |
+
for model_result_filepath in model_result_filepaths:
|
170 |
+
# create evaluation results
|
171 |
+
try:
|
172 |
+
eval_result = FullEvalResult.init_from_json_file(model_result_filepath)
|
173 |
+
except UnicodeDecodeError as e:
|
174 |
+
print(f"loading file failed. {model_result_filepath}")
|
175 |
+
continue
|
176 |
+
print(f'file loaded: {model_result_filepath}')
|
177 |
+
eval_name = eval_result.eval_name
|
178 |
+
eval_results[eval_name] = eval_result
|
179 |
+
|
180 |
+
results = []
|
181 |
+
for k, v in eval_results.items():
|
182 |
+
try:
|
183 |
+
v.to_dict()
|
184 |
+
results.append(v)
|
185 |
+
except KeyError:
|
186 |
+
print(f"loading failed: {k}")
|
187 |
+
continue
|
188 |
+
return results
|
189 |
+
|
190 |
+
|
191 |
+
def get_leaderboard_df(raw_data: List[FullEvalResult], task: str, metric: str) -> pd.DataFrame:
|
192 |
+
"""
|
193 |
+
Creates a dataframe from all the individual experiment results
|
194 |
+
"""
|
195 |
+
cols = [COL_NAME_IS_ANONYMOUS, ]
|
196 |
+
if task == "qa":
|
197 |
+
cols += COLS_QA
|
198 |
+
benchmark_cols = QA_BENCHMARK_COLS
|
199 |
+
elif task == "long-doc":
|
200 |
+
cols += COLS_LONG_DOC
|
201 |
+
benchmark_cols = LONG_DOC_BENCHMARK_COLS
|
202 |
+
else:
|
203 |
+
raise NotImplemented
|
204 |
+
all_data_json = []
|
205 |
+
for v in raw_data:
|
206 |
+
all_data_json += v.to_dict(task=task, metric=metric)
|
207 |
+
df = pd.DataFrame.from_records(all_data_json)
|
208 |
+
# print(f'dataframe created: {df.shape}')
|
209 |
+
|
210 |
+
_benchmark_cols = frozenset(benchmark_cols).intersection(frozenset(df.columns.to_list()))
|
211 |
+
|
212 |
+
# calculate the average score for selected benchmarks
|
213 |
+
df[COL_NAME_AVG] = df[list(_benchmark_cols)].apply(calculate_mean, axis=1).round(decimals=2)
|
214 |
+
df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
|
215 |
+
df.reset_index(inplace=True, drop=True)
|
216 |
+
|
217 |
+
_cols = frozenset(cols).intersection(frozenset(df.columns.to_list()))
|
218 |
+
df = df[_cols].round(decimals=2)
|
219 |
+
|
220 |
+
# filter out if any of the benchmarks have not been produced
|
221 |
+
df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
|
222 |
+
|
223 |
+
# shorten the revision
|
224 |
+
df[COL_NAME_REVISION] = df[COL_NAME_REVISION].str[:6]
|
225 |
+
return df
|
src/utils.py
CHANGED
@@ -1,37 +1,24 @@
|
|
1 |
-
import hashlib
|
2 |
import json
|
3 |
-
import
|
4 |
from datetime import datetime, timezone
|
5 |
from pathlib import Path
|
|
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
-
from src.benchmarks import
|
10 |
-
from src.
|
11 |
-
|
12 |
-
COL_NAME_IS_ANONYMOUS,
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
COL_NAME_TIMESTAMP,
|
18 |
-
get_default_col_names_and_types,
|
19 |
-
get_fixed_col_names_and_types,
|
20 |
-
)
|
21 |
-
from src.envs import API, LATEST_BENCHMARK_VERSION, SEARCH_RESULTS_REPO
|
22 |
-
from src.models import TaskType, get_safe_name
|
23 |
-
|
24 |
-
|
25 |
-
def calculate_mean(row):
|
26 |
-
if pd.isna(row).any():
|
27 |
-
return -1
|
28 |
-
else:
|
29 |
-
return row.mean()
|
30 |
|
31 |
|
32 |
def remove_html(input_str):
|
33 |
# Regular expression for finding HTML tags
|
34 |
-
clean = re.sub(r
|
35 |
return clean
|
36 |
|
37 |
|
@@ -68,152 +55,160 @@ def search_table(df: pd.DataFrame, query: str) -> pd.DataFrame:
|
|
68 |
return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
|
69 |
|
70 |
|
71 |
-
def get_default_cols(task:
|
72 |
cols = []
|
73 |
types = []
|
74 |
-
if task ==
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
else:
|
79 |
-
raise
|
80 |
-
cols_list, types_list = get_default_col_names_and_types(benchmarks)
|
81 |
-
benchmark_list = [c.value.col_name for c in list(benchmarks.value)]
|
82 |
for col_name, col_type in zip(cols_list, types_list):
|
83 |
if col_name not in benchmark_list:
|
84 |
continue
|
|
|
|
|
85 |
cols.append(col_name)
|
86 |
types.append(col_type)
|
|
|
87 |
if add_fix_cols:
|
88 |
_cols = []
|
89 |
_types = []
|
90 |
-
fixed_cols, fixed_cols_types = get_fixed_col_names_and_types()
|
91 |
for col_name, col_type in zip(cols, types):
|
92 |
-
if col_name in
|
93 |
continue
|
94 |
_cols.append(col_name)
|
95 |
_types.append(col_type)
|
96 |
-
cols =
|
97 |
-
types =
|
98 |
return cols, types
|
99 |
|
100 |
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
selected_cols = []
|
104 |
for c in cols:
|
105 |
-
if task ==
|
106 |
-
eval_col =
|
107 |
-
elif task ==
|
108 |
-
eval_col =
|
109 |
-
|
110 |
-
raise NotImplementedError
|
111 |
-
if eval_col.domain not in domains:
|
112 |
continue
|
113 |
-
if eval_col.lang not in
|
114 |
continue
|
115 |
selected_cols.append(c)
|
116 |
# We use COLS to maintain sorting
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
def select_columns(
|
121 |
-
df: pd.DataFrame,
|
122 |
-
domains: list,
|
123 |
-
languages: list,
|
124 |
-
task: TaskType = TaskType.qa,
|
125 |
-
reset_ranking: bool = True,
|
126 |
-
version_slug: str = None,
|
127 |
-
) -> pd.DataFrame:
|
128 |
-
selected_cols = get_selected_cols(task, version_slug, domains, languages)
|
129 |
-
fixed_cols, _ = get_fixed_col_names_and_types()
|
130 |
-
filtered_df = df[fixed_cols + selected_cols]
|
131 |
-
filtered_df.replace({"": pd.NA}, inplace=True)
|
132 |
if reset_ranking:
|
133 |
filtered_df[COL_NAME_AVG] = filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
|
134 |
filtered_df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
|
135 |
filtered_df.reset_index(inplace=True, drop=True)
|
136 |
filtered_df = reset_rank(filtered_df)
|
|
|
137 |
return filtered_df
|
138 |
|
139 |
|
140 |
-
def
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
show_revision_and_timestamp: bool = False,
|
151 |
):
|
152 |
-
filtered_df =
|
153 |
if not show_anonymous:
|
154 |
filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
|
155 |
filtered_df = filter_models(filtered_df, reranking_query)
|
156 |
filtered_df = filter_queries(query, filtered_df)
|
157 |
-
filtered_df = select_columns(filtered_df, domains, langs, task, reset_ranking
|
158 |
if not show_revision_and_timestamp:
|
159 |
filtered_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
160 |
return filtered_df
|
161 |
|
162 |
|
163 |
-
def
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
reset_ranking: bool = True,
|
173 |
):
|
174 |
-
return
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
|
188 |
def update_metric(
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
) -> pd.DataFrame:
|
199 |
-
if task ==
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
217 |
|
218 |
|
219 |
def upload_file(filepath: str):
|
@@ -223,6 +218,7 @@ def upload_file(filepath: str):
|
|
223 |
return filepath
|
224 |
|
225 |
|
|
|
226 |
def get_iso_format_timestamp():
|
227 |
# Get the current timestamp with UTC as the timezone
|
228 |
current_timestamp = datetime.now(timezone.utc)
|
@@ -231,15 +227,15 @@ def get_iso_format_timestamp():
|
|
231 |
current_timestamp = current_timestamp.replace(microsecond=0)
|
232 |
|
233 |
# Convert to ISO 8601 format and replace the offset with 'Z'
|
234 |
-
iso_format_timestamp = current_timestamp.isoformat().replace(
|
235 |
-
filename_friendly_timestamp = current_timestamp.strftime(
|
236 |
return iso_format_timestamp, filename_friendly_timestamp
|
237 |
|
238 |
|
239 |
def calculate_file_md5(file_path):
|
240 |
md5 = hashlib.md5()
|
241 |
|
242 |
-
with open(file_path,
|
243 |
while True:
|
244 |
data = f.read(4096)
|
245 |
if not data:
|
@@ -250,14 +246,13 @@ def calculate_file_md5(file_path):
|
|
250 |
|
251 |
|
252 |
def submit_results(
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
):
|
261 |
if not filepath.endswith(".zip"):
|
262 |
return styled_error(f"file uploading aborted. wrong file type: {filepath}")
|
263 |
|
@@ -270,13 +265,11 @@ def submit_results(
|
|
270 |
if not model_url.startswith("https://") and not model_url.startswith("http://"):
|
271 |
# TODO: retrieve the model page and find the model name on the page
|
272 |
return styled_error(
|
273 |
-
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
|
274 |
-
)
|
275 |
if reranking_model != "NoReranker":
|
276 |
if not reranking_model_url.startswith("https://") and not reranking_model_url.startswith("http://"):
|
277 |
return styled_error(
|
278 |
-
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}"
|
279 |
-
)
|
280 |
|
281 |
# rename the uploaded file
|
282 |
input_fp = Path(filepath)
|
@@ -286,15 +279,14 @@ def submit_results(
|
|
286 |
input_folder_path = input_fp.parent
|
287 |
|
288 |
if not reranking_model:
|
289 |
-
reranking_model =
|
290 |
-
|
291 |
API.upload_file(
|
292 |
path_or_fileobj=filepath,
|
293 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
|
294 |
repo_id=SEARCH_RESULTS_REPO,
|
295 |
repo_type="dataset",
|
296 |
-
commit_message=f"feat: submit {model} to evaluate"
|
297 |
-
)
|
298 |
|
299 |
output_config_fn = f"{output_fn.removesuffix('.zip')}.json"
|
300 |
output_config = {
|
@@ -305,7 +297,7 @@ def submit_results(
|
|
305 |
"version": f"{version}",
|
306 |
"is_anonymous": is_anonymous,
|
307 |
"revision": f"{revision}",
|
308 |
-
"timestamp": f"{timestamp_config}"
|
309 |
}
|
310 |
with open(input_folder_path / output_config_fn, "w") as f:
|
311 |
json.dump(output_config, f, indent=4, ensure_ascii=False)
|
@@ -314,8 +306,7 @@ def submit_results(
|
|
314 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_config_fn}",
|
315 |
repo_id=SEARCH_RESULTS_REPO,
|
316 |
repo_type="dataset",
|
317 |
-
commit_message=f"feat: submit {model} + {reranking_model} config"
|
318 |
-
)
|
319 |
return styled_message(
|
320 |
f"Thanks for submission!\n"
|
321 |
f"Retrieval method: {model}\nReranking model: {reranking_model}\nSubmission revision: {revision}"
|
@@ -325,125 +316,3 @@ def submit_results(
|
|
325 |
def reset_rank(df):
|
326 |
df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
|
327 |
return df
|
328 |
-
|
329 |
-
|
330 |
-
def get_leaderboard_df(datastore, task: TaskType, metric: str) -> pd.DataFrame:
|
331 |
-
"""
|
332 |
-
Creates a dataframe from all the individual experiment results
|
333 |
-
"""
|
334 |
-
# load the selected metrics into a DataFrame from the raw json
|
335 |
-
all_data_json = []
|
336 |
-
for v in datastore.raw_data:
|
337 |
-
all_data_json += v.to_dict(task=task.value, metric=metric)
|
338 |
-
df = pd.DataFrame.from_records(all_data_json)
|
339 |
-
|
340 |
-
# calculate the average scores for selected task
|
341 |
-
if task == TaskType.qa:
|
342 |
-
benchmarks = QABenchmarks[datastore.slug]
|
343 |
-
elif task == TaskType.long_doc:
|
344 |
-
benchmarks = LongDocBenchmarks[datastore.slug]
|
345 |
-
else:
|
346 |
-
raise NotImplementedError
|
347 |
-
valid_cols = frozenset(df.columns.to_list())
|
348 |
-
benchmark_cols = []
|
349 |
-
for t in list(benchmarks.value):
|
350 |
-
if t.value.col_name not in valid_cols:
|
351 |
-
continue
|
352 |
-
benchmark_cols.append(t.value.col_name)
|
353 |
-
|
354 |
-
# filter out the columns that are not in the data
|
355 |
-
df[COL_NAME_AVG] = df[list(benchmark_cols)].apply(calculate_mean, axis=1).round(decimals=2)
|
356 |
-
df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
|
357 |
-
df.reset_index(inplace=True, drop=True)
|
358 |
-
|
359 |
-
# filter out columns that are not in the data
|
360 |
-
display_cols = [COL_NAME_IS_ANONYMOUS, COL_NAME_AVG]
|
361 |
-
default_cols, _ = get_default_col_names_and_types(benchmarks)
|
362 |
-
for col in default_cols:
|
363 |
-
if col in valid_cols:
|
364 |
-
display_cols.append(col)
|
365 |
-
df = df[display_cols].round(decimals=2)
|
366 |
-
|
367 |
-
# rank the scores
|
368 |
-
df = reset_rank(df)
|
369 |
-
|
370 |
-
# shorten the revision
|
371 |
-
df[COL_NAME_REVISION] = df[COL_NAME_REVISION].str[:6]
|
372 |
-
|
373 |
-
return df
|
374 |
-
|
375 |
-
|
376 |
-
def set_listeners(
|
377 |
-
task: TaskType,
|
378 |
-
target_df,
|
379 |
-
source_df,
|
380 |
-
search_bar,
|
381 |
-
version,
|
382 |
-
selected_domains,
|
383 |
-
selected_langs,
|
384 |
-
selected_rerankings,
|
385 |
-
show_anonymous,
|
386 |
-
show_revision_and_timestamp,
|
387 |
-
):
|
388 |
-
if task == TaskType.qa:
|
389 |
-
update_table_func = update_qa_df_elem
|
390 |
-
elif task == TaskType.long_doc:
|
391 |
-
update_table_func = update_doc_df_elem
|
392 |
-
else:
|
393 |
-
raise NotImplementedError
|
394 |
-
selector_list = [selected_domains, selected_langs, selected_rerankings, search_bar, show_anonymous]
|
395 |
-
search_bar_args = [
|
396 |
-
source_df,
|
397 |
-
version,
|
398 |
-
] + selector_list
|
399 |
-
selector_args = (
|
400 |
-
[version, source_df]
|
401 |
-
+ selector_list
|
402 |
-
+ [
|
403 |
-
show_revision_and_timestamp,
|
404 |
-
]
|
405 |
-
)
|
406 |
-
# Set search_bar listener
|
407 |
-
search_bar.submit(update_table_func, search_bar_args, target_df)
|
408 |
-
|
409 |
-
# Set column-wise listener
|
410 |
-
for selector in selector_list:
|
411 |
-
selector.change(
|
412 |
-
update_table_func,
|
413 |
-
selector_args,
|
414 |
-
target_df,
|
415 |
-
queue=True,
|
416 |
-
)
|
417 |
-
|
418 |
-
|
419 |
-
def update_qa_df_elem(
|
420 |
-
version: str,
|
421 |
-
hidden_df: pd.DataFrame,
|
422 |
-
domains: list,
|
423 |
-
langs: list,
|
424 |
-
reranking_query: list,
|
425 |
-
query: str,
|
426 |
-
show_anonymous: bool,
|
427 |
-
show_revision_and_timestamp: bool = False,
|
428 |
-
reset_ranking: bool = True,
|
429 |
-
):
|
430 |
-
return _update_df_elem(
|
431 |
-
TaskType.qa,
|
432 |
-
version,
|
433 |
-
hidden_df,
|
434 |
-
domains,
|
435 |
-
langs,
|
436 |
-
reranking_query,
|
437 |
-
query,
|
438 |
-
show_anonymous,
|
439 |
-
reset_ranking,
|
440 |
-
show_revision_and_timestamp,
|
441 |
-
)
|
442 |
-
|
443 |
-
|
444 |
-
def styled_error(error):
|
445 |
-
return f"<p style='color: red; font-size: 20px; text-align: center;'>{error}</p>"
|
446 |
-
|
447 |
-
|
448 |
-
def styled_message(message):
|
449 |
-
return f"<p style='color: green; font-size: 20px; text-align: center;'>{message}</p>"
|
|
|
|
|
1 |
import json
|
2 |
+
import hashlib
|
3 |
from datetime import datetime, timezone
|
4 |
from pathlib import Path
|
5 |
+
from typing import List
|
6 |
|
7 |
import pandas as pd
|
8 |
|
9 |
+
from src.benchmarks import BENCHMARK_COLS_QA, BENCHMARK_COLS_LONG_DOC, BenchmarksQA, BenchmarksLongDoc
|
10 |
+
from src.display.formatting import styled_message, styled_error
|
11 |
+
from src.display.utils import COLS_QA, TYPES_QA, COLS_LONG_DOC, TYPES_LONG_DOC, COL_NAME_RANK, COL_NAME_AVG, \
|
12 |
+
COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL, COL_NAME_IS_ANONYMOUS, COL_NAME_TIMESTAMP, COL_NAME_REVISION, get_default_auto_eval_column_dict
|
13 |
+
from src.envs import API, SEARCH_RESULTS_REPO
|
14 |
+
from src.read_evals import FullEvalResult, get_leaderboard_df, calculate_mean
|
15 |
+
|
16 |
+
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def remove_html(input_str):
|
20 |
# Regular expression for finding HTML tags
|
21 |
+
clean = re.sub(r'<.*?>', '', input_str)
|
22 |
return clean
|
23 |
|
24 |
|
|
|
55 |
return df[(df[COL_NAME_RETRIEVAL_MODEL].str.contains(query, case=False))]
|
56 |
|
57 |
|
58 |
+
def get_default_cols(task: str, columns: list=[], add_fix_cols: bool=True) -> list:
|
59 |
cols = []
|
60 |
types = []
|
61 |
+
if task == "qa":
|
62 |
+
cols_list = COLS_QA
|
63 |
+
types_list = TYPES_QA
|
64 |
+
benchmark_list = BENCHMARK_COLS_QA
|
65 |
+
elif task == "long-doc":
|
66 |
+
cols_list = COLS_LONG_DOC
|
67 |
+
types_list = TYPES_LONG_DOC
|
68 |
+
benchmark_list = BENCHMARK_COLS_LONG_DOC
|
69 |
else:
|
70 |
+
raise NotImplemented
|
|
|
|
|
71 |
for col_name, col_type in zip(cols_list, types_list):
|
72 |
if col_name not in benchmark_list:
|
73 |
continue
|
74 |
+
if len(columns) > 0 and col_name not in columns:
|
75 |
+
continue
|
76 |
cols.append(col_name)
|
77 |
types.append(col_type)
|
78 |
+
|
79 |
if add_fix_cols:
|
80 |
_cols = []
|
81 |
_types = []
|
|
|
82 |
for col_name, col_type in zip(cols, types):
|
83 |
+
if col_name in FIXED_COLS:
|
84 |
continue
|
85 |
_cols.append(col_name)
|
86 |
_types.append(col_type)
|
87 |
+
cols = FIXED_COLS + _cols
|
88 |
+
types = FIXED_COLS_TYPES + _types
|
89 |
return cols, types
|
90 |
|
91 |
|
92 |
+
fixed_cols = get_default_auto_eval_column_dict()[:-3]
|
93 |
+
|
94 |
+
FIXED_COLS = [c.name for _, _, c in fixed_cols]
|
95 |
+
FIXED_COLS_TYPES = [c.type for _, _, c in fixed_cols]
|
96 |
+
|
97 |
+
|
98 |
+
def select_columns(
|
99 |
+
df: pd.DataFrame,
|
100 |
+
domain_query: list,
|
101 |
+
language_query: list,
|
102 |
+
task: str = "qa",
|
103 |
+
reset_ranking: bool = True
|
104 |
+
) -> pd.DataFrame:
|
105 |
+
cols, _ = get_default_cols(task=task, columns=df.columns, add_fix_cols=False)
|
106 |
selected_cols = []
|
107 |
for c in cols:
|
108 |
+
if task == "qa":
|
109 |
+
eval_col = BenchmarksQA[c].value
|
110 |
+
elif task == "long-doc":
|
111 |
+
eval_col = BenchmarksLongDoc[c].value
|
112 |
+
if eval_col.domain not in domain_query:
|
|
|
|
|
113 |
continue
|
114 |
+
if eval_col.lang not in language_query:
|
115 |
continue
|
116 |
selected_cols.append(c)
|
117 |
# We use COLS to maintain sorting
|
118 |
+
filtered_df = df[FIXED_COLS + selected_cols]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
if reset_ranking:
|
120 |
filtered_df[COL_NAME_AVG] = filtered_df[selected_cols].apply(calculate_mean, axis=1).round(decimals=2)
|
121 |
filtered_df.sort_values(by=[COL_NAME_AVG], ascending=False, inplace=True)
|
122 |
filtered_df.reset_index(inplace=True, drop=True)
|
123 |
filtered_df = reset_rank(filtered_df)
|
124 |
+
|
125 |
return filtered_df
|
126 |
|
127 |
|
128 |
+
def _update_table(
|
129 |
+
task: str,
|
130 |
+
hidden_df: pd.DataFrame,
|
131 |
+
domains: list,
|
132 |
+
langs: list,
|
133 |
+
reranking_query: list,
|
134 |
+
query: str,
|
135 |
+
show_anonymous: bool,
|
136 |
+
reset_ranking: bool = True,
|
137 |
+
show_revision_and_timestamp: bool = False
|
|
|
138 |
):
|
139 |
+
filtered_df = hidden_df.copy()
|
140 |
if not show_anonymous:
|
141 |
filtered_df = filtered_df[~filtered_df[COL_NAME_IS_ANONYMOUS]]
|
142 |
filtered_df = filter_models(filtered_df, reranking_query)
|
143 |
filtered_df = filter_queries(query, filtered_df)
|
144 |
+
filtered_df = select_columns(filtered_df, domains, langs, task, reset_ranking)
|
145 |
if not show_revision_and_timestamp:
|
146 |
filtered_df.drop([COL_NAME_REVISION, COL_NAME_TIMESTAMP], axis=1, inplace=True)
|
147 |
return filtered_df
|
148 |
|
149 |
|
150 |
+
def update_table(
|
151 |
+
hidden_df: pd.DataFrame,
|
152 |
+
domains: list,
|
153 |
+
langs: list,
|
154 |
+
reranking_query: list,
|
155 |
+
query: str,
|
156 |
+
show_anonymous: bool,
|
157 |
+
show_revision_and_timestamp: bool = False,
|
158 |
+
reset_ranking: bool = True
|
|
|
159 |
):
|
160 |
+
return _update_table(
|
161 |
+
"qa", hidden_df, domains, langs, reranking_query, query, show_anonymous, reset_ranking, show_revision_and_timestamp)
|
162 |
+
|
163 |
+
|
164 |
+
def update_table_long_doc(
|
165 |
+
hidden_df: pd.DataFrame,
|
166 |
+
domains: list,
|
167 |
+
langs: list,
|
168 |
+
reranking_query: list,
|
169 |
+
query: str,
|
170 |
+
show_anonymous: bool,
|
171 |
+
show_revision_and_timestamp: bool = False,
|
172 |
+
reset_ranking: bool = True
|
173 |
+
|
174 |
+
):
|
175 |
+
return _update_table(
|
176 |
+
"long-doc", hidden_df, domains, langs, reranking_query, query, show_anonymous, reset_ranking, show_revision_and_timestamp)
|
177 |
|
178 |
|
179 |
def update_metric(
|
180 |
+
raw_data: List[FullEvalResult],
|
181 |
+
task: str,
|
182 |
+
metric: str,
|
183 |
+
domains: list,
|
184 |
+
langs: list,
|
185 |
+
reranking_model: list,
|
186 |
+
query: str,
|
187 |
+
show_anonymous: bool = False,
|
188 |
+
show_revision_and_timestamp: bool = False,
|
189 |
) -> pd.DataFrame:
|
190 |
+
if task == 'qa':
|
191 |
+
leaderboard_df = get_leaderboard_df(raw_data, task=task, metric=metric)
|
192 |
+
return update_table(
|
193 |
+
leaderboard_df,
|
194 |
+
domains,
|
195 |
+
langs,
|
196 |
+
reranking_model,
|
197 |
+
query,
|
198 |
+
show_anonymous,
|
199 |
+
show_revision_and_timestamp
|
200 |
+
)
|
201 |
+
elif task == "long-doc":
|
202 |
+
leaderboard_df = get_leaderboard_df(raw_data, task=task, metric=metric)
|
203 |
+
return update_table_long_doc(
|
204 |
+
leaderboard_df,
|
205 |
+
domains,
|
206 |
+
langs,
|
207 |
+
reranking_model,
|
208 |
+
query,
|
209 |
+
show_anonymous,
|
210 |
+
show_revision_and_timestamp
|
211 |
+
)
|
212 |
|
213 |
|
214 |
def upload_file(filepath: str):
|
|
|
218 |
return filepath
|
219 |
|
220 |
|
221 |
+
|
222 |
def get_iso_format_timestamp():
|
223 |
# Get the current timestamp with UTC as the timezone
|
224 |
current_timestamp = datetime.now(timezone.utc)
|
|
|
227 |
current_timestamp = current_timestamp.replace(microsecond=0)
|
228 |
|
229 |
# Convert to ISO 8601 format and replace the offset with 'Z'
|
230 |
+
iso_format_timestamp = current_timestamp.isoformat().replace('+00:00', 'Z')
|
231 |
+
filename_friendly_timestamp = current_timestamp.strftime('%Y%m%d%H%M%S')
|
232 |
return iso_format_timestamp, filename_friendly_timestamp
|
233 |
|
234 |
|
235 |
def calculate_file_md5(file_path):
|
236 |
md5 = hashlib.md5()
|
237 |
|
238 |
+
with open(file_path, 'rb') as f:
|
239 |
while True:
|
240 |
data = f.read(4096)
|
241 |
if not data:
|
|
|
246 |
|
247 |
|
248 |
def submit_results(
|
249 |
+
filepath: str,
|
250 |
+
model: str,
|
251 |
+
model_url: str,
|
252 |
+
reranking_model: str="",
|
253 |
+
reranking_model_url: str="",
|
254 |
+
version: str="AIR-Bench_24.04",
|
255 |
+
is_anonymous=False):
|
|
|
256 |
if not filepath.endswith(".zip"):
|
257 |
return styled_error(f"file uploading aborted. wrong file type: {filepath}")
|
258 |
|
|
|
265 |
if not model_url.startswith("https://") and not model_url.startswith("http://"):
|
266 |
# TODO: retrieve the model page and find the model name on the page
|
267 |
return styled_error(
|
268 |
+
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}")
|
|
|
269 |
if reranking_model != "NoReranker":
|
270 |
if not reranking_model_url.startswith("https://") and not reranking_model_url.startswith("http://"):
|
271 |
return styled_error(
|
272 |
+
f"failed to submit. Model url must start with `https://` or `http://`. Illegal model url: {model_url}")
|
|
|
273 |
|
274 |
# rename the uploaded file
|
275 |
input_fp = Path(filepath)
|
|
|
279 |
input_folder_path = input_fp.parent
|
280 |
|
281 |
if not reranking_model:
|
282 |
+
reranking_model = 'NoReranker'
|
283 |
+
|
284 |
API.upload_file(
|
285 |
path_or_fileobj=filepath,
|
286 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_fn}",
|
287 |
repo_id=SEARCH_RESULTS_REPO,
|
288 |
repo_type="dataset",
|
289 |
+
commit_message=f"feat: submit {model} to evaluate")
|
|
|
290 |
|
291 |
output_config_fn = f"{output_fn.removesuffix('.zip')}.json"
|
292 |
output_config = {
|
|
|
297 |
"version": f"{version}",
|
298 |
"is_anonymous": is_anonymous,
|
299 |
"revision": f"{revision}",
|
300 |
+
"timestamp": f"{timestamp_config}"
|
301 |
}
|
302 |
with open(input_folder_path / output_config_fn, "w") as f:
|
303 |
json.dump(output_config, f, indent=4, ensure_ascii=False)
|
|
|
306 |
path_in_repo=f"{version}/{model}/{reranking_model}/{output_config_fn}",
|
307 |
repo_id=SEARCH_RESULTS_REPO,
|
308 |
repo_type="dataset",
|
309 |
+
commit_message=f"feat: submit {model} + {reranking_model} config")
|
|
|
310 |
return styled_message(
|
311 |
f"Thanks for submission!\n"
|
312 |
f"Retrieval method: {model}\nReranking model: {reranking_model}\nSubmission revision: {revision}"
|
|
|
316 |
def reset_rank(df):
|
317 |
df[COL_NAME_RANK] = df[COL_NAME_AVG].rank(ascending=False, method="min")
|
318 |
return df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/src/display/test_utils.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from src.display.utils import fields, AutoEvalColumnQA, COLS_QA, COLS_LONG_DOC, COLS_LITE, TYPES_QA, TYPES_LONG_DOC, QA_BENCHMARK_COLS, LONG_DOC_BENCHMARK_COLS, get_default_auto_eval_column_dict
|
3 |
+
|
4 |
+
|
5 |
+
def test_fields():
|
6 |
+
for c in fields(AutoEvalColumnQA):
|
7 |
+
print(c)
|
8 |
+
|
9 |
+
|
10 |
+
def test_macro_variables():
|
11 |
+
print(f'COLS_QA: {COLS_QA}')
|
12 |
+
print(f'COLS_LONG_DOC: {COLS_LONG_DOC}')
|
13 |
+
print(f'COLS_LITE: {COLS_LITE}')
|
14 |
+
print(f'TYPES_QA: {TYPES_QA}')
|
15 |
+
print(f'TYPES_LONG_DOC: {TYPES_LONG_DOC}')
|
16 |
+
print(f'QA_BENCHMARK_COLS: {QA_BENCHMARK_COLS}')
|
17 |
+
print(f'LONG_DOC_BENCHMARK_COLS: {LONG_DOC_BENCHMARK_COLS}')
|
18 |
+
|
19 |
+
|
20 |
+
def test_get_default_auto_eval_column_dict():
|
21 |
+
auto_eval_column_dict_list = get_default_auto_eval_column_dict()
|
22 |
+
assert len(auto_eval_column_dict_list) == 9
|
23 |
+
|
tests/src/test_benchmarks.py
CHANGED
@@ -1,33 +1,9 @@
|
|
1 |
-
import
|
2 |
|
3 |
-
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
4 |
-
from src.envs import BENCHMARK_VERSION_LIST
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
# | Task | dev | test |
|
9 |
-
# | ---- | --- | ---- |
|
10 |
-
# | Long-Doc | 4 | 11 |
|
11 |
-
# | QA | 54 | 53 |
|
12 |
-
#
|
13 |
-
# 24.04
|
14 |
-
# | Task | test |
|
15 |
-
# | ---- | ---- |
|
16 |
-
# | Long-Doc | 15 |
|
17 |
-
# | QA | 13 |
|
18 |
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
assert len(QABenchmarks) == len(BENCHMARK_VERSION_LIST)
|
23 |
-
for benchmark_list in list(QABenchmarks):
|
24 |
-
version_slug = benchmark_list.name
|
25 |
-
assert num_datasets_dict[version_slug] == len(benchmark_list.value)
|
26 |
-
|
27 |
-
|
28 |
-
@pytest.mark.parametrize("num_datasets_dict", [{"air_bench_2404": 15, "air_bench_2405": 11}])
|
29 |
-
def test_doc_benchmarks(num_datasets_dict):
|
30 |
-
assert len(LongDocBenchmarks) == len(BENCHMARK_VERSION_LIST)
|
31 |
-
for benchmark_list in list(LongDocBenchmarks):
|
32 |
-
version_slug = benchmark_list.name
|
33 |
-
assert num_datasets_dict[version_slug] == len(benchmark_list.value)
|
|
|
1 |
+
from src.benchmarks import BenchmarksQA, BenchmarksLongDoc
|
2 |
|
|
|
|
|
3 |
|
4 |
+
def test_qabenchmarks():
|
5 |
+
print(list(BenchmarksQA))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
+
def test_longdocbenchmarks():
|
9 |
+
print(list(BenchmarksLongDoc))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/src/test_columns.py
DELETED
@@ -1,119 +0,0 @@
|
|
1 |
-
import pytest
|
2 |
-
|
3 |
-
from src.benchmarks import LongDocBenchmarks, QABenchmarks
|
4 |
-
from src.columns import (
|
5 |
-
COL_NAME_AVG,
|
6 |
-
COL_NAME_RANK,
|
7 |
-
COL_NAME_RERANKING_MODEL,
|
8 |
-
COL_NAME_RETRIEVAL_MODEL,
|
9 |
-
COL_NAME_REVISION,
|
10 |
-
COL_NAME_TIMESTAMP,
|
11 |
-
get_default_auto_eval_column_dict,
|
12 |
-
get_default_col_names_and_types,
|
13 |
-
get_fixed_col_names_and_types,
|
14 |
-
make_autoevalcolumn,
|
15 |
-
)
|
16 |
-
|
17 |
-
# Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
|
18 |
-
# 24.05
|
19 |
-
# | Task | dev | test |
|
20 |
-
# | ---- | --- | ---- |
|
21 |
-
# | Long-Doc | 4 | 11 |
|
22 |
-
# | QA | 54 | 53 |
|
23 |
-
#
|
24 |
-
# 24.04
|
25 |
-
# | Task | test |
|
26 |
-
# | ---- | ---- |
|
27 |
-
# | Long-Doc | 15 |
|
28 |
-
# | QA | 13 |
|
29 |
-
|
30 |
-
|
31 |
-
@pytest.fixture()
|
32 |
-
def expected_col_names():
|
33 |
-
return [
|
34 |
-
"rank",
|
35 |
-
"retrieval_model",
|
36 |
-
"reranking_model",
|
37 |
-
"revision",
|
38 |
-
"timestamp",
|
39 |
-
"average",
|
40 |
-
"retrieval_model_link",
|
41 |
-
"reranking_model_link",
|
42 |
-
"is_anonymous",
|
43 |
-
]
|
44 |
-
|
45 |
-
|
46 |
-
@pytest.fixture()
|
47 |
-
def expected_hidden_col_names():
|
48 |
-
return [
|
49 |
-
"retrieval_model_link",
|
50 |
-
"reranking_model_link",
|
51 |
-
"is_anonymous",
|
52 |
-
]
|
53 |
-
|
54 |
-
|
55 |
-
def test_get_default_auto_eval_column_dict(expected_col_names, expected_hidden_col_names):
|
56 |
-
col_list = get_default_auto_eval_column_dict()
|
57 |
-
assert len(col_list) == 9
|
58 |
-
hidden_cols = []
|
59 |
-
for col_tuple, expected_col in zip(col_list, expected_col_names):
|
60 |
-
col, _, col_content = col_tuple
|
61 |
-
assert col == expected_col
|
62 |
-
if col_content.hidden:
|
63 |
-
hidden_cols.append(col)
|
64 |
-
assert hidden_cols == expected_hidden_col_names
|
65 |
-
|
66 |
-
|
67 |
-
def test_get_fixed_col_names_and_types():
|
68 |
-
col_names, col_types = get_fixed_col_names_and_types()
|
69 |
-
assert len(col_names) == 6
|
70 |
-
assert len(col_types) == 6
|
71 |
-
expected_col_and_type = [
|
72 |
-
(COL_NAME_RANK, "number"),
|
73 |
-
(COL_NAME_RETRIEVAL_MODEL, "markdown"),
|
74 |
-
(COL_NAME_RERANKING_MODEL, "markdown"),
|
75 |
-
(COL_NAME_REVISION, "markdown"),
|
76 |
-
(COL_NAME_TIMESTAMP, "date"),
|
77 |
-
(COL_NAME_AVG, "number"),
|
78 |
-
]
|
79 |
-
for col_name, col_type, (c_name, c_type) in zip(col_names, col_types, expected_col_and_type):
|
80 |
-
assert col_name == c_name
|
81 |
-
assert col_type == c_type
|
82 |
-
|
83 |
-
|
84 |
-
@pytest.mark.parametrize(
|
85 |
-
"benchmarks, expected_benchmark_len",
|
86 |
-
[
|
87 |
-
(QABenchmarks, {"air_bench_2404": 13, "air_bench_2405": 53}),
|
88 |
-
(LongDocBenchmarks, {"air_bench_2404": 15, "air_bench_2405": 11}),
|
89 |
-
],
|
90 |
-
)
|
91 |
-
def test_make_autoevalcolumn(benchmarks, expected_benchmark_len, expected_col_names):
|
92 |
-
expected_default_attrs = frozenset(expected_col_names)
|
93 |
-
for benchmark in benchmarks:
|
94 |
-
TestEvalColumn = make_autoevalcolumn("TestEvalColumn", benchmark)
|
95 |
-
attrs = []
|
96 |
-
for k, v in TestEvalColumn.__dict__.items():
|
97 |
-
if not k.startswith("__"):
|
98 |
-
attrs.append(k)
|
99 |
-
attrs = frozenset(attrs)
|
100 |
-
assert expected_default_attrs.issubset(attrs)
|
101 |
-
benchmark_attrs = attrs.difference(expected_default_attrs)
|
102 |
-
assert len(benchmark_attrs) == expected_benchmark_len[benchmark.name]
|
103 |
-
|
104 |
-
|
105 |
-
@pytest.mark.parametrize(
|
106 |
-
"benchmarks, expected_benchmark_len",
|
107 |
-
[
|
108 |
-
(QABenchmarks, {"air_bench_2404": 13, "air_bench_2405": 53}),
|
109 |
-
(LongDocBenchmarks, {"air_bench_2404": 15, "air_bench_2405": 11}),
|
110 |
-
],
|
111 |
-
)
|
112 |
-
def test_get_default_col_names_and_types(
|
113 |
-
benchmarks, expected_benchmark_len, expected_col_names, expected_hidden_col_names
|
114 |
-
):
|
115 |
-
default_col_len = len(expected_col_names)
|
116 |
-
hidden_col_len = len(expected_hidden_col_names)
|
117 |
-
for benchmark in benchmarks:
|
118 |
-
col_names, col_types = get_default_col_names_and_types(benchmark)
|
119 |
-
assert len(col_names) == expected_benchmark_len[benchmark.name] + default_col_len - hidden_col_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/src/test_envs.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
from air_benchmark.tasks import BenchmarkTable
|
2 |
-
|
3 |
-
from src.envs import BENCHMARK_VERSION_LIST, DEFAULT_METRIC_LONG_DOC, DEFAULT_METRIC_QA, METRIC_LIST
|
4 |
-
|
5 |
-
|
6 |
-
def test_benchmark_version_list():
|
7 |
-
leaderboard_versions = frozenset(BENCHMARK_VERSION_LIST)
|
8 |
-
available_versions = frozenset([k for k in BenchmarkTable.keys()])
|
9 |
-
assert leaderboard_versions.issubset(available_versions)
|
10 |
-
|
11 |
-
|
12 |
-
def test_default_metrics():
|
13 |
-
assert DEFAULT_METRIC_QA in METRIC_LIST
|
14 |
-
assert DEFAULT_METRIC_LONG_DOC in METRIC_LIST
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/src/test_loaders.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
|
3 |
-
import pandas as pd
|
4 |
-
import pytest
|
5 |
-
|
6 |
-
from src.loaders import load_eval_results, load_leaderboard_datastore, load_raw_eval_results
|
7 |
-
|
8 |
-
cur_fp = Path(__file__)
|
9 |
-
|
10 |
-
|
11 |
-
@pytest.mark.parametrize("version", ["AIR-Bench_24.04", "AIR-Bench_24.05"])
|
12 |
-
def test_load_raw_eval_results(version):
|
13 |
-
raw_data = load_raw_eval_results(cur_fp.parents[1] / f"toydata/eval_results/{version}")
|
14 |
-
assert len(raw_data) == 1
|
15 |
-
full_eval_result = raw_data[0]
|
16 |
-
expected_attr = [
|
17 |
-
"eval_name",
|
18 |
-
"retrieval_model",
|
19 |
-
"reranking_model",
|
20 |
-
"retrieval_model_link",
|
21 |
-
"reranking_model_link",
|
22 |
-
"results",
|
23 |
-
"timestamp",
|
24 |
-
"revision",
|
25 |
-
"is_anonymous",
|
26 |
-
]
|
27 |
-
result_attr = [k for k in full_eval_result.__dict__.keys() if k[:2] != "__" and k[-2:] != "__"]
|
28 |
-
assert sorted(expected_attr) == sorted(result_attr)
|
29 |
-
|
30 |
-
|
31 |
-
@pytest.mark.parametrize("version", ["AIR-Bench_24.04", "AIR-Bench_24.05"])
|
32 |
-
def test_load_leaderboard_datastore(version):
|
33 |
-
file_path = cur_fp.parents[1] / f"toydata/eval_results/{version}"
|
34 |
-
datastore = load_leaderboard_datastore(file_path, version)
|
35 |
-
for k, v in datastore.__dict__.items():
|
36 |
-
if k[:2] != "__" and k[-2:] != "__":
|
37 |
-
if isinstance(v, list):
|
38 |
-
assert v
|
39 |
-
elif isinstance(v, pd.DataFrame):
|
40 |
-
assert not v.empty
|
41 |
-
|
42 |
-
|
43 |
-
def test_load_eval_results():
|
44 |
-
file_path = cur_fp.parents[1] / "toydata/eval_results/"
|
45 |
-
datastore_dict = load_eval_results(file_path)
|
46 |
-
assert len(datastore_dict) == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/src/test_models.py
DELETED
@@ -1,89 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
|
3 |
-
import pytest
|
4 |
-
|
5 |
-
from src.models import EvalResult, FullEvalResult
|
6 |
-
|
7 |
-
cur_fp = Path(__file__)
|
8 |
-
|
9 |
-
|
10 |
-
# Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md
|
11 |
-
# 24.05
|
12 |
-
# | Task | dev | test |
|
13 |
-
# | ---- | --- | ---- |
|
14 |
-
# | Long-Doc | 4 | 11 |
|
15 |
-
# | QA | 54 | 53 |
|
16 |
-
#
|
17 |
-
# 24.04
|
18 |
-
# | Task | test |
|
19 |
-
# | ---- | ---- |
|
20 |
-
# | Long-Doc | 15 |
|
21 |
-
# | QA | 13 |
|
22 |
-
NUM_QA_BENCHMARKS_24_05 = 53
|
23 |
-
NUM_DOC_BENCHMARKS_24_05 = 11
|
24 |
-
NUM_QA_BENCHMARKS_24_04 = 13
|
25 |
-
NUM_DOC_BENCHMARKS_24_04 = 15
|
26 |
-
|
27 |
-
|
28 |
-
def test_eval_result():
|
29 |
-
EvalResult(
|
30 |
-
eval_name="eval_name",
|
31 |
-
retrieval_model="bge-m3",
|
32 |
-
reranking_model="NoReranking",
|
33 |
-
results=[{"domain": "law", "lang": "en", "dataset": "lex_files_500K-600K", "value": 0.45723}],
|
34 |
-
task="qa",
|
35 |
-
metric="ndcg_at_3",
|
36 |
-
timestamp="2024-05-14T03:09:08Z",
|
37 |
-
revision="1e243f14bd295ccdea7a118fe847399d",
|
38 |
-
is_anonymous=True,
|
39 |
-
)
|
40 |
-
|
41 |
-
|
42 |
-
@pytest.mark.parametrize(
|
43 |
-
"file_path",
|
44 |
-
[
|
45 |
-
"AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json",
|
46 |
-
"AIR-Bench_24.05/bge-m3/NoReranker/results.json",
|
47 |
-
],
|
48 |
-
)
|
49 |
-
def test_full_eval_result_init_from_json_file(file_path):
|
50 |
-
json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
|
51 |
-
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
52 |
-
assert json_fp.parents[0].stem == full_eval_result.reranking_model
|
53 |
-
assert json_fp.parents[1].stem == full_eval_result.retrieval_model
|
54 |
-
assert len(full_eval_result.results) == 70
|
55 |
-
|
56 |
-
|
57 |
-
@pytest.mark.parametrize(
|
58 |
-
"file_path, task, expected_num_results",
|
59 |
-
[
|
60 |
-
("AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "qa", NUM_QA_BENCHMARKS_24_04),
|
61 |
-
(
|
62 |
-
"AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json",
|
63 |
-
"long-doc",
|
64 |
-
NUM_DOC_BENCHMARKS_24_04,
|
65 |
-
),
|
66 |
-
("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "qa", NUM_QA_BENCHMARKS_24_05),
|
67 |
-
("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "long-doc", NUM_DOC_BENCHMARKS_24_05),
|
68 |
-
],
|
69 |
-
)
|
70 |
-
def test_full_eval_result_to_dict(file_path, task, expected_num_results):
|
71 |
-
json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
|
72 |
-
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
73 |
-
result_dict_list = full_eval_result.to_dict(task)
|
74 |
-
assert len(result_dict_list) == 1
|
75 |
-
result = result_dict_list[0]
|
76 |
-
attr_list = frozenset(
|
77 |
-
[
|
78 |
-
"eval_name",
|
79 |
-
"Retrieval Method",
|
80 |
-
"Reranking Model",
|
81 |
-
"Retrieval Model LINK",
|
82 |
-
"Reranking Model LINK",
|
83 |
-
"Revision",
|
84 |
-
"Submission Date",
|
85 |
-
"Anonymous Submission",
|
86 |
-
]
|
87 |
-
)
|
88 |
-
result_cols = list(result.keys())
|
89 |
-
assert len(result_cols) == (expected_num_results + len(attr_list))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/src/test_read_evals.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from src.read_evals import FullEvalResult, get_raw_eval_results, get_leaderboard_df
|
4 |
+
|
5 |
+
cur_fp = Path(__file__)
|
6 |
+
|
7 |
+
|
8 |
+
def test_init_from_json_file():
|
9 |
+
json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
|
10 |
+
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
11 |
+
num_different_task_domain_lang_metric_dataset_combination = 6
|
12 |
+
assert len(full_eval_result.results) == \
|
13 |
+
num_different_task_domain_lang_metric_dataset_combination
|
14 |
+
assert full_eval_result.retrieval_model == "bge-m3"
|
15 |
+
assert full_eval_result.reranking_model == "bge-reranker-v2-m3"
|
16 |
+
|
17 |
+
|
18 |
+
def test_to_dict():
|
19 |
+
json_fp = cur_fp.parents[2] / "toydata" / "test_data.json"
|
20 |
+
full_eval_result = FullEvalResult.init_from_json_file(json_fp)
|
21 |
+
result_list = full_eval_result.to_dict(task='qa', metric='ndcg_at_1')
|
22 |
+
assert len(result_list) == 1
|
23 |
+
result_dict = result_list[0]
|
24 |
+
assert result_dict["Retrieval Model"] == "bge-m3"
|
25 |
+
assert result_dict["Reranking Model"] == "bge-reranker-v2-m3"
|
26 |
+
assert result_dict["wiki_en"] is not None
|
27 |
+
assert result_dict["wiki_zh"] is not None
|
28 |
+
|
29 |
+
|
30 |
+
def test_get_raw_eval_results():
|
31 |
+
results_path = cur_fp.parents[2] / "toydata" / "eval_results" / "AIR-Bench_24.04"
|
32 |
+
results = get_raw_eval_results(results_path)
|
33 |
+
# only load the latest results
|
34 |
+
assert len(results) == 4
|
35 |
+
assert results[0].eval_name == "bge-base-en-v1.5_NoReranker"
|
36 |
+
assert len(results[0].results) == 70
|
37 |
+
assert results[0].eval_name == "bge-base-en-v1.5_bge-reranker-v2-m3"
|
38 |
+
assert len(results[1].results) == 70
|
39 |
+
|
40 |
+
|
41 |
+
def test_get_leaderboard_df():
|
42 |
+
results_path = cur_fp.parents[2] / "toydata" / "eval_results" / "AIR-Bench_24.04"
|
43 |
+
raw_data = get_raw_eval_results(results_path)
|
44 |
+
df = get_leaderboard_df(raw_data, 'qa', 'ndcg_at_10')
|
45 |
+
assert df.shape[0] == 4
|
46 |
+
# the results contain only one embedding model
|
47 |
+
# for i in range(4):
|
48 |
+
# assert df["Retrieval Model"][i] == "bge-m3"
|
49 |
+
# # the results contain only two reranking model
|
50 |
+
# assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
|
51 |
+
# assert df["Reranking Model"][1] == "NoReranker"
|
52 |
+
# assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
|
53 |
+
# assert not df[['Average ⬆️', 'wiki_en', 'wiki_zh', ]].isnull().values.any()
|
54 |
+
|
55 |
+
|
56 |
+
def test_get_leaderboard_df_long_doc():
|
57 |
+
results_path = cur_fp.parents[2] / "toydata" / "test_results"
|
58 |
+
raw_data = get_raw_eval_results(results_path)
|
59 |
+
df = get_leaderboard_df(raw_data, 'long-doc', 'ndcg_at_1')
|
60 |
+
assert df.shape[0] == 2
|
61 |
+
# the results contain only one embedding model
|
62 |
+
for i in range(2):
|
63 |
+
assert df["Retrieval Model"][i] == "bge-m3"
|
64 |
+
# the results contains only two reranking model
|
65 |
+
assert df["Reranking Model"][0] == "bge-reranker-v2-m3"
|
66 |
+
assert df["Reranking Model"][1] == "NoReranker"
|
67 |
+
assert df["Average ⬆️"][0] > df["Average ⬆️"][1]
|
68 |
+
assert not df[['Average ⬆️', 'law_en_lex_files_500k_600k', ]].isnull().values.any()
|
tests/src/test_utils.py
DELETED
@@ -1,237 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
|
3 |
-
import pandas as pd
|
4 |
-
import pytest
|
5 |
-
|
6 |
-
from src.columns import COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL
|
7 |
-
from src.models import TaskType, model_hyperlink
|
8 |
-
from src.utils import (
|
9 |
-
_update_df_elem,
|
10 |
-
calculate_mean,
|
11 |
-
filter_models,
|
12 |
-
filter_queries,
|
13 |
-
get_default_cols,
|
14 |
-
get_leaderboard_df,
|
15 |
-
get_selected_cols,
|
16 |
-
remove_html,
|
17 |
-
select_columns,
|
18 |
-
)
|
19 |
-
|
20 |
-
cur_fp = Path(__file__)
|
21 |
-
|
22 |
-
NUM_QA_BENCHMARKS_24_05 = 53
|
23 |
-
NUM_DOC_BENCHMARKS_24_05 = 11
|
24 |
-
NUM_QA_BENCHMARKS_24_04 = 13
|
25 |
-
NUM_DOC_BENCHMARKS_24_04 = 15
|
26 |
-
|
27 |
-
|
28 |
-
@pytest.fixture
|
29 |
-
def toy_df():
|
30 |
-
return pd.DataFrame(
|
31 |
-
{
|
32 |
-
"Retrieval Method": ["bge-m3", "bge-m3", "jina-embeddings-v2-base", "jina-embeddings-v2-base"],
|
33 |
-
"Reranking Model": ["bge-reranker-v2-m3", "NoReranker", "bge-reranker-v2-m3", "NoReranker"],
|
34 |
-
"Rank 🏆": [1, 2, 3, 4],
|
35 |
-
"Revision": ["123", "234", "345", "456"],
|
36 |
-
"Submission Date": ["", "", "", ""],
|
37 |
-
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
38 |
-
"wiki_en": [0.8, 0.7, 0.2, 0.1],
|
39 |
-
"wiki_zh": [0.4, 0.1, 0.4, 0.3],
|
40 |
-
"news_en": [0.8, 0.7, 0.2, 0.1],
|
41 |
-
"news_zh": [0.4, 0.1, 0.2, 0.3],
|
42 |
-
"Anonymous Submission": [False, False, False, True],
|
43 |
-
}
|
44 |
-
)
|
45 |
-
|
46 |
-
|
47 |
-
def test_remove_html():
|
48 |
-
model_name = "jina-embeddings-v3"
|
49 |
-
html_str = model_hyperlink("https://jina.ai", model_name)
|
50 |
-
output_str = remove_html(html_str)
|
51 |
-
assert output_str == model_name
|
52 |
-
|
53 |
-
|
54 |
-
def test_calculate_mean():
|
55 |
-
valid_row = [1, 3]
|
56 |
-
invalid_row = [2, pd.NA]
|
57 |
-
df = pd.DataFrame([valid_row, invalid_row], columns=["a", "b"])
|
58 |
-
result = list(df.apply(calculate_mean, axis=1))
|
59 |
-
assert result[0] == sum(valid_row) / 2
|
60 |
-
assert result[1] == -1
|
61 |
-
|
62 |
-
|
63 |
-
@pytest.mark.parametrize(
|
64 |
-
"models, expected",
|
65 |
-
[
|
66 |
-
(["model1", "model3"], 2),
|
67 |
-
(["model1", "model_missing"], 1),
|
68 |
-
(["model1", "model2", "model3"], 3),
|
69 |
-
(
|
70 |
-
[
|
71 |
-
"model1",
|
72 |
-
],
|
73 |
-
1,
|
74 |
-
),
|
75 |
-
([], 3),
|
76 |
-
],
|
77 |
-
)
|
78 |
-
def test_filter_models(models, expected):
|
79 |
-
df = pd.DataFrame(
|
80 |
-
{
|
81 |
-
COL_NAME_RERANKING_MODEL: [
|
82 |
-
"model1",
|
83 |
-
"model2",
|
84 |
-
"model3",
|
85 |
-
],
|
86 |
-
"col2": [1, 2, 3],
|
87 |
-
}
|
88 |
-
)
|
89 |
-
output_df = filter_models(df, models)
|
90 |
-
assert len(output_df) == expected
|
91 |
-
|
92 |
-
|
93 |
-
@pytest.mark.parametrize(
|
94 |
-
"query, expected",
|
95 |
-
[
|
96 |
-
("model1;model3", 2),
|
97 |
-
("model1;model4", 1),
|
98 |
-
("model1;model2;model3", 3),
|
99 |
-
("model1", 1),
|
100 |
-
("", 3),
|
101 |
-
],
|
102 |
-
)
|
103 |
-
def test_filter_queries(query, expected):
|
104 |
-
df = pd.DataFrame(
|
105 |
-
{
|
106 |
-
COL_NAME_RETRIEVAL_MODEL: [
|
107 |
-
"model1",
|
108 |
-
"model2",
|
109 |
-
"model3",
|
110 |
-
],
|
111 |
-
COL_NAME_RERANKING_MODEL: [
|
112 |
-
"model4",
|
113 |
-
"model5",
|
114 |
-
"model6",
|
115 |
-
],
|
116 |
-
}
|
117 |
-
)
|
118 |
-
output_df = filter_queries(query, df)
|
119 |
-
assert len(output_df) == expected
|
120 |
-
|
121 |
-
|
122 |
-
@pytest.mark.parametrize(
|
123 |
-
"task_type, slug, add_fix_cols, expected",
|
124 |
-
[
|
125 |
-
(TaskType.qa, "air_bench_2404", True, NUM_QA_BENCHMARKS_24_04),
|
126 |
-
(TaskType.long_doc, "air_bench_2404", True, NUM_DOC_BENCHMARKS_24_04),
|
127 |
-
(TaskType.qa, "air_bench_2405", False, NUM_QA_BENCHMARKS_24_05),
|
128 |
-
(TaskType.long_doc, "air_bench_2405", False, NUM_DOC_BENCHMARKS_24_05),
|
129 |
-
],
|
130 |
-
)
|
131 |
-
def test_get_default_cols(task_type, slug, add_fix_cols, expected):
|
132 |
-
attr_cols = ["Rank 🏆", "Retrieval Method", "Reranking Model", "Revision", "Submission Date", "Average ⬆️"]
|
133 |
-
cols, types = get_default_cols(task_type, slug)
|
134 |
-
cols_set = frozenset(cols)
|
135 |
-
attrs_set = frozenset(attr_cols)
|
136 |
-
if add_fix_cols:
|
137 |
-
assert attrs_set.issubset(cols_set)
|
138 |
-
benchmark_cols = list(cols_set.difference(attrs_set))
|
139 |
-
assert len(benchmark_cols) == expected
|
140 |
-
|
141 |
-
|
142 |
-
@pytest.mark.parametrize(
|
143 |
-
"task_type, domains, languages, expected",
|
144 |
-
[
|
145 |
-
(
|
146 |
-
TaskType.qa,
|
147 |
-
["wiki", "news"],
|
148 |
-
[
|
149 |
-
"zh",
|
150 |
-
],
|
151 |
-
["wiki_zh", "news_zh"],
|
152 |
-
),
|
153 |
-
(
|
154 |
-
TaskType.qa,
|
155 |
-
[
|
156 |
-
"law",
|
157 |
-
],
|
158 |
-
["zh", "en"],
|
159 |
-
["law_en"],
|
160 |
-
),
|
161 |
-
(
|
162 |
-
TaskType.long_doc,
|
163 |
-
["healthcare"],
|
164 |
-
["zh", "en"],
|
165 |
-
[
|
166 |
-
"healthcare_en_pubmed_100k_200k_1",
|
167 |
-
"healthcare_en_pubmed_100k_200k_2",
|
168 |
-
"healthcare_en_pubmed_100k_200k_3",
|
169 |
-
"healthcare_en_pubmed_40k_50k_5_merged",
|
170 |
-
"healthcare_en_pubmed_30k_40k_10_merged",
|
171 |
-
],
|
172 |
-
),
|
173 |
-
],
|
174 |
-
)
|
175 |
-
def test_get_selected_cols(task_type, domains, languages, expected):
|
176 |
-
slug = "air_bench_2404"
|
177 |
-
cols = get_selected_cols(task_type, slug, domains, languages)
|
178 |
-
assert sorted(cols) == sorted(expected)
|
179 |
-
|
180 |
-
|
181 |
-
@pytest.mark.parametrize("reset_rank", [False])
|
182 |
-
def test_select_columns(toy_df, reset_rank):
|
183 |
-
expected = [
|
184 |
-
"Rank 🏆",
|
185 |
-
"Retrieval Method",
|
186 |
-
"Reranking Model",
|
187 |
-
"Revision",
|
188 |
-
"Submission Date",
|
189 |
-
"Average ⬆️",
|
190 |
-
"news_zh",
|
191 |
-
]
|
192 |
-
df_result = select_columns(toy_df, ["news"], ["zh"], version_slug="air_bench_2404", reset_ranking=reset_rank)
|
193 |
-
assert len(df_result.columns) == len(expected)
|
194 |
-
if reset_rank:
|
195 |
-
assert df_result["Average ⬆️"].equals(df_result["news_zh"])
|
196 |
-
else:
|
197 |
-
assert df_result["Average ⬆️"].equals(toy_df["Average ⬆️"])
|
198 |
-
|
199 |
-
|
200 |
-
@pytest.mark.parametrize(
|
201 |
-
"reset_rank, show_anony",
|
202 |
-
[
|
203 |
-
(False, True),
|
204 |
-
(True, True),
|
205 |
-
(True, False),
|
206 |
-
],
|
207 |
-
)
|
208 |
-
def test__update_df_elem(toy_df, reset_rank, show_anony):
|
209 |
-
df = _update_df_elem(TaskType.qa, "AIR-Bench_24.04", toy_df, ["news"], ["zh"], [], "", show_anony, reset_rank)
|
210 |
-
if show_anony:
|
211 |
-
assert df.shape[0] == 4
|
212 |
-
else:
|
213 |
-
assert df.shape[0] == 3
|
214 |
-
if show_anony:
|
215 |
-
if reset_rank:
|
216 |
-
assert df["Average ⬆️"].equals(df["news_zh"])
|
217 |
-
else:
|
218 |
-
assert df["Average ⬆️"].equals(toy_df["Average ⬆️"])
|
219 |
-
|
220 |
-
|
221 |
-
@pytest.mark.parametrize(
|
222 |
-
"version, task_type",
|
223 |
-
[
|
224 |
-
("AIR-Bench_24.04", TaskType.qa),
|
225 |
-
("AIR-Bench_24.04", TaskType.long_doc),
|
226 |
-
("AIR-Bench_24.05", TaskType.qa),
|
227 |
-
("AIR-Bench_24.05", TaskType.long_doc),
|
228 |
-
],
|
229 |
-
)
|
230 |
-
def test_get_leaderboard_df(version, task_type):
|
231 |
-
from src.loaders import load_raw_eval_results
|
232 |
-
from src.models import LeaderboardDataStore, get_safe_name
|
233 |
-
|
234 |
-
raw_data = load_raw_eval_results(cur_fp.parents[1] / f"toydata/eval_results/{version}")
|
235 |
-
ds = LeaderboardDataStore(version, get_safe_name(version), raw_data=raw_data)
|
236 |
-
df = get_leaderboard_df(ds, task_type, "ndcg_at_10")
|
237 |
-
assert df.shape[0] == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import pytest
|
3 |
+
|
4 |
+
from src.utils import filter_models, search_table, filter_queries, select_columns, update_table_long_doc, get_iso_format_timestamp, get_default_cols, update_table
|
5 |
+
from src.display.utils import COL_NAME_IS_ANONYMOUS, COL_NAME_REVISION, COL_NAME_TIMESTAMP, COL_NAME_RERANKING_MODEL, COL_NAME_RETRIEVAL_MODEL, COL_NAME_RANK, COL_NAME_AVG
|
6 |
+
|
7 |
+
|
8 |
+
@pytest.fixture
|
9 |
+
def toy_df():
|
10 |
+
return pd.DataFrame(
|
11 |
+
{
|
12 |
+
"Retrieval Model": [
|
13 |
+
"bge-m3",
|
14 |
+
"bge-m3",
|
15 |
+
"jina-embeddings-v2-base",
|
16 |
+
"jina-embeddings-v2-base"
|
17 |
+
],
|
18 |
+
"Reranking Model": [
|
19 |
+
"bge-reranker-v2-m3",
|
20 |
+
"NoReranker",
|
21 |
+
"bge-reranker-v2-m3",
|
22 |
+
"NoReranker"
|
23 |
+
],
|
24 |
+
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
25 |
+
"wiki_en": [0.8, 0.7, 0.2, 0.1],
|
26 |
+
"wiki_zh": [0.4, 0.1, 0.4, 0.3],
|
27 |
+
"news_en": [0.8, 0.7, 0.2, 0.1],
|
28 |
+
"news_zh": [0.4, 0.1, 0.4, 0.3],
|
29 |
+
}
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
@pytest.fixture
|
34 |
+
def toy_df_long_doc():
|
35 |
+
return pd.DataFrame(
|
36 |
+
{
|
37 |
+
"Retrieval Model": [
|
38 |
+
"bge-m3",
|
39 |
+
"bge-m3",
|
40 |
+
"jina-embeddings-v2-base",
|
41 |
+
"jina-embeddings-v2-base"
|
42 |
+
],
|
43 |
+
"Reranking Model": [
|
44 |
+
"bge-reranker-v2-m3",
|
45 |
+
"NoReranker",
|
46 |
+
"bge-reranker-v2-m3",
|
47 |
+
"NoReranker"
|
48 |
+
],
|
49 |
+
"Average ⬆️": [0.6, 0.4, 0.3, 0.2],
|
50 |
+
"law_en_lex_files_300k_400k": [0.4, 0.1, 0.4, 0.3],
|
51 |
+
"law_en_lex_files_400k_500k": [0.8, 0.7, 0.2, 0.1],
|
52 |
+
"law_en_lex_files_500k_600k": [0.8, 0.7, 0.2, 0.1],
|
53 |
+
"law_en_lex_files_600k_700k": [0.4, 0.1, 0.4, 0.3],
|
54 |
+
}
|
55 |
+
)
|
56 |
+
def test_filter_models(toy_df):
|
57 |
+
df_result = filter_models(toy_df, ["bge-reranker-v2-m3", ])
|
58 |
+
assert len(df_result) == 2
|
59 |
+
assert df_result.iloc[0]["Reranking Model"] == "bge-reranker-v2-m3"
|
60 |
+
|
61 |
+
|
62 |
+
def test_search_table(toy_df):
|
63 |
+
df_result = search_table(toy_df, "jina")
|
64 |
+
assert len(df_result) == 2
|
65 |
+
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
|
66 |
+
|
67 |
+
|
68 |
+
def test_filter_queries(toy_df):
|
69 |
+
df_result = filter_queries("jina", toy_df)
|
70 |
+
assert len(df_result) == 2
|
71 |
+
assert df_result.iloc[0]["Retrieval Model"] == "jina-embeddings-v2-base"
|
72 |
+
|
73 |
+
|
74 |
+
def test_select_columns(toy_df):
|
75 |
+
df_result = select_columns(toy_df, ['news',], ['zh',])
|
76 |
+
assert len(df_result.columns) == 4
|
77 |
+
assert df_result['Average ⬆️'].equals(df_result['news_zh'])
|
78 |
+
|
79 |
+
|
80 |
+
def test_update_table_long_doc(toy_df_long_doc):
|
81 |
+
df_result = update_table_long_doc(toy_df_long_doc, ['law',], ['en',], ["bge-reranker-v2-m3", ], "jina")
|
82 |
+
print(df_result)
|
83 |
+
|
84 |
+
|
85 |
+
def test_get_iso_format_timestamp():
|
86 |
+
timestamp_config, timestamp_fn = get_iso_format_timestamp()
|
87 |
+
assert len(timestamp_fn) == 14
|
88 |
+
assert len(timestamp_config) == 20
|
89 |
+
assert timestamp_config[-1] == "Z"
|
90 |
+
|
91 |
+
|
92 |
+
def test_get_default_cols():
|
93 |
+
cols, types = get_default_cols("qa")
|
94 |
+
for c, t in zip(cols, types):
|
95 |
+
print(f"type({c}): {t}")
|
96 |
+
assert len(frozenset(cols)) == len(cols)
|
97 |
+
|
98 |
+
|
99 |
+
def test_update_table():
|
100 |
+
df = pd.DataFrame(
|
101 |
+
{
|
102 |
+
COL_NAME_IS_ANONYMOUS: [False, False, False],
|
103 |
+
COL_NAME_REVISION: ["a1", "a2", "a3"],
|
104 |
+
COL_NAME_TIMESTAMP: ["2024-05-12T12:24:02Z"] * 3,
|
105 |
+
COL_NAME_RERANKING_MODEL: ["NoReranker"] * 3,
|
106 |
+
COL_NAME_RETRIEVAL_MODEL: ["Foo"] * 3,
|
107 |
+
COL_NAME_RANK: [1, 2, 3],
|
108 |
+
COL_NAME_AVG: [0.1, 0.2, 0.3], # unsorted values
|
109 |
+
"wiki_en": [0.1, 0.2, 0.3]
|
110 |
+
}
|
111 |
+
)
|
112 |
+
results = update_table(df, "wiki", "en", ["NoReranker"], "", show_anonymous=False, reset_ranking=False, show_revision_and_timestamp=False)
|
113 |
+
# keep the RANK as the same regardless of the unsorted averages
|
114 |
+
assert results[COL_NAME_RANK].to_list() == [1, 2, 3]
|
115 |
+
|
tests/toydata/eval_results/AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json
DELETED
The diff for this file is too large to render.
See raw diff
|
|
tests/toydata/eval_results/AIR-Bench_24.05/bge-m3/NoReranker/results.json
DELETED
The diff for this file is too large to render.
See raw diff
|
|
tests/toydata/test_data.json
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"config": {
|
4 |
+
"retrieval_model": "bge-m3",
|
5 |
+
"reranking_model": "bge-reranker-v2-m3",
|
6 |
+
"task": "long_doc",
|
7 |
+
"metric": "ndcg_at_1"
|
8 |
+
},
|
9 |
+
"results": [
|
10 |
+
{
|
11 |
+
"domain": "law",
|
12 |
+
"lang": "en",
|
13 |
+
"dataset": "lex_files_500K-600K",
|
14 |
+
"value": 0.75723
|
15 |
+
}
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"config": {
|
20 |
+
"retrieval_model": "bge-m3",
|
21 |
+
"reranking_model": "bge-reranker-v2-m3",
|
22 |
+
"task": "long_doc",
|
23 |
+
"metric": "ndcg_at_3"
|
24 |
+
},
|
25 |
+
"results": [
|
26 |
+
{
|
27 |
+
"domain": "law",
|
28 |
+
"lang": "en",
|
29 |
+
"dataset": "lex_files_500K-600K",
|
30 |
+
"value": 0.69909
|
31 |
+
}
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"config": {
|
36 |
+
"retrieval_model": "bge-m3",
|
37 |
+
"reranking_model": "bge-reranker-v2-m3",
|
38 |
+
"task": "qa",
|
39 |
+
"metric": "ndcg_at_1"
|
40 |
+
},
|
41 |
+
"results": [
|
42 |
+
{
|
43 |
+
"domain": "wiki",
|
44 |
+
"lang": "en",
|
45 |
+
"dataset": "unknown",
|
46 |
+
"value": 0.69083
|
47 |
+
}
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"config": {
|
52 |
+
"retrieval_model": "bge-m3",
|
53 |
+
"reranking_model": "bge-reranker-v2-m3",
|
54 |
+
"task": "qa",
|
55 |
+
"metric": "ndcg_at_3"
|
56 |
+
},
|
57 |
+
"results": [
|
58 |
+
{
|
59 |
+
"domain": "wiki",
|
60 |
+
"lang": "en",
|
61 |
+
"dataset": "unknown",
|
62 |
+
"value": 0.73359
|
63 |
+
}
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"config": {
|
68 |
+
"retrieval_model": "bge-m3",
|
69 |
+
"reranking_model": "bge-reranker-v2-m3",
|
70 |
+
"task": "qa",
|
71 |
+
"metric": "ndcg_at_1"
|
72 |
+
},
|
73 |
+
"results": [
|
74 |
+
{
|
75 |
+
"domain": "wiki",
|
76 |
+
"lang": "zh",
|
77 |
+
"dataset": "unknown",
|
78 |
+
"value": 0.78358
|
79 |
+
}
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"config": {
|
84 |
+
"retrieval_model": "bge-m3",
|
85 |
+
"reranking_model": "bge-reranker-v2-m3",
|
86 |
+
"task": "qa",
|
87 |
+
"metric": "ndcg_at_3"
|
88 |
+
},
|
89 |
+
"results": [
|
90 |
+
{
|
91 |
+
"domain": "wiki",
|
92 |
+
"lang": "zh",
|
93 |
+
"dataset": "unknown",
|
94 |
+
"value": 0.78358
|
95 |
+
}
|
96 |
+
]
|
97 |
+
}
|
98 |
+
]
|
tests/toydata/test_results/bge-m3/NoReranker/results_2023-11-21T18-10-08.json
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"config": {
|
4 |
+
"retrieval_model": "bge-m3",
|
5 |
+
"reranking_model": "NoReranker",
|
6 |
+
"task": "long_doc",
|
7 |
+
"metric": "ndcg_at_1"
|
8 |
+
},
|
9 |
+
"results": [
|
10 |
+
{
|
11 |
+
"domain": "law",
|
12 |
+
"lang": "en",
|
13 |
+
"dataset": "lex_files_500K-600K",
|
14 |
+
"value": 0.45723
|
15 |
+
}
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"config": {
|
20 |
+
"retrieval_model": "bge-m3",
|
21 |
+
"reranking_model": "NoReranker",
|
22 |
+
"task": "long_doc",
|
23 |
+
"metric": "ndcg_at_3"
|
24 |
+
},
|
25 |
+
"results": [
|
26 |
+
{
|
27 |
+
"domain": "law",
|
28 |
+
"lang": "en",
|
29 |
+
"dataset": "lex_files_500K-600K",
|
30 |
+
"value": 0.49909
|
31 |
+
}
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"config": {
|
36 |
+
"retrieval_model": "bge-m3",
|
37 |
+
"reranking_model": "NoReranker",
|
38 |
+
"task": "qa",
|
39 |
+
"metric": "ndcg_at_1"
|
40 |
+
},
|
41 |
+
"results": [
|
42 |
+
{
|
43 |
+
"domain": "wiki",
|
44 |
+
"lang": "en",
|
45 |
+
"dataset": "unknown",
|
46 |
+
"value": 0.49083
|
47 |
+
}
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"config": {
|
52 |
+
"retrieval_model": "bge-m3",
|
53 |
+
"reranking_model": "NoReranker",
|
54 |
+
"task": "qa",
|
55 |
+
"metric": "ndcg_at_3"
|
56 |
+
},
|
57 |
+
"results": [
|
58 |
+
{
|
59 |
+
"domain": "wiki",
|
60 |
+
"lang": "en",
|
61 |
+
"dataset": "unknown",
|
62 |
+
"value": 0.43359
|
63 |
+
}
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"config": {
|
68 |
+
"retrieval_model": "bge-m3",
|
69 |
+
"reranking_model": "NoReranker",
|
70 |
+
"task": "qa",
|
71 |
+
"metric": "ndcg_at_1"
|
72 |
+
},
|
73 |
+
"results": [
|
74 |
+
{
|
75 |
+
"domain": "wiki",
|
76 |
+
"lang": "zh",
|
77 |
+
"dataset": "unknown",
|
78 |
+
"value": 0.78358
|
79 |
+
}
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"config": {
|
84 |
+
"retrieval_model": "bge-m3",
|
85 |
+
"reranking_model": "NoReranker",
|
86 |
+
"task": "qa",
|
87 |
+
"metric": "ndcg_at_3"
|
88 |
+
},
|
89 |
+
"results": [
|
90 |
+
{
|
91 |
+
"domain": "wiki",
|
92 |
+
"lang": "zh",
|
93 |
+
"dataset": "unknown",
|
94 |
+
"value": 0.78358
|
95 |
+
}
|
96 |
+
]
|
97 |
+
}
|
98 |
+
]
|
tests/toydata/test_results/bge-m3/bge-reranker-v2-m3/results_2023-11-21T18-10-08.json
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"config": {
|
4 |
+
"retrieval_model": "bge-m3",
|
5 |
+
"reranking_model": "bge-reranker-v2-m3",
|
6 |
+
"task": "long_doc",
|
7 |
+
"metric": "ndcg_at_1"
|
8 |
+
},
|
9 |
+
"results": [
|
10 |
+
{
|
11 |
+
"domain": "law",
|
12 |
+
"lang": "en",
|
13 |
+
"dataset": "lex_files_500K-600K",
|
14 |
+
"value": 0.75723
|
15 |
+
}
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"config": {
|
20 |
+
"retrieval_model": "bge-m3",
|
21 |
+
"reranking_model": "bge-reranker-v2-m3",
|
22 |
+
"task": "long_doc",
|
23 |
+
"metric": "ndcg_at_3"
|
24 |
+
},
|
25 |
+
"results": [
|
26 |
+
{
|
27 |
+
"domain": "law",
|
28 |
+
"lang": "en",
|
29 |
+
"dataset": "lex_files_500K-600K",
|
30 |
+
"value": 0.69909
|
31 |
+
}
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"config": {
|
36 |
+
"retrieval_model": "bge-m3",
|
37 |
+
"reranking_model": "bge-reranker-v2-m3",
|
38 |
+
"task": "qa",
|
39 |
+
"metric": "ndcg_at_1"
|
40 |
+
},
|
41 |
+
"results": [
|
42 |
+
{
|
43 |
+
"domain": "wiki",
|
44 |
+
"lang": "en",
|
45 |
+
"dataset": "unknown",
|
46 |
+
"value": 0.69083
|
47 |
+
}
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"config": {
|
52 |
+
"retrieval_model": "bge-m3",
|
53 |
+
"reranking_model": "bge-reranker-v2-m3",
|
54 |
+
"task": "qa",
|
55 |
+
"metric": "ndcg_at_3"
|
56 |
+
},
|
57 |
+
"results": [
|
58 |
+
{
|
59 |
+
"domain": "wiki",
|
60 |
+
"lang": "en",
|
61 |
+
"dataset": "unknown",
|
62 |
+
"value": 0.73359
|
63 |
+
}
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"config": {
|
68 |
+
"retrieval_model": "bge-m3",
|
69 |
+
"reranking_model": "bge-reranker-v2-m3",
|
70 |
+
"task": "qa",
|
71 |
+
"metric": "ndcg_at_1"
|
72 |
+
},
|
73 |
+
"results": [
|
74 |
+
{
|
75 |
+
"domain": "wiki",
|
76 |
+
"lang": "zh",
|
77 |
+
"dataset": "unknown",
|
78 |
+
"value": 0.78358
|
79 |
+
}
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"config": {
|
84 |
+
"retrieval_model": "bge-m3",
|
85 |
+
"reranking_model": "bge-reranker-v2-m3",
|
86 |
+
"task": "qa",
|
87 |
+
"metric": "ndcg_at_3"
|
88 |
+
},
|
89 |
+
"results": [
|
90 |
+
{
|
91 |
+
"domain": "wiki",
|
92 |
+
"lang": "zh",
|
93 |
+
"dataset": "unknown",
|
94 |
+
"value": 0.78358
|
95 |
+
}
|
96 |
+
]
|
97 |
+
}
|
98 |
+
]
|