|
import gradio as gr |
|
import pandas as pd |
|
import json |
|
from constants import BANNER, INTRODUCTION_TEXT, CITATION_TEXT, METRICS_TAB_TEXT, DIR_OUTPUT_REQUESTS |
|
from init import is_model_on_hub, load_all_info_from_dataset_hub |
|
from utils_display import AutoEvalColumn, fields, make_clickable_model, styled_error, styled_message |
|
from datetime import datetime, timezone |
|
import torch |
|
|
|
LAST_UPDATED = "OCT 21st 2024" |
|
|
|
column_names = { |
|
"MODEL": "Model", |
|
"WER": "Common Voice WER β¬οΈ", |
|
"CER": "Common Voice CER β¬οΈ", |
|
"WER2": "persian-asr-test-set WER" , |
|
"CER2": "persian-asr-test-set CER", |
|
"WER3": "asr-farsi-youtube WER", |
|
"CER3": "asr-farsi-youtube CER" |
|
} |
|
|
|
|
|
eval_queue_repo, requested_models, csv_results = load_all_info_from_dataset_hub() |
|
|
|
if not csv_results.exists(): |
|
raise Exception(f"CSV file {csv_results} does not exist locally") |
|
|
|
|
|
original_df = pd.read_csv(csv_results) |
|
|
|
|
|
def formatter(x): |
|
if type(x) is str: |
|
return x |
|
else: |
|
return round(x, 2) |
|
|
|
|
|
for col in original_df.columns: |
|
if col in ['Model']: |
|
continue |
|
else: |
|
original_df[col] = original_df[col].apply(formatter) |
|
|
|
original_df.rename(columns=column_names, inplace=True) |
|
original_df.sort_values(by='asr-farsi-youtube WER', inplace=True) |
|
|
|
COLS = [c.name for c in fields(AutoEvalColumn)] |
|
TYPES = [c.type for c in fields(AutoEvalColumn)] |
|
|
|
def request_model(model_text): |
|
global original_df |
|
|
|
base_model_on_hub, error_msg = is_model_on_hub(model_text) |
|
|
|
if not base_model_on_hub: |
|
return styled_error(f"Base model '{model_text}' {error_msg}") |
|
|
|
|
|
if model_text in original_df['Model'].values: |
|
return styled_error(f"The model '{model_text}' is already in the leaderboard.") |
|
|
|
try: |
|
|
|
from transformers import pipeline |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from transformers.pipelines.pt_utils import KeyDataset |
|
from evaluate import load |
|
|
|
|
|
common_voice_test = load_dataset( |
|
"mozilla-foundation/common_voice_17_0", "fa", split="test" |
|
).shuffle(seed=42).select(range(len(load_dataset( |
|
"mozilla-foundation/common_voice_17_0", "fa", split="test")) // 150)) |
|
|
|
|
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model_text, |
|
torch_dtype=torch.float16, |
|
device=0 if torch.cuda.is_available() else -1, |
|
) |
|
|
|
all_predictions = [] |
|
|
|
|
|
for prediction in tqdm( |
|
pipe( |
|
KeyDataset(common_voice_test, "audio"), |
|
max_new_tokens=128, |
|
chunk_length_s=30, |
|
generate_kwargs={"task": "transcribe"}, |
|
batch_size=32, |
|
), |
|
total=len(common_voice_test), |
|
): |
|
all_predictions.append(prediction["text"]) |
|
|
|
wer_metric = load("wer") |
|
cer_metric = load("cer") |
|
|
|
wer_result = 100 * wer_metric.compute( |
|
references=common_voice_test["sentence"], predictions=all_predictions |
|
) |
|
|
|
cer_result = 100 * cer_metric.compute( |
|
references=common_voice_test["sentence"], predictions=all_predictions |
|
) |
|
|
|
|
|
new_row = {'model': model_text, 'Common Voice WER β¬οΈ': wer_result, 'Common Voice CER β¬οΈ': cer_result} |
|
df_results = pd.read_csv(csv_results) |
|
df_results = df_results.append(new_row, ignore_index=True) |
|
df_results.to_csv(csv_results, index=False) |
|
|
|
|
|
original_df = df_results.copy() |
|
original_df['Model'] = original_df['Model'].apply(make_clickable_model) |
|
for col in original_df.columns: |
|
if col in ['Model']: |
|
continue |
|
else: |
|
original_df[col] = original_df[col].apply(formatter) |
|
original_df.rename(columns=column_names, inplace=True) |
|
original_df.sort_values(by='asr-farsi-youtube WER', inplace=True) |
|
|
|
|
|
leaderboard_table.update(value=original_df) |
|
|
|
|
|
return styled_message("π€ Your model has been evaluated and added to the leaderboard!") |
|
|
|
except Exception as e: |
|
return styled_error(f"Error during evaluation: {e}") |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(BANNER, elem_id="banner") |
|
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") |
|
|
|
with gr.Tabs(elem_classes="tab-buttons") as tabs: |
|
with gr.TabItem("π
Leaderboard", elem_id="od-benchmark-tab-table", id=0): |
|
leaderboard_table = gr.Dataframe( |
|
value=original_df, |
|
datatype=TYPES, |
|
elem_id="leaderboard-table", |
|
interactive=False, |
|
visible=True, |
|
) |
|
|
|
with gr.TabItem("π Metrics", elem_id="od-benchmark-tab-table", id=1): |
|
gr.Markdown(METRICS_TAB_TEXT, elem_classes="markdown-text") |
|
|
|
with gr.TabItem("βοΈβ¨ Request a model here!", elem_id="od-benchmark-tab-table", id=2): |
|
with gr.Column(): |
|
gr.Markdown("# βοΈβ¨ Request results for a new model here!", elem_classes="markdown-text") |
|
model_name_textbox = gr.Textbox(label="Model name (user_name/model_name)") |
|
mdw_submission_result = gr.Markdown() |
|
btn_submit = gr.Button(value="π Request") |
|
btn_submit.click(request_model, [model_name_textbox], mdw_submission_result) |
|
|
|
gr.Markdown(f"Last updated on **{LAST_UPDATED}**", elem_classes="markdown-text") |
|
|
|
with gr.Row(): |
|
with gr.Accordion("π Citation", open=False): |
|
gr.Textbox( |
|
value=CITATION_TEXT, lines=7, |
|
label="Copy the BibTeX snippet to cite this source", |
|
elem_id="citation-button", |
|
show_copy_button=True, |
|
) |
|
|
|
demo.launch() |
|
|