giskard-evaluator / app_leaderboard.py
ZeroCommand's picture
add leaderboard ui and refactor code
cbb886a
raw
history blame
4.17 kB
import gradio as gr
import datasets
import logging
from fetch_utils import check_dataset_and_get_config, check_dataset_and_get_split
def get_records_from_dataset_repo(dataset_id):
dataset_config = check_dataset_and_get_config(dataset_id)
logging.info(f"Dataset {dataset_id} has configs {dataset_config}")
dataset_split = check_dataset_and_get_split(dataset_id, dataset_config[0])
logging.info(f"Dataset {dataset_id} has splits {dataset_split}")
try:
ds = datasets.load_dataset(dataset_id, dataset_config[0])[dataset_split[0]]
df = ds.to_pandas()
return df
except Exception as e:
logging.warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
return None
def get_model_ids(ds):
logging.info(f"Dataset {ds} column names: {ds['model_id']}")
models = ds['model_id'].tolist()
# return unique elements in the list model_ids
model_ids = list(set(models))
return model_ids
def get_dataset_ids(ds):
logging.info(f"Dataset {ds} column names: {ds['dataset_id']}")
datasets = ds['dataset_id'].tolist()
dataset_ids = list(set(datasets))
return dataset_ids
def get_types(ds):
# set all types for each column
types = [str(t) for t in ds.dtypes.to_list()]
types = [t.replace('object', 'markdown') for t in types]
types = [t.replace('float64', 'number') for t in types]
types = [t.replace('int64', 'number') for t in types]
return types
def get_display_df(df):
# style all elements in the model_id column
display_df = df.copy()
if display_df['model_id'].any():
display_df['model_id'] = display_df['model_id'].apply(lambda x: f'<p href="https://huggingface.co/{x}" style="color:blue">πŸ”—{x}</p>')
# style all elements in the dataset_id column
if display_df['dataset_id'].any():
display_df['dataset_id'] = display_df['dataset_id'].apply(lambda x: f'<p href="https://huggingface.co/datasets/{x}" style="color:blue">πŸ”—{x}</p>')
# style all elements in the report_link column
if display_df['report_link'].any():
display_df['report_link'] = display_df['report_link'].apply(lambda x: f'<p href="{x}" style="color:blue">πŸ”—{x}</p>')
return display_df
def get_demo():
records = get_records_from_dataset_repo('ZeroCommand/test-giskard-report')
model_ids = get_model_ids(records)
dataset_ids = get_dataset_ids(records)
column_names = records.columns.tolist()
default_columns = ['model_id', 'dataset_id', 'total_issue', 'report_link']
# set the default columns to show
default_df = records[default_columns]
types = get_types(default_df)
display_df = get_display_df(default_df)
with gr.Row():
task_select = gr.Dropdown(label='Task', choices=['text_classification', 'tabular'], value='text_classification', interactive=True)
model_select = gr.Dropdown(label='Model id', choices=model_ids, interactive=True)
dataset_select = gr.Dropdown(label='Dataset id', choices=dataset_ids, interactive=True)
with gr.Row():
columns_select = gr.CheckboxGroup(label='Show columns', choices=column_names, value=default_columns, interactive=True)
with gr.Row():
leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
@gr.on(triggers=[model_select.change, dataset_select.change, columns_select.change, task_select.change],
inputs=[model_select, dataset_select, columns_select, task_select],
outputs=[leaderboard_df])
def filter_table(model_id, dataset_id, columns, task):
# filter the table based on task
df = records[(records['hf_pipeline_type'] == task)]
# filter the table based on the model_id and dataset_id
if model_id:
df = records[(records['model_id'] == model_id)]
if dataset_id:
df = records[(records['dataset_id'] == dataset_id)]
# filter the table based on the columns
df = df[columns]
types = get_types(df)
display_df = get_display_df(df)
return (
gr.update(value=display_df, datatype=types, interactive=False)
)