|
import gradio as gr |
|
import pandas as pd |
|
import json |
|
from constants import BANNER, INTRODUCTION_TEXT, CITATION_TEXT, METRICS_TAB_TEXT, DIR_OUTPUT_REQUESTS |
|
from init import is_model_on_hub, load_all_info_from_dataset_hub |
|
from utils_display import AutoEvalColumn, fields, make_clickable_model, styled_error, styled_message |
|
from datetime import datetime, timezone |
|
import torch |
|
|
|
LAST_UPDATED = "OCT 2nd 2024" |
|
|
|
column_names = { |
|
"MODEL": "Model", |
|
"WER": "WER ⬇️", |
|
"CER": "CER ⬇️", |
|
} |
|
|
|
|
|
eval_queue_repo, requested_models, csv_results = load_all_info_from_dataset_hub() |
|
|
|
if not csv_results.exists(): |
|
raise Exception(f"CSV file {csv_results} does not exist locally") |
|
|
|
|
|
original_df = pd.read_csv(csv_results) |
|
|
|
|
|
def formatter(x): |
|
if type(x) is str: |
|
return x |
|
else: |
|
return round(x, 2) |
|
|
|
for col in original_df.columns: |
|
if col == "model": |
|
original_df[col] = original_df[col].apply(lambda x: x.replace(x, make_clickable_model(x))) |
|
else: |
|
original_df[col] = original_df[col].apply(formatter) |
|
|
|
original_df.rename(columns=column_names, inplace=True) |
|
original_df.sort_values(by='WER ⬇️', inplace=True) |
|
|
|
COLS = [c.name for c in fields(AutoEvalColumn)] |
|
TYPES = [c.type for c in fields(AutoEvalColumn)] |
|
|
|
def request_model(model_text): |
|
global original_df |
|
|
|
base_model_on_hub, error_msg = is_model_on_hub(model_text) |
|
|
|
if not base_model_on_hub: |
|
return styled_error(f"Base model '{model_text}' {error_msg}") |
|
|
|
|
|
if model_text in original_df['Model'].apply(lambda x: x.split('href="')[1].split('"')[0].replace('https://huggingface.co/', '')).values: |
|
return styled_error(f"The model '{model_text}' is already in the leaderboard.") |
|
|
|
try: |
|
|
|
from transformers import pipeline |
|
from transformers.utils import is_flash_attn_2_available |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from transformers.pipelines.pt_utils import KeyDataset |
|
from evaluate import load |
|
|
|
|
|
common_voice_test = load_dataset( |
|
"mozilla-foundation/common_voice_17_0", "fa", split="test" |
|
).shuffle(seed=42).select(range(len(load_dataset( |
|
"mozilla-foundation/common_voice_17_0", "fa", split="test")) // 15)) |
|
|
|
|
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model_text, |
|
torch_dtype=torch.float32, |
|
device=0 if torch.cuda.is_available() else -1, |
|
) |
|
|
|
all_predictions = [] |
|
|
|
|
|
for prediction in tqdm( |
|
pipe( |
|
KeyDataset(common_voice_test, "audio"), |
|
max_new_tokens=128, |
|
chunk_length_s=30, |
|
generate_kwargs={"task": "transcribe"}, |
|
batch_size=32, |
|
), |
|
total=len(common_voice_test), |
|
): |
|
all_predictions.append(prediction["text"]) |
|
|
|
wer_metric = load("wer") |
|
cer_metric = load("cer") |
|
|
|
wer_result = 100 * wer_metric.compute( |
|
references=common_voice_test["sentence"], predictions=all_predictions |
|
) |
|
|
|
cer_result = 100 * cer_metric.compute( |
|
references=common_voice_test["sentence"], predictions=all_predictions |
|
) |
|
|
|
|
|
new_row = {'model': model_text, 'wer': wer_result, 'cer': cer_result} |
|
df_results = pd.read_csv(csv_results) |
|
df_results = df_results.append(new_row, ignore_index=True) |
|
df_results.to_csv(csv_results, index=False) |
|
|
|
|
|
|
|
original_df = df_results.copy() |
|
original_df['Model'] = original_df['model'].apply(lambda x: make_clickable_model(x)) |
|
original_df['WER ⬇️'] = original_df['wer'].apply(lambda x: round(x, 2)) |
|
original_df['CER ⬇️'] = original_df['cer'].apply(lambda x: round(x, 2)) |
|
original_df = original_df[['Model', 'WER ⬇️', 'CER ⬇️']] |
|
original_df.sort_values(by='WER ⬇️', inplace=True) |
|
|
|
|
|
leaderboard_table.update(value=original_df) |
|
|
|
|
|
return styled_message("🤗 Your model has been evaluated and added to the leaderboard!") |
|
|
|
except Exception as e: |
|
return styled_error(f"Error during evaluation: {e}") |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(BANNER, elem_id="banner") |
|
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") |
|
|
|
with gr.Tabs(elem_classes="tab-buttons") as tabs: |
|
with gr.TabItem("🏅 Leaderboard", elem_id="od-benchmark-tab-table", id=0): |
|
leaderboard_table = gr.Dataframe( |
|
value=original_df, |
|
datatype=TYPES, |
|
elem_id="leaderboard-table", |
|
interactive=False, |
|
visible=True, |
|
) |
|
|
|
with gr.TabItem("📈 Metrics", elem_id="od-benchmark-tab-table", id=1): |
|
gr.Markdown(METRICS_TAB_TEXT, elem_classes="markdown-text") |
|
|
|
with gr.TabItem("✉️✨ Request a model here!", elem_id="od-benchmark-tab-table", id=2): |
|
with gr.Column(): |
|
gr.Markdown("# ✉️✨ Request results for a new model here!", elem_classes="markdown-text") |
|
model_name_textbox = gr.Textbox(label="Model name (user_name/model_name)") |
|
mdw_submission_result = gr.Markdown() |
|
btn_submit = gr.Button(value="🚀 Request") |
|
btn_submit.click(request_model, [model_name_textbox], mdw_submission_result) |
|
|
|
gr.Markdown(f"Last updated on **{LAST_UPDATED}**", elem_classes="markdown-text") |
|
|
|
with gr.Row(): |
|
with gr.Accordion("📙 Citation", open=False): |
|
gr.Textbox( |
|
value=CITATION_TEXT, lines=7, |
|
label="Copy the BibTeX snippet to cite this source", |
|
elem_id="citation-button", |
|
show_copy_button=True, |
|
) |
|
|
|
demo.launch() |
|
|