navidved's picture
Update app.py
0ff8797 verified
raw
history blame
6.28 kB
import gradio as gr
import pandas as pd
import json
from constants import BANNER, INTRODUCTION_TEXT, CITATION_TEXT, METRICS_TAB_TEXT, DIR_OUTPUT_REQUESTS
from init import is_model_on_hub, load_all_info_from_dataset_hub
from utils_display import AutoEvalColumn, fields, make_clickable_model, styled_error, styled_message
from datetime import datetime, timezone
import torch
LAST_UPDATED = "OCT 2nd 2024"
column_names = {
"MODEL": "Model",
"WER": "WER ⬇️",
"CER": "CER ⬇️",
}
# Load evaluation results
eval_queue_repo, requested_models, csv_results = load_all_info_from_dataset_hub()
if not csv_results.exists():
raise Exception(f"CSV file {csv_results} does not exist locally")
# Read CSV with data and parse columns
original_df = pd.read_csv(csv_results)
# Format the columns
def formatter(x):
if type(x) is str:
return x
else:
return round(x, 2)
# Apply formatting to other columns
for col in original_df.columns:
if col in ['Model']:
continue # Skip the 'model' and 'Model' columns
else:
original_df[col] = original_df[col].apply(formatter)
original_df.rename(columns=column_names, inplace=True)
original_df.sort_values(by='WER ⬇️', inplace=True)
COLS = [c.name for c in fields(AutoEvalColumn)]
TYPES = [c.type for c in fields(AutoEvalColumn)]
def request_model(model_text):
global original_df
# Check if the model exists on the Hub
base_model_on_hub, error_msg = is_model_on_hub(model_text)
if not base_model_on_hub:
return styled_error(f"Base model '{model_text}' {error_msg}")
# Check if the model has already been evaluated using the raw 'model' column
if model_text in original_df['Model'].values:
return styled_error(f"The model '{model_text}' is already in the leaderboard.")
try:
# Run the evaluation code
from transformers import pipeline
from transformers.utils import is_flash_attn_2_available
from datasets import load_dataset
from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset
from evaluate import load
# Load a subset of the Common Voice test dataset for evaluation
common_voice_test = load_dataset(
"mozilla-foundation/common_voice_17_0", "fa", split="test"
).shuffle(seed=42).select(range(len(load_dataset(
"mozilla-foundation/common_voice_17_0", "fa", split="test")) // 15))
# Initialize the pipeline with the requested model
pipe = pipeline(
"automatic-speech-recognition",
model=model_text,
torch_dtype=torch.float16,
device=0 if torch.cuda.is_available() else -1, # Use GPU if available
)
all_predictions = []
# Run inference
for prediction in tqdm(
pipe(
KeyDataset(common_voice_test, "audio"),
max_new_tokens=128,
chunk_length_s=30,
generate_kwargs={"task": "transcribe"},
batch_size=32,
),
total=len(common_voice_test),
):
all_predictions.append(prediction["text"])
wer_metric = load("wer")
cer_metric = load("cer")
wer_result = 100 * wer_metric.compute(
references=common_voice_test["sentence"], predictions=all_predictions
)
cer_result = 100 * cer_metric.compute(
references=common_voice_test["sentence"], predictions=all_predictions
)
# Update the results CSV
new_row = {'model': model_text, 'wer': wer_result, 'cer': cer_result}
df_results = pd.read_csv(csv_results)
df_results = df_results.append(new_row, ignore_index=True)
df_results.to_csv(csv_results, index=False)
# Update the leaderboard DataFrame
original_df = df_results.copy()
original_df['Model'] = original_df['Model'].apply(make_clickable_model)
for col in original_df.columns:
if col in ['Model']:
continue # Skip the 'model' and 'Model' columns
else:
original_df[col] = original_df[col].apply(formatter)
original_df.rename(columns=column_names, inplace=True)
original_df.sort_values(by='WER ⬇️', inplace=True)
# Update the leaderboard table in the UI
leaderboard_table.update(value=original_df)
# Return success message
return styled_message("πŸ€— Your model has been evaluated and added to the leaderboard!")
except Exception as e:
return styled_error(f"Error during evaluation: {e}")
with gr.Blocks() as demo:
gr.HTML(BANNER, elem_id="banner")
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("πŸ… Leaderboard", elem_id="od-benchmark-tab-table", id=0):
leaderboard_table = gr.Dataframe(
value=original_df,
datatype=TYPES,
elem_id="leaderboard-table",
interactive=False,
visible=True,
)
with gr.TabItem("πŸ“ˆ Metrics", elem_id="od-benchmark-tab-table", id=1):
gr.Markdown(METRICS_TAB_TEXT, elem_classes="markdown-text")
with gr.TabItem("βœ‰οΈβœ¨ Request a model here!", elem_id="od-benchmark-tab-table", id=2):
with gr.Column():
gr.Markdown("# βœ‰οΈβœ¨ Request results for a new model here!", elem_classes="markdown-text")
model_name_textbox = gr.Textbox(label="Model name (user_name/model_name)")
mdw_submission_result = gr.Markdown()
btn_submit = gr.Button(value="πŸš€ Request")
btn_submit.click(request_model, [model_name_textbox], mdw_submission_result)
gr.Markdown(f"Last updated on **{LAST_UPDATED}**", elem_classes="markdown-text")
with gr.Row():
with gr.Accordion("πŸ“™ Citation", open=False):
gr.Textbox(
value=CITATION_TEXT, lines=7,
label="Copy the BibTeX snippet to cite this source",
elem_id="citation-button",
show_copy_button=True,
)
demo.launch()