import glob
import pandas as pd
import streamlit as st

cer_langs = ["ja", "zh-CN", "zh-HK", "zh-TW"]


def get_note(lang):
    metric = (
        "`CER` (Char Error Rate, lower is better)" if lang in cer_langs else "`WER` (Word Error Rate, lower is better)"
    )
    return (
        f"**Note**: The evaluation metric for the `{lang}` language is {metric}. " f"The score used for ranking is: "
    )


def get_note_custom(lang):
    metric = "CER (Char Error Rate)" if lang in cer_langs else "WER (Word Error Rate)"
    return (
        f"**Note**: The evaluation metric for the `{lang}` language is {metric}. "
        "The metrics for this language are **self-reported** by the participants and are dataset-dependent, "
        "so there's no explicit ranking score. "
        "Please refer to the model cards for more info about the datasets."
    )


def make_clickable(model_name):
    link = "https://huggingface.co/" + model_name
    return f'<a target="_blank" href="{link}">{model_name}</a>'


def parse_df(df, lang):
    note = None
    if "cv6" in df.columns and "hf_dev" in df.columns:
        df["avg"] = (df[["cv6", "cv7", "cv8"]].max(axis=1) + df["hf_dev"] + df["hf_test"]) / 3
        df.sort_values("avg", inplace=True, ignore_index=True)
        df = df[["model", "cv6", "cv7", "cv8", "hf_dev", "hf_test", "avg"]]
        note = get_note(lang)
        metric_eq = r"""Score = 
                        \begin{cases}
                          (CV61_{Test} + HF_{Dev} + HF_{Test}) / 3 & \text{if trained on $Common\ Voice\ 6.1$} \\
                          (CV7_{Test} + HF_{Dev} + HF_{Test}) / 3 & \text{if trained on $Common\ Voice\ 7.0$} \\
                          (CV8_{Test} + HF_{Dev} + HF_{Test}) / 3 & \text{if trained on $Common\ Voice\ 8.0$} \\
                        \end{cases}"""
    elif "cv8" in df.columns and "hf_dev" in df.columns:
        df["avg"] = (df[["cv7", "cv8"]].max(axis=1) + df["hf_dev"] + df["hf_test"]) / 3
        df.sort_values("avg", inplace=True, ignore_index=True)
        df = df[["model", "cv7", "cv8", "hf_dev", "hf_test", "avg"]]
        note = get_note(lang)
        metric_eq = r"""Score = 
                        \begin{cases}
                          (CV7_{Test} + HF_{Dev} + HF_{Test}) / 3 & \text{if trained on $Common\ Voice\ 7.0$} \\
                          (CV8_{Test} + HF_{Dev} + HF_{Test}) / 3 & \text{if trained on $Common\ Voice\ 8.0$} \\
                        \end{cases}"""
    elif "cv6" in df.columns:
        df["avg"] = df[["cv6", "cv7", "cv8"]].max(axis=1)
        df.sort_values("avg", inplace=True, ignore_index=True)
        df = df[["model", "cv6", "cv7", "cv8", "avg"]]
        note = get_note(lang)
        metric_eq = r"""Score = 
                        \begin{cases}
                          CV61_{Test} & \text{if trained on $Common\ Voice\ 6.1$} \\
                          CV7_{Test} & \text{if trained on $Common\ Voice\ 7.0$} \\
                          CV8_{Test} & \text{if trained on $Common\ Voice\ 8.0$} \\
                        \end{cases}"""
    elif "cv8" in df.columns:
        df["avg"] = df[["cv7", "cv8"]].max(axis=1)
        df.sort_values("avg", inplace=True, ignore_index=True)
        df = df[["model", "cv7", "cv8", "avg"]]
        note = get_note(lang)
        metric_eq = r"""Score = 
                        \begin{cases}
                          CV7_{Test} & \text{if trained on $Common\ Voice\ 7.0$} \\
                          CV8_{Test} & \text{if trained on $Common\ Voice\ 8.0$} \\
                        \end{cases}"""
    elif "hf_dev" in df.columns:
        df["avg"] = (df["hf_dev"] + df["hf_test"]) / 2
        df.sort_values("avg", inplace=True, ignore_index=True)
        df = df[["model", "hf_dev", "hf_test", "avg"]]
        note = get_note(lang)
        metric_eq = r"""Score = \frac{HF_{Dev} + HF_{Test}}{2}"""
    elif "custom" in df.columns:
        df = df[["model", "custom"]]
        df.sort_values("custom", inplace=True, ignore_index=True)
        note = get_note_custom(lang)
        metric_eq = None
    df["model"] = df["model"].apply(make_clickable)
    df.rename(
        columns={
            "model": "Model",
            "cv6": "CV 6.1 Test",
            "cv7": "CV 7.0 Test",
            "cv8": "CV 8.0 Test",
            "hf_dev": "HF Dev",
            "hf_test": "HF Test",
            "custom": "Custom Test",
            "avg": "Score",
        },
        inplace=True,
    )
    df.fillna("", inplace=True)
    return df, note, metric_eq


@st.cache()
def main():
    dataframes = {}
    notes = {}
    metric_eqs = {}
    for lang_csv in sorted(glob.glob("data/*.csv")):
        lang = lang_csv.split("/")[-1].split(".")[0]
        df = pd.read_csv(lang_csv)
        dataframes[lang], notes[lang], metric_eqs[lang] = parse_df(df, lang)

    return dataframes, notes, metric_eqs


dataframes, notes, eval_eqs = main()

_, col_center = st.columns([3, 6])
with col_center:
    st.image("logo.png", width=200)
st.markdown("# Robust Speech Challenge Results")

lang_select = sorted(dataframes.keys())

lang = st.selectbox(
    "Language",
    lang_select,
    index=0,
)

st.markdown(notes[lang])
if eval_eqs[lang]:
    st.latex(eval_eqs[lang])
st.write(dataframes[lang].to_html(escape=False, index=None), unsafe_allow_html=True)