ardaatahan's picture
change multilingual wer copy logic
8f1c387
raw
history blame
58 kB
"""
Main module for the WhisperKit Evaluation Dashboard.
This module sets up and runs the Gradio interface for the WhisperKit Evaluation Dashboard,
allowing users to explore and compare speech recognition model performance across different
devices, operating systems, and datasets.
"""
import json
import os
import re
from math import ceil, floor
import gradio as gr
import pandas as pd
from argmax_gradio_components import RangeSlider
from dotenv import load_dotenv
from huggingface_hub import login
# Import custom constants and utility functions
from constants import (
BANNER_TEXT,
CITATION_BUTTON_LABEL,
CITATION_BUTTON_TEXT,
COL_NAMES,
HEADER,
LANGUAGE_MAP,
METHODOLOGY_TEXT,
PERFORMANCE_TEXT,
QUALITY_TEXT,
)
from utils import (
add_datasets_to_performance_columns,
add_datasets_to_quality_columns,
create_confusion_matrix_plot,
create_initial_performance_column_dict,
create_initial_quality_column_dict,
css,
fields,
get_os_name_and_version,
make_dataset_wer_clickable_link,
make_model_name_clickable_link,
make_multilingual_model_clickable_link,
plot_metric,
read_json_line_by_line,
)
# Load environment variables
load_dotenv()
# Get the Hugging Face token from the environment variable
HF_TOKEN = os.getenv("HF_TOKEN")
# Use the token for login
login(token=HF_TOKEN, add_to_git_credential=True)
# Define repository and directory information
repo_id = "argmaxinc/whisperkit-evals-dataset"
directory = "xcresults/benchmark_results"
local_dir = ""
# Load benchmark data from JSON files
PERFORMANCE_DATA = read_json_line_by_line("dashboard_data/performance_data.json")
QUALITY_DATA = read_json_line_by_line("dashboard_data/quality_data.json")
# Convert JSON data to pandas DataFrames
quality_df = pd.json_normalize(QUALITY_DATA)
benchmark_df = pd.json_normalize(PERFORMANCE_DATA)
# Process timestamp data
benchmark_df["timestamp"] = pd.to_datetime(benchmark_df["timestamp"]).dt.tz_localize(
None
)
benchmark_df["timestamp"] = pd.to_datetime(benchmark_df["timestamp"]).dt.tz_localize(
None
)
# First create a temporary column for model length
sorted_quality_df = (
quality_df.assign(model_len=quality_df["model"].str.len())
.sort_values(
by=["model_len", "model", "timestamp"],
ascending=[True, True, False],
)
.drop(columns=["model_len"])
.drop_duplicates(subset=["model"], keep="first")
.reset_index(drop=True)
)
multilingual_df = pd.read_csv("dashboard_data/multilingual_results.csv")
multilingual_models_df = multilingual_df[["Model"]].drop_duplicates()
multilingual_models_buttons = []
for model in multilingual_models_df["Model"]:
elem_id = (
f"{model}".replace(" ", "_").replace('"', "").replace("'", "").replace(",", "")
)
multilingual_models_buttons.append(
gr.Button(value=model, elem_id=elem_id, visible=False)
)
multilingual_models_df["Model"] = multilingual_models_df["Model"].apply(
lambda x: make_multilingual_model_clickable_link(x)
)
with open("dashboard_data/multilingual_confusion_matrices.json", "r") as file:
confusion_matrix_map = dict(json.load(file))
# Create a mapping of model to average WER
model_to_english_wer = dict(zip(sorted_quality_df["model"], sorted_quality_df["average_wer"]))
model_to_multilingual_wer = dict(
zip(multilingual_df["Model"], multilingual_df["Average WER"])
)
# Copy over the multilingual WER to matching models
multilingual_models = {}
for multilingual_model, multilingual_wer in model_to_multilingual_wer.items():
for model in benchmark_df["model"].unique().tolist():
if model in model_to_multilingual_wer:
continue
if model == "openai/whisper-large-v3-v20240930/turbo/632MB":
multilingual_models[model] = model_to_multilingual_wer["openai/whisper-large-v3-v20240930"]
if model.endswith("/turbo") and model.replace("/turbo", "") in model_to_multilingual_wer:
multilingual_models[model] = model_to_multilingual_wer[model.replace("/turbo", "")]
elif multilingual_model in model and not model.endswith("en"):
multilingual_models[model] = multilingual_wer
# Update the dictionary with turbo models
model_to_multilingual_wer.update(multilingual_models)
# Add English WER and Multilingual WER to performance_df
benchmark_df["english_wer"] = benchmark_df["model"].map(model_to_english_wer)
benchmark_df["multilingual_wer"] = benchmark_df["model"].map(model_to_multilingual_wer)
benchmark_df.fillna({"multilingual_wer": "English-only model"}, inplace=True)
benchmark_df["multilingual_wer"] = benchmark_df["multilingual_wer"].astype(str)
sorted_performance_df = (
benchmark_df.assign(model_len=benchmark_df["model"].str.len())
.sort_values(
by=["model_len", "model", "device", "os", "timestamp"],
ascending=[True, True, True, True, False],
)
.drop(columns=["model_len"])
.drop_duplicates(subset=["model", "device", "os"], keep="first")
.reset_index(drop=True)
)
# Identify dataset-specific columns
dataset_wer_columns = [
col for col in sorted_quality_df.columns if col.startswith("dataset_wer.")
]
dataset_speed_columns = [
col for col in sorted_performance_df.columns if col.startswith("dataset_speed.")
]
dataset_toks_columns = [
col
for col in sorted_performance_df.columns
if col.startswith("dataset_tokens_per_second.")
]
# Extract dataset names
QUALITY_DATASETS = [col.split(".")[-1] for col in dataset_wer_columns]
PERFORMANCE_DATASETS = [col.split(".")[-1] for col in dataset_speed_columns]
# Prepare DataFrames for display
model_df = sorted_quality_df[
["model", "average_wer", "qoi", "timestamp"] + dataset_wer_columns
]
performance_df = sorted_performance_df[
[
"model",
"device",
"os",
"english_wer",
"multilingual_wer",
"qoi",
"speed",
"tokens_per_second",
"timestamp",
]
+ dataset_speed_columns
+ dataset_toks_columns
].copy()
# Rename columns for clarity
performance_df = performance_df.rename(
lambda x: COL_NAMES[x] if x in COL_NAMES else x, axis="columns"
)
model_df = model_df.rename(
lambda x: COL_NAMES[x] if x in COL_NAMES else x, axis="columns"
)
# Process dataset-specific columns
for col in dataset_wer_columns:
dataset_name = col.split(".")[-1]
model_df = model_df.rename(columns={col: dataset_name})
model_df[dataset_name] = model_df.apply(
lambda x: make_dataset_wer_clickable_link(x, dataset_name), axis=1
)
for col in dataset_speed_columns:
dataset_name = col.split(".")[-1]
performance_df = performance_df.rename(
columns={
col: f"{'Short-Form' if dataset_name == 'librispeech-10mins' else 'Long-Form'} Speed"
}
)
for col in dataset_toks_columns:
dataset_name = col.split(".")[-1]
performance_df = performance_df.rename(
columns={
col: f"{'Short-Form' if dataset_name == 'librispeech-10mins' else 'Long-Form'} Tok/s"
}
)
# Process model names for display
model_df["model_raw"] = model_df["Model"].copy()
performance_df["model_raw"] = performance_df["Model"].copy()
model_df["Model"] = model_df["Model"].apply(lambda x: make_model_name_clickable_link(x))
performance_df["Model"] = performance_df["Model"].apply(
lambda x: make_model_name_clickable_link(x)
)
# Extract unique devices and OS versions
PERFORMANCE_DEVICES = performance_df["Device"].unique().tolist()
PERFORMANCE_OS = performance_df["OS"].apply(get_os_name_and_version).unique().tolist()
PERFORMANCE_OS.sort()
# Create initial column dictionaries and update with dataset information
initial_performance_column_dict = create_initial_performance_column_dict()
initial_quality_column_dict = create_initial_quality_column_dict()
performance_column_info = add_datasets_to_performance_columns(
initial_performance_column_dict, PERFORMANCE_DATASETS
)
quality_column_info = add_datasets_to_quality_columns(
initial_quality_column_dict, QUALITY_DATASETS
)
# Unpack the returned dictionaries
updated_performance_column_dict = performance_column_info["column_dict"]
updated_quality_column_dict = quality_column_info["column_dict"]
PerformanceAutoEvalColumn = performance_column_info["AutoEvalColumn"]
QualityAutoEvalColumn = quality_column_info["AutoEvalColumn"]
# Define column sets for different views
PERFORMANCE_COLS = performance_column_info["COLS"]
QUALITY_COLS = quality_column_info["COLS"]
PERFORMANCE_TYPES = performance_column_info["TYPES"]
QUALITY_TYPES = quality_column_info["TYPES"]
PERFORMANCE_ALWAYS_HERE_COLS = performance_column_info["ALWAYS_HERE_COLS"]
QUALITY_ALWAYS_HERE_COLS = quality_column_info["ALWAYS_HERE_COLS"]
PERFORMANCE_TOGGLE_COLS = performance_column_info["TOGGLE_COLS"]
QUALITY_TOGGLE_COLS = quality_column_info["TOGGLE_COLS"]
PERFORMANCE_SELECTED_COLS = performance_column_info["SELECTED_COLS"]
QUALITY_SELECTED_COLS = quality_column_info["SELECTED_COLS"]
def performance_filter(
df,
columns,
model_query,
exclude_models,
devices,
os,
short_speed_slider,
long_speed_slider,
short_toks_slider,
long_toks_slider,
):
"""
Filters the performance DataFrame based on specified criteria.
:param df: The DataFrame to be filtered.
:param columns: The columns to be included in the filtered DataFrame.
:param model_query: The query string to filter the 'Model' column.
:param exclude_models: Models to exclude from the results.
:param devices: The devices to filter the 'Device' column.
:param os: The list of operating systems to filter the 'OS' column.
:param short_speed_slider: The range of values to filter the 'Short-Form Speed' column.
:param long_speed_slider: The range of values to filter the 'Long-Form Speed' column.
:param short_toks_slider: The range of values to filter the 'Short-Form Tok/s' column.
:param long_toks_slider: The range of values to filter the 'Long-Form Tok/s' column.
:return: The filtered DataFrame.
"""
# Select columns based on input and always-present columns
filtered_df = df[
PERFORMANCE_ALWAYS_HERE_COLS
+ [c for c in PERFORMANCE_COLS if c in df.columns and c in columns]
]
# Filter models based on query
if model_query:
filtered_df = filtered_df[
filtered_df["Model"].str.contains(
"|".join(q.strip() for q in model_query.split(";")), case=False
)
]
# Exclude specified models
if exclude_models:
exclude_list = [m.strip() for m in exclude_models.split(";")]
filtered_df = filtered_df[
~filtered_df["Model"].str.contains("|".join(exclude_list), case=False)
]
# Filter by devices
if devices:
filtered_df = filtered_df[filtered_df["Device"].isin(devices)]
else:
filtered_df = pd.DataFrame(columns=filtered_df.columns)
# Filter by operating systems
filtered_df = (
filtered_df[
(
filtered_df["OS"].str.contains(
"|".join(q.strip() for q in os), case=False
)
)
]
if os
else pd.DataFrame(columns=filtered_df.columns)
)
# Apply short-form and long-form speed and tokens per second filters
min_short_speed, max_short_speed = short_speed_slider
min_long_speed, max_long_speed = long_speed_slider
min_short_toks, max_short_toks = short_toks_slider
min_long_toks, max_long_toks = long_toks_slider
df["Short-Form Speed"] = pd.to_numeric(df["Short-Form Speed"], errors="coerce")
df["Long-Form Speed"] = pd.to_numeric(df["Long-Form Speed"], errors="coerce")
df["Short-Form Tok/s"] = pd.to_numeric(df["Short-Form Tok/s"], errors="coerce")
df["Long-Form Tok/s"] = pd.to_numeric(df["Long-Form Tok/s"], errors="coerce")
if "Short-Form Speed" in filtered_df.columns:
filtered_df = filtered_df[
(filtered_df["Short-Form Speed"] >= min_short_speed)
& (filtered_df["Short-Form Speed"] <= max_short_speed)
]
if "Long-Form Speed" in filtered_df.columns:
filtered_df = filtered_df[
(filtered_df["Long-Form Speed"] >= min_long_speed)
& (filtered_df["Long-Form Speed"] <= max_long_speed)
]
if "Short-Form Tok/s" in filtered_df.columns:
filtered_df = filtered_df[
(filtered_df["Short-Form Tok/s"] >= min_short_toks)
& (filtered_df["Short-Form Tok/s"] <= max_short_toks)
]
if "Long-Form Tok/s" in filtered_df.columns:
filtered_df = filtered_df[
(filtered_df["Long-Form Tok/s"] >= min_long_toks)
& (filtered_df["Long-Form Tok/s"] <= max_long_toks)
]
return filtered_df
def quality_filter(df, columns, model_query, wer_slider, qoi_slider, exclude_models):
"""
Filters the quality DataFrame based on specified criteria.
:param df: The DataFrame to be filtered.
:param columns: The columns to be included in the filtered DataFrame.
:param model_query: The query string to filter the 'Model' column.
:param wer_slider: The range of values to filter the 'Average WER' column.
:param qoi_slider: The range of values to filter the 'QoI' column.
:param exclude_models: Models to exclude from the results.
:return: The filtered DataFrame.
"""
# Select columns based on input and always-present columns
filtered_df = df[
QUALITY_ALWAYS_HERE_COLS
+ [c for c in QUALITY_COLS if c in df.columns and c in columns]
]
# Filter models based on query
if model_query:
filtered_df = filtered_df[
filtered_df["Model"].str.contains(
"|".join(q.strip() for q in model_query.split(";")), case=False
)
]
# Exclude specified models
if exclude_models:
exclude_list = [m.strip() for m in exclude_models.split(";")]
filtered_df = filtered_df[
~filtered_df["Model"].str.contains("|".join(exclude_list), case=False)
]
# Apply WER and QoI filters
min_wer_slider, max_wer_slider = wer_slider
min_qoi_slider, max_qoi_slider = qoi_slider
if "Average WER" in filtered_df.columns:
filtered_df = filtered_df[
(filtered_df["Average WER"] >= min_wer_slider)
& (filtered_df["Average WER"] <= max_wer_slider)
]
if "QoI" in filtered_df.columns:
filtered_df = filtered_df[
(filtered_df["QoI"] >= min_qoi_slider)
& (filtered_df["QoI"] <= max_qoi_slider)
]
return filtered_df
diff_tab = gr.TabItem("Difference Checker", elem_id="diff_checker", id=2)
text_diff_elems = []
tabs = gr.Tabs(elem_id="tab-elems")
def update_multilingual_results(selected_model):
"""
Updates the multilingual results display based on the selected model.
This function processes the multilingual data for the chosen model,
calculates average WER for different scenarios (language hinted vs. predicted),
and prepares language-specific WER data for display.
:param selected_model: The name of the selected model
:return: A list containing updated components for the Gradio interface
"""
if selected_model is None:
return "# Select a model from the dropdown to view results."
# Filter data for the selected model
model_data = multilingual_df[multilingual_df["Model"] == selected_model]
if model_data.empty:
return f"# No data available for model: {selected_model}"
# Separate data for forced and not forced scenarios
forced_data = model_data[model_data["Forced Tokens"] == True]
not_forced_data = model_data[model_data["Forced Tokens"] == False]
result_text = f"# Model: {selected_model}\n\n"
# Prepare average WER data
average_wer_data = []
if not forced_data.empty:
average_wer_data.append(
{
"Scenario": "Language Hinted",
"Average WER": forced_data.iloc[0]["Average WER"],
}
)
if not not_forced_data.empty:
average_wer_data.append(
{
"Scenario": "Language Predicted",
"Average WER": not_forced_data.iloc[0]["Average WER"],
}
)
average_wer_df = pd.DataFrame(average_wer_data)
average_wer_df["Average WER"] = average_wer_df["Average WER"].apply(
lambda x: round(x, 2)
)
# Prepare language-specific WER data
lang_columns = [col for col in model_data.columns if col.startswith("WER_")]
lang_wer_data = []
for column in lang_columns:
lang = column.split("_")[1]
forced_wer = forced_data[column].iloc[0] if not forced_data.empty else None
not_forced_wer = (
not_forced_data[column].iloc[0] if not not_forced_data.empty else None
)
if forced_wer is not None or not_forced_wer is not None:
lang_wer_data.append(
{
"Language": LANGUAGE_MAP[lang],
"Language Hinted WER": round(forced_wer, 2)
if forced_wer is not None
else "N/A",
"Language Predicted WER": round(not_forced_wer, 2)
if not_forced_wer is not None
else "N/A",
}
)
lang_wer_df = pd.DataFrame(lang_wer_data)
lang_wer_df = lang_wer_df.fillna("No Data")
# Create confusion matrix plot for unforced scenario
unforced_plot = None
if selected_model in confusion_matrix_map:
if "not_forced" in confusion_matrix_map[selected_model]:
unforced_plot = create_confusion_matrix_plot(
confusion_matrix_map[selected_model]["not_forced"]["matrix"],
confusion_matrix_map[selected_model]["not_forced"]["labels"],
False,
)
# Return updated components for Gradio interface
return [
gr.update(value=result_text),
gr.update(visible=True, value=average_wer_df),
gr.update(visible=True, value=lang_wer_df),
gr.update(visible=unforced_plot is not None, value=unforced_plot),
]
font = [
"Zwizz Regular", # Local font
"IBM Plex Mono", # Monospace font
"ui-sans-serif",
"system-ui",
"sans-serif",
]
# Define the Gradio interface
with gr.Blocks(css=css, theme=gr.themes.Base(font=font)) as demo:
# Add header and banner to the interface
gr.HTML(HEADER)
gr.HTML(BANNER_TEXT, elem_classes="markdown-text")
# Create tabs for different sections of the dashboard
with tabs.render():
# Performance Tab
with gr.TabItem("Performance", elem_id="benchmark", id=0):
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
with gr.Column(scale=6, elem_classes="filter_models_column"):
filter_performance_models = gr.Textbox(
placeholder="πŸ” Filter Model (separate multiple queries with ';')",
label="Filter Models",
)
with gr.Column(scale=4, elem_classes="exclude_models_column"):
exclude_performance_models = gr.Textbox(
placeholder="πŸ” Exclude Model",
label="Exclude Model",
)
with gr.Row():
with gr.Accordion("See All Columns", open=False):
with gr.Row():
with gr.Column(scale=9, elem_id="performance_columns"):
performance_shown_columns = gr.CheckboxGroup(
choices=PERFORMANCE_TOGGLE_COLS,
value=PERFORMANCE_SELECTED_COLS,
label="Toggle Columns",
elem_id="column-select",
interactive=True,
)
with gr.Column(
scale=1,
min_width=200,
elem_id="performance_select_columns",
):
with gr.Row():
select_all_button = gr.Button(
"Select All",
elem_id="select-all-button",
interactive=True,
)
deselect_all_button = gr.Button(
"Deselect All",
elem_id="deselect-all-button",
interactive=True,
)
def select_all_columns():
return PERFORMANCE_TOGGLE_COLS
def deselect_all_columns():
return []
select_all_button.click(
select_all_columns,
inputs=[],
outputs=performance_shown_columns,
)
deselect_all_button.click(
deselect_all_columns,
inputs=[],
outputs=performance_shown_columns,
)
with gr.Row():
with gr.Accordion("Filter Devices", open=False):
with gr.Row():
with gr.Column(
scale=9, elem_id="filter_devices_column"
):
performance_shown_devices = gr.CheckboxGroup(
choices=PERFORMANCE_DEVICES,
value=PERFORMANCE_DEVICES,
label="Filter Devices",
interactive=True,
)
with gr.Column(
scale=1,
min_width=200,
elem_id="filter_select_devices",
):
with gr.Row():
select_all_devices_button = gr.Button(
"Select All",
elem_id="select-all-devices-button",
interactive=True,
)
deselect_all_devices_button = gr.Button(
"Deselect All",
elem_id="deselect-all-devices-button",
interactive=True,
)
def select_all_devices():
return PERFORMANCE_DEVICES
def deselect_all_devices():
return []
select_all_devices_button.click(
select_all_devices,
inputs=[],
outputs=performance_shown_devices,
)
deselect_all_devices_button.click(
deselect_all_devices,
inputs=[],
outputs=performance_shown_devices,
)
with gr.Row():
performance_shown_os = gr.CheckboxGroup(
choices=PERFORMANCE_OS,
value=PERFORMANCE_OS,
label="Filter OS",
interactive=True,
)
with gr.Column(scale=1):
with gr.Accordion("See Performance Filters"):
with gr.Row():
with gr.Row():
min_short_speed, max_short_speed = floor(
min(performance_df["Short-Form Speed"])
), ceil(max(performance_df["Short-Form Speed"]))
short_speed_slider = RangeSlider(
value=[min_short_speed, max_short_speed],
minimum=min_short_speed,
maximum=max_short_speed,
step=0.001,
label="Short-Form Speed",
)
with gr.Row():
min_long_speed, max_long_speed = floor(
min(performance_df["Long-Form Speed"])
), ceil(max(performance_df["Long-Form Speed"]))
long_speed_slider = RangeSlider(
value=[min_long_speed, max_long_speed],
minimum=min_long_speed,
maximum=max_long_speed,
step=0.001,
label="Long-Form Speed",
)
with gr.Row():
with gr.Row():
min_short_toks, max_short_toks = floor(
min(performance_df["Short-Form Tok/s"])
), ceil(max(performance_df["Short-Form Tok/s"]))
short_toks_slider = RangeSlider(
value=[min_short_toks, max_short_toks],
minimum=min_short_toks,
maximum=max_short_toks,
step=0.001,
label="Short-Form Tok/s",
)
with gr.Row():
min_long_toks, max_long_toks = floor(
min(performance_df["Long-Form Tok/s"])
), ceil(max(performance_df["Long-Form Tok/s"]))
long_toks_slider = RangeSlider(
value=[min_long_toks, max_long_toks],
minimum=min_long_toks,
maximum=max_long_toks,
step=0.001,
label="Long-Form Tok/s",
)
with gr.Row():
gr.Markdown(PERFORMANCE_TEXT, elem_classes="markdown-text")
with gr.Row():
leaderboard_df = gr.components.Dataframe(
value=performance_df[
PERFORMANCE_ALWAYS_HERE_COLS + performance_shown_columns.value
],
headers=[
PERFORMANCE_ALWAYS_HERE_COLS + performance_shown_columns.value
],
datatype=[
c.type
for c in fields(PerformanceAutoEvalColumn)
if c.name in PERFORMANCE_COLS
],
elem_id="leaderboard-table",
elem_classes="large-table",
interactive=False,
)
# Copy of the leaderboard dataframe to apply filters to
hidden_leaderboard_df = gr.components.Dataframe(
value=performance_df,
headers=PERFORMANCE_COLS,
datatype=[
c.type
for c in fields(PerformanceAutoEvalColumn)
if c.name in PERFORMANCE_COLS
],
visible=False,
)
# Inputs for the dataframe filter function
performance_filter_inputs = [
hidden_leaderboard_df,
performance_shown_columns,
filter_performance_models,
exclude_performance_models,
performance_shown_devices,
performance_shown_os,
short_speed_slider,
long_speed_slider,
short_toks_slider,
long_toks_slider,
]
filter_output = leaderboard_df
filter_performance_models.change(
performance_filter, performance_filter_inputs, filter_output
)
exclude_performance_models.change(
performance_filter, performance_filter_inputs, filter_output
)
performance_shown_columns.change(
performance_filter, performance_filter_inputs, filter_output
)
performance_shown_devices.change(
performance_filter, performance_filter_inputs, filter_output
)
performance_shown_os.change(
performance_filter, performance_filter_inputs, filter_output
)
short_speed_slider.change(
performance_filter, performance_filter_inputs, filter_output
)
long_speed_slider.change(
performance_filter, performance_filter_inputs, filter_output
)
short_toks_slider.change(
performance_filter, performance_filter_inputs, filter_output
)
long_toks_slider.change(
performance_filter, performance_filter_inputs, filter_output
)
# English Quality Tab
with gr.TabItem("English Quality", elem_id="timeline", id=1):
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
with gr.Column(scale=6, elem_classes="filter_models_column"):
filter_quality_models = gr.Textbox(
placeholder="πŸ” Filter Model (separate multiple queries with ';')",
label="Filter Models",
)
with gr.Column(scale=4, elem_classes="exclude_models_column"):
exclude_quality_models = gr.Textbox(
placeholder="πŸ” Exclude Model",
label="Exclude Model",
)
with gr.Row():
with gr.Accordion("See All Columns", open=False):
quality_shown_columns = gr.CheckboxGroup(
choices=QUALITY_TOGGLE_COLS,
value=QUALITY_SELECTED_COLS,
label="Toggle Columns",
elem_id="column-select",
interactive=True,
)
with gr.Column(scale=1):
with gr.Accordion("See Quality Filters"):
with gr.Row():
with gr.Row():
quality_min_avg_wer, quality_max_avg_wer = (
floor(min(model_df["Average WER"])),
ceil(max(model_df["Average WER"])) + 1,
)
wer_slider = RangeSlider(
value=[quality_min_avg_wer, quality_max_avg_wer],
minimum=quality_min_avg_wer,
maximum=quality_max_avg_wer,
label="Average WER",
)
with gr.Row():
quality_min_qoi, quality_max_qoi = floor(
min(model_df["QoI"])
), ceil(max(model_df["QoI"] + 1))
qoi_slider = RangeSlider(
value=[quality_min_qoi, quality_max_qoi],
minimum=quality_min_qoi,
maximum=quality_max_qoi,
label="QoI",
)
with gr.Row():
gr.Markdown(QUALITY_TEXT)
with gr.Row():
quality_leaderboard_df = gr.components.Dataframe(
value=model_df[
QUALITY_ALWAYS_HERE_COLS + quality_shown_columns.value
],
headers=[QUALITY_ALWAYS_HERE_COLS + quality_shown_columns.value],
datatype=[
c.type
for c in fields(QualityAutoEvalColumn)
if c.name in QUALITY_COLS
],
elem_id="leaderboard-table",
elem_classes="large-table",
interactive=False,
)
# Copy of the leaderboard dataframe to apply filters to
hidden_quality_leaderboard_df = gr.components.Dataframe(
value=model_df,
headers=QUALITY_COLS,
datatype=[
c.type
for c in fields(QualityAutoEvalColumn)
if c.name in QUALITY_COLS
],
visible=False,
)
# Inputs for the dataframe filter function
filter_inputs = [
hidden_quality_leaderboard_df,
quality_shown_columns,
filter_quality_models,
wer_slider,
qoi_slider,
exclude_quality_models,
]
filter_output = quality_leaderboard_df
filter_quality_models.change(
quality_filter, filter_inputs, filter_output
)
exclude_quality_models.change(
quality_filter, filter_inputs, filter_output
)
quality_shown_columns.change(
quality_filter, filter_inputs, filter_output
)
wer_slider.change(quality_filter, filter_inputs, filter_output)
qoi_slider.change(quality_filter, filter_inputs, filter_output)
# Timeline Tab
with gr.TabItem("Timeline", elem_id="timeline", id=4):
# Create subtabs for different metrics
with gr.Tabs():
with gr.TabItem("QoI", id=0):
with gr.Row():
with gr.Column(scale=6):
filter_qoi = gr.Textbox(
placeholder="πŸ” Filter Model-Device-OS (separate multiple queries with ';')",
label="Filter",
)
with gr.Column(scale=4):
exclude_qoi = gr.Textbox(
placeholder="πŸ” Exclude Model-Device-OS",
label="Exclude",
)
with gr.Row():
with gr.Column():
qoi_plot = gr.Plot(container=True)
demo.load(
lambda x, y, z: plot_metric(
x,
"qoi",
"QoI",
"QoI Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_qoi,
exclude_qoi,
],
qoi_plot,
)
filter_qoi.change(
lambda x, y, z: plot_metric(
x,
"qoi",
"QoI",
"QoI Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_qoi,
exclude_qoi,
],
qoi_plot,
)
exclude_qoi.change(
lambda x, y, z: plot_metric(
x,
"qoi",
"QoI",
"QoI Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_qoi,
exclude_qoi,
],
qoi_plot,
)
with gr.TabItem("Average WER", id=1):
with gr.Row():
with gr.Column(scale=6):
filter_average_wer = gr.Textbox(
placeholder="πŸ” Filter Model-Device-OS (separate multiple queries with ';')",
label="Filter",
)
with gr.Column(scale=4):
exclude_average_wer = gr.Textbox(
placeholder="πŸ” Exclude Model-Device-OS",
label="Exclude",
)
with gr.Row():
with gr.Column():
average_wer_plot = gr.Plot(container=True)
demo.load(
lambda x, y, z: plot_metric(
x,
"average_wer",
"Average WER",
"Average WER Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_average_wer,
exclude_average_wer,
],
average_wer_plot,
)
filter_average_wer.change(
lambda x, y, z: plot_metric(
x,
"average_wer",
"Average WER",
"Average WER Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_average_wer,
exclude_average_wer,
],
average_wer_plot,
)
exclude_average_wer.change(
lambda x, y, z: plot_metric(
x,
"average_wer",
"Average WER",
"Average WER Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_average_wer,
exclude_average_wer,
],
average_wer_plot,
)
with gr.TabItem("Speed", id=2):
with gr.Row():
with gr.Column(scale=6):
filter_speed = gr.Textbox(
placeholder="πŸ” Filter Model-Device-OS (separate multiple queries with ';')",
label="Filter",
)
with gr.Column(scale=4):
exclude_speed = gr.Textbox(
placeholder="πŸ” Exclude Model-Device-OS",
label="Exclude",
)
with gr.Row():
with gr.Column():
speed_plot = gr.Plot(container=True)
demo.load(
lambda x, y, z: plot_metric(
x,
"speed",
"Speed",
"Speed Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_speed,
exclude_speed,
],
speed_plot,
)
filter_speed.change(
lambda x, y, z: plot_metric(
x,
"speed",
"Speed",
"Speed Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_speed,
exclude_speed,
],
speed_plot,
)
exclude_speed.change(
lambda x, y, z: plot_metric(
x,
"speed",
"Speed",
"Speed Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_speed,
exclude_speed,
],
speed_plot,
)
with gr.TabItem("Tok/s", id=3):
with gr.Row():
with gr.Column(scale=6):
filter_toks = gr.Textbox(
placeholder="πŸ” Filter Model-Device-OS (separate multiple queries with ';')",
label="Filter",
)
with gr.Column(scale=4):
exclude_toks = gr.Textbox(
placeholder="πŸ” Exclude Model-Device-OS",
label="Exclude",
)
with gr.Row():
with gr.Column():
toks_plot = gr.Plot(container=True)
demo.load(
lambda x, y, z: plot_metric(
x,
"tokens_per_second",
"Tok/s",
"Tok/s Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_toks,
exclude_toks,
],
toks_plot,
)
filter_toks.change(
lambda x, y, z: plot_metric(
x,
"tokens_per_second",
"Tok/s",
"Tok/s Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_toks,
exclude_toks,
],
toks_plot,
)
exclude_toks.change(
lambda x, y, z: plot_metric(
x,
"tokens_per_second",
"Tok/s",
"Tok/s Over Time for Model-Device-OS Combinations",
y,
z,
),
[
gr.Dataframe(benchmark_df, visible=False),
filter_toks,
exclude_toks,
],
toks_plot,
)
# Multilingual Quality Tab
with gr.TabItem("Multilingual Quality", elem_id="multilingual", id=5):
if multilingual_df is not None:
with gr.Row():
with gr.Column(scale=1):
# Display table of multilingual models
model_table = gr.Dataframe(
value=multilingual_models_df,
headers=["Model"],
datatype=["html"],
elem_classes="left-side-table",
)
# Placeholders for confusion matrix plots
with gr.Row():
unforced_confusion_matrix = gr.Plot(visible=False)
with gr.Row():
forced_confusion_matrix = gr.Plot(visible=False)
with gr.Column(scale=1):
# Display area for selected model results
results_markdown = gr.Markdown(
"# Select a model from the table on the left to view results.",
elem_id="multilingual-results",
)
# Tables for displaying average WER and language-specific WER
average_wer_table = gr.Dataframe(
value=None, elem_id="average-wer-table", visible=False
)
language_wer_table = gr.Dataframe(
value=None, elem_id="general-wer-table", visible=False
)
# Set up click event to update results when a model is selected
for button in multilingual_models_buttons:
button.render()
button.click(
fn=lambda x: update_multilingual_results(x),
inputs=[button],
outputs=[
results_markdown,
average_wer_table,
language_wer_table,
unforced_confusion_matrix,
],
)
else:
# Display message if no multilingual data is available
gr.Markdown("No multilingual benchmark results available.")
# Device Support Tab
with gr.TabItem("Device Support", elem_id="device_support", id=6):
# Load device support data from CSV
support_data = pd.read_csv("dashboard_data/support_data.csv")
support_data.set_index(support_data.columns[0], inplace=True)
support_data["Model"] = support_data["Model"].apply(
lambda x: x.replace("_", "/")
)
support_data["Model"] = support_data["Model"].apply(
lambda x: make_model_name_clickable_link(x)
)
support_data = (
support_data.assign(model_len=support_data["Model"].str.len())
.sort_values(
by=["model_len"],
ascending=[True],
)
.drop(columns=["model_len"])
)
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
with gr.Column(scale=6, elem_id="filter_models_column"):
filter_support_models = gr.Textbox(
placeholder="πŸ” Filter Model (separate multiple queries with ';')",
label="Filter Models",
)
with gr.Column(scale=4, elem_classes="exclude_models_column"):
exclude_support_models = gr.Textbox(
placeholder="πŸ” Exclude Model",
label="Exclude Model",
)
with gr.Row():
with gr.Accordion("See All Columns", open=False):
with gr.Row():
with gr.Column(scale=9):
support_shown_columns = gr.CheckboxGroup(
choices=support_data.columns.tolist()[
1:
], # Exclude 'Model' column
value=support_data.columns.tolist()[1:],
label="Toggle Columns",
elem_id="support-column-select",
interactive=True,
)
with gr.Column(scale=1, min_width=200):
with gr.Row():
select_all_support_button = gr.Button(
"Select All",
elem_id="select-all-support-button",
interactive=True,
)
deselect_all_support_button = gr.Button(
"Deselect All",
elem_id="deselect-all-support-button",
interactive=True,
)
with gr.Column():
gr.Markdown(
"""
### Legend
- βœ… Supported: The model is supported and tested on this device.
- ⚠️ Failed: Either the model tests failed on this device or the Speed Factor for the test is less than 1.
- ? Not Tested: The model is supported on this device but no test information available.
- Not Supported: The model is not supported on this device as per the [WhisperKit configuration](https://huggingface.co/argmaxinc/whisperkit-coreml/blob/main/config.json).
"""
)
# Display device support data in a table
device_support_table = gr.Dataframe(
value=support_data,
headers=support_data.columns.tolist(),
datatype=["html" for _ in support_data.columns],
elem_id="device-support-table",
elem_classes="large-table",
interactive=False,
)
# Hidden dataframe to store the original data
hidden_support_df = gr.Dataframe(value=support_data, visible=False)
def filter_support_data(df, columns, model_query, exclude_models):
filtered_df = df.copy()
# Filter models based on query
if model_query:
filtered_df = filtered_df[
filtered_df["Model"].str.contains(
"|".join(q.strip() for q in model_query.split(";")),
case=False,
regex=True,
)
]
# Exclude specified models
if exclude_models:
exclude_list = [
re.escape(m.strip()) for m in exclude_models.split(";")
]
filtered_df = filtered_df[
~filtered_df["Model"].str.contains(
"|".join(exclude_list), case=False, regex=True
)
]
# Select columns
selected_columns = ["Model"] + [
col for col in columns if col in df.columns
]
filtered_df = filtered_df[selected_columns]
return filtered_df
def select_all_support_columns():
return support_data.columns.tolist()[1:] # Exclude 'Model' column
def deselect_all_support_columns():
return []
# Connect the filter function to the input components
filter_inputs = [
hidden_support_df,
support_shown_columns,
filter_support_models,
exclude_support_models,
]
filter_support_models.change(
filter_support_data, filter_inputs, device_support_table
)
exclude_support_models.change(
filter_support_data, filter_inputs, device_support_table
)
support_shown_columns.change(
filter_support_data, filter_inputs, device_support_table
)
# Connect select all and deselect all buttons
select_all_support_button.click(
select_all_support_columns,
inputs=[],
outputs=support_shown_columns,
)
deselect_all_support_button.click(
deselect_all_support_columns,
inputs=[],
outputs=support_shown_columns,
)
# Methodology Tab
with gr.TabItem("Methodology", elem_id="methodology", id=7):
gr.Markdown(METHODOLOGY_TEXT, elem_id="methodology-text")
# Citation section
with gr.Accordion("πŸ“™ Citation", open=False):
citation_button = gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
lines=7,
elem_id="citation-button",
show_copy_button=True,
)
# Launch the Gradio interface
demo.launch(debug=True, share=True, ssr_mode=False)