import pandas as pd
import plotly.express as px
import requests
import json
import gradio as gr

from src.assets.text_content import SHORT_NAMES, TEXT_NAME, MULTIMODAL_NAME, REGISTRY_URL
from src.leaderboard_utils import get_github_data


def plotly_plot(df: pd.DataFrame, list_op: list, list_co: list,
                show_all: list, show_names: list, show_legend: list,
                mobile_view: list):
    """
    Takes in a list of models for a plotly plot
    Args:
        df: A dummy dataframe of latest version
        list_op: The list of open source models to show in the plot, updated from frontend
        list_co: The list of commercial models to show in the plot, updated from frontend
        show_all: Either [] or ["Show All Models"] - toggle view to plot all models 
        show_names: Either [] or ["Show Names"] - toggle view to show model names on plot 
        show_legend: Either [] or ["Show Legend"] - toggle view to show legend on plot
        mobile_view: Either [] or ["Mobile View"] - toggle view to for smaller screens
    Returns:
        Fig: plotly figure of % played v/s quality score
    """

    LIST = list_op + list_co
    # Get list of all models and append short names column to df
    list_columns = list(df.columns)
    ALL_LIST = list(df[list_columns[0]].unique())
    short_names = label_map(ALL_LIST)
    list_short_names = list(short_names.values())
    df["Short"] = list_short_names

    if show_all:
        LIST = ALL_LIST
    # Filter dataframe based on the provided list of models
    df = df[df[list_columns[0]].isin(LIST)]

    if show_names:
        fig = px.scatter(df, x=list_columns[2], y=list_columns[3], color=list_columns[0], symbol=list_columns[0],
                         color_discrete_map={"category1": "blue", "category2": "red"},
                         hover_name=list_columns[0], template="plotly_white", text="Short")
        fig.update_traces(textposition='top center')
    else:
        fig = px.scatter(df, x=list_columns[2], y=list_columns[3], color=list_columns[0], symbol=list_columns[0],
                         color_discrete_map={"category1": "blue", "category2": "red"},
                         hover_name=list_columns[0], template="plotly_white")

    if not show_legend:
        fig.update_layout(showlegend=False)

    fig.update_layout(
        xaxis_title='% Played',
        yaxis_title='Quality Score',
        title='Overview of benchmark results',
        height=1000
    )

    fig.update_xaxes(range=[-5, 105])
    fig.update_yaxes(range=[-5, 105])

    if mobile_view:
        fig.update_layout(height=300)

    if mobile_view and show_legend:
        fig.update_layout(height=450)
        fig.update_layout(legend=dict(
            yanchor="bottom",
            y=-5.52,
            xanchor="left",
            x=0.01
        ))

        fig.update_layout(
            xaxis_title="",
            yaxis_title="",
            title="% Played v/s Quality Score"
        )

    return fig


def shorten_model_name(full_name):
    # Split the name into parts
    parts = full_name.split('-')

    # Process the name parts to keep only the parts with digits (model sizes and versions)
    short_name_parts = [part for part in parts if any(char.isdigit() for char in part)]

    if len(parts) == 1:
        short_name = ''.join(full_name[0:min(3, len(full_name))])
    else:
        # Join the parts to form the short name
        short_name = '-'.join(short_name_parts)

        # Remove any leading or trailing hyphens
        short_name = full_name[0] + '-' + short_name.strip('-')

    return short_name


def label_map(model_list: list) -> dict:
    """
    Generate a map from long names to short names, to plot them in frontend graph
    Define the short names in src/assets/text_content.py
    Args: 
        model_list: A list of long model names
    Returns:
        short_name: A dict from long to short name
    """
    short_names = {}
    for model_name in model_list:
        if model_name in SHORT_NAMES:
            short_name = SHORT_NAMES[model_name]
        else:
            short_name = shorten_model_name(model_name)

        # Define the short name and indicate both models are same
        short_names[model_name] = short_name

    return short_names


def split_models(model_list: list):
    """
    Split the models into open source and commercial
    """
    open_models = []
    commercial_models = []
    
    # Load model registry data from main repo
    response = requests.get(REGISTRY_URL)

    if response.status_code == 200:
        json_data = json.loads(response.text)

        for model_name in model_list:
            for entry in json_data:
                if entry["model_name"] == model_name:
                    open_model = entry["open_weight"]
                    
                    if open_model:
                        open_models.append(model_name)
                    else:
                        commercial_models.append(model_name)
                    break

    else:
        print(f"Failed to read JSON file: {REGISTRY_URL} Status Code : {response.status_code}")

    open_models.sort(key=lambda o: o.upper())
    commercial_models.sort(key=lambda c: c.upper())

    # Add missing model from the model_registry
    if "dolphin-2.5-mixtral-8x7b" in model_list:
        open_models.append("dolphin-2.5-mixtral-8x7b")

    return open_models, commercial_models

"""
Update Functions, for when the leaderboard selection changes
"""
def update_open_models(leaderboard: str = TEXT_NAME):
    """
    Change the checkbox group of Open Models based on the leaderboard selected

    Args:
        leaderboard: Selected leaderboard from the frontend [Default - Text Leaderboard]
    Return:
        Updated checkbox group for Open Models, based on the leaderboard selected
    """
    github_data = get_github_data()
    leaderboard_data = github_data["text" if leaderboard == TEXT_NAME else "multimodal"]['dataframes'][0]
    models = leaderboard_data.iloc[:, 0].unique().tolist()
    open_models, commercial_models = split_models(models)
    return gr.CheckboxGroup(
        open_models,
        value=[],
        elem_id="value-select-1",
        interactive=True,
    )

def update_closed_models(leaderboard: str = TEXT_NAME):
    """
    Change the checkbox group of Closed Models based on the leaderboard selected

    Args:
        leaderboard: Selected leaderboard from the frontend [Default - Text Leaderboard]
    Return:
        Updated checkbox group for Closed Models, based on the leaderboard selected
    """
    github_data = get_github_data()
    leaderboard_data = github_data["text" if leaderboard == TEXT_NAME else "multimodal"]['dataframes'][0]
    models = leaderboard_data.iloc[:, 0].unique().tolist()
    open_models, commercial_models = split_models(models)
    return gr.CheckboxGroup(
        commercial_models,
        value=[],
        elem_id="value-select-2",
        interactive=True,
    )

def get_plot_df(leaderboard: str = TEXT_NAME) -> pd.DataFrame:
    """
    Get the DataFrame for plotting based on the selected leaderboard.
    Args:
        leaderboard: Selected leaderboard.
    Returns:
        DataFrame with model data.
    """
    github_data = get_github_data()
    return github_data["text" if leaderboard == TEXT_NAME else "multimodal"]['dataframes'][0]


"""
Reset Functions for when the Leaderboard selection changes
"""
def reset_show_all():
    return gr.CheckboxGroup(
            ["Select All Models"],
            label="Show plot for all models 🤖",
            value=[],
            elem_id="value-select-3",
            interactive=True,
        )

def reset_show_names():
    return gr.CheckboxGroup(
        ["Show Names"],
        label="Show names of models on the plot 🏷️",
        value=[],
        elem_id="value-select-4",
        interactive=True,
    )


def reset_show_legend():
    return gr.CheckboxGroup(
        ["Show Legend"],
        label="Show legend on the plot 💡",
        value=[],
        elem_id="value-select-5",
        interactive=True,
    )


def reset_mobile_view():
    return gr.CheckboxGroup(
        ["Mobile View"],
        label="View plot on smaller screens 📱",
        value=[],
        elem_id="value-select-6",
        interactive=True,
    )


if __name__ == '__main__':
    mm_model_list = ['gpt-4o-2024-05-13', 'gpt-4-1106-vision-preview', 'claude-3-opus-20240229', 'gemini-1.5-pro-latest',
                     'gemini-1.5-flash-latest', 'llava-v1.6-34b-hf', 'llava-v1.6-vicuna-13b-hf', 'idefics-80b-instruct',
                     'llava-1.5-13b-hf', 'idefics-9b-instruct']

    text_model_list = ['vicuna-33b-v1.3', 'gpt-4-0125-preview', 'gpt-4-turbo-2024-04-09', 'claude-3-5-sonnet-20240620', 'gpt-4-1106-preview',
                         'gpt-4-0613', 'gpt-4o-2024-05-13', 'claude-3-opus-20240229', 'gemini-1.5-pro-latest',
                         'Meta-Llama-3-70B-Instruct-hf', 'claude-2.1', 'gemini-1.5-flash-latest', 'claude-3-sonnet-20240229',
                         'Qwen1.5-72B-Chat', 'mistral-large-2402', 'gpt-3.5-turbo-0125', 'gemini-1.0-pro', 'command-r-plus', 'openchat_3.5',
                         'claude-3-haiku-20240307', 'sheep-duck-llama-2-70b-v1.1', 'Meta-Llama-3-8B-Instruct-hf', 'openchat-3.5-1210',
                         'WizardLM-70b-v1.0', 'openchat-3.5-0106', 'Qwen1.5-14B-Chat', 'mistral-medium-2312', 'Qwen1.5-32B-Chat',
                         'codegemma-7b-it', 'dolphin-2.5-mixtral-8x7b', 'CodeLlama-34b-Instruct-hf', 'command-r', 'gemma-1.1-7b-it',
                         'SUS-Chat-34B', 'Mixtral-8x22B-Instruct-v0.1', 'tulu-2-dpo-70b', 'Nous-Hermes-2-Mixtral-8x7B-SFT',
                         'WizardLM-13b-v1.2', 'Mistral-7B-Instruct-v0.2', 'Yi-34B-Chat', 'Mixtral-8x7B-Instruct-v0.1',
                         'Mistral-7B-Instruct-v0.1', 'Yi-1.5-34B-Chat', 'vicuna-13b-v1.5', 'Yi-1.5-6B-Chat', 'Starling-LM-7B-beta',
                         'sheep-duck-llama-2-13b', 'Yi-1.5-9B-Chat', 'gemma-1.1-2b-it', 'Qwen1.5-7B-Chat', 'gemma-7b-it',
                         'llama-2-70b-chat-hf', 'Qwen1.5-0.5B-Chat', 'Qwen1.5-1.8B-Chat']

    om, cm = split_models(mm_model_list)
    print("Open")
    print(om)
    print("Closed")
    print(cm)