import os
import gradio as gr
import json
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from datasets import load_dataset
from plotly.subplots import make_subplots


CATEGORIES = ["task-solving", "math-reasoning", "general-instruction", "natural-question", "safety"]
LANGS = ['en', 'vi', 'th', 'id', 'km', 'lo', 'ms', 'my', 'tl']

FORCE_DOWNLOAD = bool(int(os.environ.get("FORCE_DOWNLOAD", "0")))
HF_TOKEN = str(os.environ.get("HF_TOKEN", ""))
DATA_SET_REPO_PATH = str(os.environ.get("DATA_SET_REPO_PATH", ""))
PERFORMANCE_FILENAME = str(os.environ.get("PERFORMANCE_FILENAME", "gpt4_single_json.csv"))


rename_map = {
    "seallm13b10L6k_a_5a1R1_seaall_sft4x_1_5a1_r2_0_dpo_8_40000s": "SeaLLM-13b",
    "polylm": "PolyLM-13b",
    "qwen": "Qwen-14b",
    "gpt-3.5-turbo": "GPT-3.5-turbo",
}
CATEGORIES = [ "task-solving",  "math-reasoning",  "general-instruction",  "natural-question",  "safety", ]

CATEGORIES_NAMES = {
    "task-solving": 'Task-solving', 
    "math-reasoning": 'Math', 
    "general-instruction": 'General-instruction', 
    "natural-question": 'NaturalQA', 
    "safety": 'Safety',
}


# LANGS = ['en', 'vi', 'th', 'id', 'km', 'lo', 'ms', 'my', 'tl']
LANGS = ['en', 'vi', 'id', 'ms', 'tl', 'th', 'km', 'lo', 'my']
LANG_NAMES = {
    'en': 'eng', 
    'vi': 'vie', 
    'th': 'tha', 
    'id': 'ind', 
    'km': 'khm', 
    'lo': 'lao', 
    'ms': 'msa', 
    'my': 'mya', 
    'tl': 'tgl',

}



MODEL_DFRAME = None

def get_model_df():
    # global MODEL_DFRAME
    # if isinstance(MODEL_DFRAME, pd.DataFrame):
    #     print(f'Load cache data frame')
    #     return MODEL_DFRAME
    from huggingface_hub import hf_hub_download
    assert DATA_SET_REPO_PATH != ''
    assert HF_TOKEN != ''
    repo_id = DATA_SET_REPO_PATH
    filename = PERFORMANCE_FILENAME

    # data_path = f"{DATA_SET_REPO_PATH}/{PERFORMANCE_FILENAME}"
    file_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        force_download=FORCE_DOWNLOAD,
        local_dir='./hf_cache',
        repo_type="dataset",
        token=HF_TOKEN
    )
    print(f'Downloaded file at {file_path} from {DATA_SET_REPO_PATH} / {PERFORMANCE_FILENAME}')
    df = pd.read_csv(file_path)
    return df


def aggregate_df(df, model_dict, category_name, categories):
    scores_all = []
    all_models = df["model"].unique()
    for model in all_models:
        for i, cat in enumerate(categories):
            # filter category/model, and score format error (<1% case)
            res = df[(df[category_name]==cat) & (df["model"]==model) & (df["score"] >= 0)]
            score = res["score"].mean()
            cat_name = cat
            scores_all.append({"model": model, category_name: cat_name, "score": score})
    
    target_models = list(model_dict.keys())
    scores_target = [scores_all[i] for i in range(len(scores_all)) if scores_all[i]["model"] in target_models]
    scores_target = sorted(scores_target, key=lambda x: target_models.index(x["model"]), reverse=True)

    df_score = pd.DataFrame(scores_target)
    df_score = df_score[df_score["model"].isin(target_models)]
    
    rename_map = model_dict

    for k, v in rename_map.items():
        df_score.replace(k, v, inplace=True)
    return df_score


def polar_subplot(fig, dframe, model_names, category_label, category_names, row, col, showlegend=True):

    # cat category
    colors = px.colors.qualitative.Plotly
    for i, (model, model_name) in enumerate(model_names):
        cat_list = dframe[dframe['model'] == model_name][category_label].tolist()
        score_list = dframe[dframe['model'] == model_name]['score'].tolist()
        cat_list += [cat_list[0]]
        cat_list = [category_names[x] for x in cat_list]
        score_list += [score_list[0]]
        polar = go.Scatterpolar(
            name = model_name,
            r = score_list,
            theta = cat_list,
            legendgroup=f'{i}',
            marker=dict(color=colors[i]),
            hovertemplate="""Score: %{r:.2f}""",
            showlegend=showlegend,
        )
        fig.add_trace(polar, row, col)


def plot_agg_fn():
    df = get_model_df()

    all_models = df["model"].unique()
    model_names = list(rename_map.items())
    colors = px.colors.qualitative.Plotly

    cat_df = aggregate_df(df, rename_map, "category", CATEGORIES, )
    lang_df = aggregate_df(df, rename_map, "lang", LANGS, )

    fig = make_subplots(
        rows=1, cols=2, 
        specs=[[{'type': 'polar'}]*2],
        subplot_titles=("By Category", "By Language"),
    )
    fig.layout.annotations[0].y = 1.05
    fig.layout.annotations[1].y = 1.05

    # cat category
    for i, (model, model_name) in enumerate(model_names):
        cat_list = cat_df[cat_df['model'] == model_name]['category'].tolist()
        score_list = cat_df[cat_df['model'] == model_name]['score'].tolist()
        cat_list += [cat_list[0]]
        cat_list = [CATEGORIES_NAMES[x] for x in cat_list]
        score_list += [score_list[0]]
        polar = go.Scatterpolar(
            name = model_name,
            r = score_list,
            theta = cat_list,
            legendgroup=f'{i}',
            marker=dict(color=colors[i]),
            hovertemplate="""Score: %{r:.2f}""",
            )
        fig.add_trace(polar, 1, 1)

    # cat langs
    for i, (model, model_name) in enumerate(model_names):
        cat_list = lang_df[lang_df['model'] == model_name]['lang'].tolist()
        score_list = lang_df[lang_df['model'] == model_name]['score'].tolist()
        cat_list += [cat_list[0]]
        score_list += [score_list[0]]
        cat_list = [LANG_NAMES[x] for x in cat_list]
        polar = go.Scatterpolar(
            name = model_name,
            r = score_list,
            theta = cat_list,
            legendgroup=f'{i}',
            marker=dict(color=colors[i]),
            hovertemplate="""Score: %{r:.2f}""",
            showlegend=False,
            )
        fig.add_trace(polar, 1, 2)


    polar_config = dict(
        angularaxis = dict(
        rotation=90, # start position of angular axis
        ),
        radialaxis = dict(
            range=[0, 10],
        ),
    )

    fig.update_layout(
        polar = polar_config,
        polar2 = polar_config,
        title='Sea-Bench (rated by GPT-4)',
    )
    return fig


def plot_by_lang_fn():
    df = get_model_df()
    model_names = list(rename_map.items())

    fig = make_subplots(
        rows=3, cols=3, 
        specs=[[{'type': 'polar'}]*3] * 3,
        subplot_titles=list(LANG_NAMES.values()),
        # vertical_spacing=1
    )
    # print(fig.layout.annotations)
    for ano in fig.layout.annotations:
        ano.y = ano.y + 0.02
    has_safety = ['vi', 'id', 'th']

    for lang_id, lang in enumerate(LANGS):
        cat_names = CATEGORIES if lang in has_safety else [x for x in CATEGORIES if x != 'safety']
        cat_lang_df = aggregate_df(df[df['lang'] == lang], rename_map, "category", cat_names, )
        row = lang_id // 3 + 1
        col = lang_id % 3 + 1
        polar_subplot(fig, cat_lang_df, model_names, 'category', CATEGORIES_NAMES, row, col, showlegend=lang_id == 0)

    polar_config = dict(
        angularaxis = dict(
            rotation=90, # start position of angular axis
        ),
        radialaxis = dict(
            range=[0, 10],
        ),
    )
    layer_kwargs = {f"polar{i}": polar_config for i in range(1, 10)}
    fig.update_layout(
        title='Sea-Bench - By language (rated by GPT-4)',
        height=1000,
        # width=1200,
        **layer_kwargs
    )
    return fig


def both_plot():
    return plot_agg_fn(), plot_by_lang_fn()


def attach_plot_to_demo(demo: gr.Blocks):
    with gr.Accordion("Psst... wanna see some performance benchmarks?", open=False) as accord:
        # gr_plot_agg = gr.Plot(plot_agg_fn, label="Aggregated")
        # gr_plot_bylang = gr.Plot(plot_by_lang_fn, label='By language')
        show_result = gr.Button("Load benchmark results")
        gr_plot_agg = gr.Plot(label="Aggregated")
        gr_plot_bylang = gr.Plot(label='By language')
    
    # def callback():
    # demo.load(plot_agg_fn, [], gr_plot_agg)
    # demo.load(plot_by_lang_fn, [], gr_plot_bylang)
    show_result.click(both_plot, [], [gr_plot_agg, gr_plot_bylang])