File size: 4,171 Bytes
cbb886a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import datasets
import logging
from fetch_utils import check_dataset_and_get_config, check_dataset_and_get_split

def get_records_from_dataset_repo(dataset_id):
    dataset_config = check_dataset_and_get_config(dataset_id)
    logging.info(f"Dataset {dataset_id} has configs {dataset_config}")
    dataset_split = check_dataset_and_get_split(dataset_id, dataset_config[0])
    logging.info(f"Dataset {dataset_id} has splits {dataset_split}")
    try:
        ds = datasets.load_dataset(dataset_id, dataset_config[0])[dataset_split[0]]
        df = ds.to_pandas()
        return df
    except Exception as e:
        logging.warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
        return None

def get_model_ids(ds):
    logging.info(f"Dataset {ds} column names: {ds['model_id']}")
    models = ds['model_id'].tolist()
    # return unique elements in the list model_ids
    model_ids = list(set(models))
    return model_ids

def get_dataset_ids(ds):
    logging.info(f"Dataset {ds} column names: {ds['dataset_id']}")
    datasets = ds['dataset_id'].tolist()
    dataset_ids = list(set(datasets))
    return dataset_ids

def get_types(ds):
    # set all types for each column
    types = [str(t) for t in ds.dtypes.to_list()]
    types = [t.replace('object', 'markdown') for t in types]
    types = [t.replace('float64', 'number') for t in types]
    types = [t.replace('int64', 'number') for t in types]
    return types

def get_display_df(df):
    # style all elements in the model_id column
    display_df = df.copy()
    if display_df['model_id'].any():
        display_df['model_id'] = display_df['model_id'].apply(lambda x: f'<p href="https://huggingface.co/{x}" style="color:blue">πŸ”—{x}</p>')
    # style all elements in the dataset_id column
    if display_df['dataset_id'].any():
        display_df['dataset_id'] = display_df['dataset_id'].apply(lambda x: f'<p href="https://huggingface.co/datasets/{x}" style="color:blue">πŸ”—{x}</p>')
    # style all elements in the report_link column
    if display_df['report_link'].any():
        display_df['report_link'] = display_df['report_link'].apply(lambda x: f'<p href="{x}" style="color:blue">πŸ”—{x}</p>')
    return display_df

def get_demo():
    records = get_records_from_dataset_repo('ZeroCommand/test-giskard-report')

    model_ids = get_model_ids(records)
    dataset_ids = get_dataset_ids(records)

    column_names = records.columns.tolist()
    default_columns = ['model_id', 'dataset_id', 'total_issue', 'report_link']
    # set the default columns to show
    default_df = records[default_columns]
    types = get_types(default_df)
    display_df = get_display_df(default_df)

    with gr.Row():
        task_select = gr.Dropdown(label='Task', choices=['text_classification', 'tabular'], value='text_classification', interactive=True)
        model_select = gr.Dropdown(label='Model id', choices=model_ids, interactive=True)
        dataset_select = gr.Dropdown(label='Dataset id', choices=dataset_ids, interactive=True)
    
    with gr.Row():
        columns_select = gr.CheckboxGroup(label='Show columns', choices=column_names, value=default_columns, interactive=True)

    with gr.Row():
        leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False)
    
    @gr.on(triggers=[model_select.change, dataset_select.change, columns_select.change, task_select.change],
           inputs=[model_select, dataset_select, columns_select, task_select],
           outputs=[leaderboard_df])
    def filter_table(model_id, dataset_id, columns, task):
        # filter the table based on task
        df = records[(records['hf_pipeline_type'] == task)]
        # filter the table based on the model_id and dataset_id
        if model_id:
            df = records[(records['model_id'] == model_id)]
        if dataset_id:
            df = records[(records['dataset_id'] == dataset_id)]

        # filter the table based on the columns
        df = df[columns]
        types = get_types(df)
        display_df = get_display_df(df)
        return (
            gr.update(value=display_df, datatype=types, interactive=False)
        )