import pandas as pd
import streamlit as st
import config
from pathlib import Path as P
import json


nbow_results_path = P("assets").glob("nbow*")

def display_metrics_dict(metrics, display_only_accuracy):
    model_name = metrics.pop("model_name")
    columns = metrics.pop("columns").split("_")
    st.markdown(f"### columns: {columns}")
    st.markdown(f"best model {model_name}")
    if not display_only_accuracy:
        st.json(metrics)
    else:
        st.json({"accuracy@10": metrics["accuracy@k"]["10"]})

def display_metrics():
    display_only_accuracy = st.sidebar.checkbox("display only accuracy@10", value=True)
    st.markdown("## Test metrics for best validation modelon given columns")
    for p in nbow_results_path:
        metrics = json.loads(open(p, "r").read())
        display_metrics_dict(metrics, display_only_accuracy)

display_metrics()

best_results_df = pd.read_csv(config.best_tasks_path)


worst_results_df = pd.read_csv(config.worst_tasks_path)

show_worst_best_statistics = st.sidebar.checkbox(
    label="show worst/best statistics grouped by area"
)

show_area_aggregated_results = st.sidebar.checkbox(
    label="show results aggregated by area"
)
if show_worst_best_statistics:
    st.markdown(
        """
    ## Worst/best queries
    The following are top 10 worst/best queries per area by number of hits.
    There are at least 10 documents per query in the test set, so number of hits/10 is the accuracy.
    """
    )
    sort_key = st.selectbox("sort by", list(best_results_df.columns))
    st.markdown("## Queries with best results")
    st.table(best_results_df.sort_values(sort_key, ascending=False))
    st.markdown("## Queries with worst results")
    st.table(worst_results_df.sort_values(sort_key, ascending=False))

if show_area_aggregated_results:
    st.markdown("## Area aggregated results")
    best_results_agg = best_results_df.groupby("area").agg("mean").reset_index()
    worst_results_agg = worst_results_df.groupby("area").agg("mean").reset_index()
    sort_key = st.selectbox("sort by", list(best_results_agg.columns))
    st.markdown("Best results")
    st.table(best_results_agg.sort_values(sort_key, ascending=False))
    st.markdown("Worst results")
    st.table(worst_results_agg.sort_values(sort_key, ascending=False))