File size: 4,418 Bytes
13e8963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr  # type: ignore
import plotly.express as px  # type: ignore

from backend.data import load_cot_data
from backend.envs import API, REPO_ID, TOKEN

logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png"
logo2_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/logo_logikon_notext_withborder.png"
LOGOS = f'<div style="display: flex; justify-content: center;"><a href="https://allenai.org/"><img src="{logo1_url}" alt="AI2" style="width: 30vw; min-width: 20px; max-width: 60px;"></a> <a href="https://logikon.ai"><img src="{logo2_url}" alt="Logikon AI" style="width: 30vw; min-width: 20px; max-width: 60px; margin-left: 10px;"></a></div>'

TITLE = f'<h1 align="center" id="space-title"> Open CoT Dashboard</h1> {LOGOS}'

INTRODUCTION_TEXT = """
Baseline accuracies and marginal accuracy gains for specific models and CoT regimes from the [Open CoT Leaderboard](https://huggingface.co/spaces/logikon/open_cot_leaderboard).
"""

def restart_space():
    API.restart_space(repo_id=REPO_ID, token=TOKEN)

try:
    df_cot_err, df_cot_regimes = load_cot_data()
except Exception:
    restart_space()


def plot_evals(model_id, plotly_mode, request: gr.Request):
    df = df_cot_err.copy()
    if request and "model" in request.query_params:
        model_param = request.query_params["model"]
        if model_param in df.model.to_list():
            model_id = model_param
    df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
    #df.sort_values(["selected", "model"], inplace=True, ascending=True)  # has currently no effect with px.scatter
    template = "plotly_dark" if plotly_mode=="dark" else "plotly" 
    fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model",
                    facet_col="task", facet_col_wrap=3,
                    category_orders={"selected": ["selected", "-"]},
                    color_discrete_sequence=["Orange", "Gray"],
                    template=template,
                    error_y="acc_gain-err", hover_data=['model', "cot accuracy"],
                    width=1200, height=700)

    fig.update_layout(
        title={"automargin": True},
    )
    return fig, model_id

def get_model_table(model_id):

    def make_pretty(styler):
        styler.hide(axis="index")
        styler.format(precision=1),
        styler.background_gradient(
            axis=None,
            subset=["acc_base", "acc_cot"],
            vmin=20, vmax=100, cmap="YlGnBu"
        )
        styler.background_gradient(
            axis=None,
            subset=["acc_gain"],
            vmin=-20, vmax=20, cmap="coolwarm"
        )
        styler.set_table_styles({
            'task': [{'selector': '',
                  'props': [('font-weight', 'bold')]}],
            'B': [{'selector': 'td',
                  'props': 'color: blue;'}]
        }, overwrite=False)
        return styler

    df_cot_model = df_cot_regimes[df_cot_regimes.model.eq(model_id)][['task', 'cot_chain', 'best_of',
          'temperature', 'top_k', 'top_p', 'acc_base', 'acc_cot', 'delta_abs']]
    
    df_cot_model = df_cot_model \
      .rename(columns={"temperature": "temp"}) \
      .replace({'cot_chain': 'ReflectBeforeRun'}, "Reflect") \
      .sort_values(["task", "cot_chain"]) \
      .reset_index(drop=True)
    
    return df_cot_model.style.pipe(make_pretty)

def styled_model_table(model_id, request: gr.Request):
    if request and "model" in request.query_params:
        model_param = request.query_params["model"]
        if model_param in df_cot_regimes.model.to_list():
            model_id = model_param
    return get_model_table(model_id)


demo = gr.Blocks()

with demo:

    gr.HTML(TITLE)
    gr.Markdown(INTRODUCTION_TEXT)
    with gr.Row():
        model_list = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", scale=2)
        plotly_mode = gr.Radio(["dark","light"], value="dark", label="Plot theme", scale=1)
        submit = gr.Button("Update", scale=1)
    table = gr.DataFrame()
    plot = gr.Plot(label="evals")


    submit.click(plot_evals, [model_list, plotly_mode], [plot, model_list])
    submit.click(styled_model_table, model_list, table)
    demo.load(plot_evals, [model_list, plotly_mode], [plot, model_list])
    demo.load(styled_model_table, model_list, table)

demo.launch()