from typing import List

import gradio as gr
import pandas as pd

from src.leaderboard import get_leaderboard_df
from src.llm_perf import get_llm_perf_df
# from attention_implementations import get_attn_decode_fig, get_attn_prefill_fig
# from custom_kernels import get_kernel_decode_fig, get_kernel_prefill_fig


def create_control_panel(
    machine: str,
    backends: List[str],
    hardware_provider: str,
    hardware_type: str,
):
    # controls
    machine_value = gr.State(value=machine)
    backends_value = gr.State(value=backends)
    hardware_type_value = gr.State(value=hardware_type)

    if hardware_provider == "ARM":
        backends = ["llama_cpp"]
        quantizations = ["Q8_0", "Q4_K_M", "Q4_0_4_4"]
    else:
        raise ValueError(f"Unknown hardware provider: {hardware_provider}")

    with gr.Accordion("Control Panel", open=False, elem_id="control-panel"):
        with gr.Row():
            with gr.Column(scale=2, variant="panel"):
                memory_slider = gr.Slider(
                    label="Model Size (GB)",
                    info="🎚️ Slide to maximum Model Size",
                    minimum=0,
                    maximum=16,
                    value=16,
                    elem_id="memory-slider",
                )
            with gr.Column(scale=1, variant="panel"):
                quantization_checkboxes = gr.CheckboxGroup(
                    label="Quantizations",
                    choices=quantizations,
                    value=quantizations,
                    info="☑️ Select the quantization schemes",
                    elem_id="quantization-checkboxes",
                    elem_classes="boxed-option",
                )
        with gr.Row():
            filter_button = gr.Button(
                value="Filter 🚀",
                elem_id="filter-button",
                elem_classes="boxed-option",
            )

    return (
        filter_button,
        machine_value,
        backends_value,
        hardware_type_value,
        memory_slider,
        quantization_checkboxes,
    )

def filter_rows_fn(
    machine,
    backends,
    hardware_type,
    # inputs
    memory,
    quantizations,
    # interactive
    columns,
    search,
):
    llm_perf_df = get_llm_perf_df(
        machine=machine, backends=backends, hardware_type=hardware_type
    )
    # print(attentions)
    # print(llm_perf_df["Attention 👁️"].unique())
    filtered_llm_perf_df = llm_perf_df[
        llm_perf_df["Model"].str.contains(search, case=False)
        & llm_perf_df["Quantization"].isin(quantizations)
        & llm_perf_df["Model Size (GB)"] <= memory
    ]
    selected_filtered_llm_perf_df = select_columns_fn(
        machine, backends, hardware_type, columns, search, filtered_llm_perf_df
    )
    # filtered_bt_prefill_fig = get_bt_prefill_fig(filtered_df)
    # filtered_bt_decode_fig = get_bt_decode_fig(filtered_df)
    # filtered_fa2_prefill_fig = get_fa2_prefill_fig(filtered_df)
    # filtered_fa2_decode_fig = get_fa2_decode_fig(filtered_df)
    # filtered_quant_prefill_fig = get_quant_prefill_fig(filtered_df)
    # filtered_quant_decode_fig = get_quant_decode_fig(filtered_df)

    return [
        selected_filtered_llm_perf_df,
        # filtered_bt_prefill_fig,
        # filtered_bt_decode_fig,
        # filtered_fa2_prefill_fig,
        # filtered_fa2_decode_fig,
        # filtered_quant_prefill_fig,
        # filtered_quant_decode_fig,
    ]


def create_control_callback(
    # button
    filter_button,
    # fixed
    machine_value,
    backends_value,
    hardware_type_value,
    # inputs
    memory_slider,
    quantization_checkboxes,
    # interactive
    columns_checkboxes,
    search_bar,
    # outputs
    leaderboard_table,
    # attn_prefill_plot,
    # attn_decode_plot,
    # fa2_prefill_plot,
    # fa2_decode_plot,
    # quant_prefill_plot,
    # quant_decode_plot,
):
    filter_button.click(
        fn=filter_rows_fn,
        inputs=[
            # fixed
            machine_value,
            backends_value,
            hardware_type_value,
            # inputs
            memory_slider,
            quantization_checkboxes,
            # interactive
            columns_checkboxes,
            search_bar,
        ],
        outputs=[
            leaderboard_table,
            # attn_prefill_plot,
            # attn_decode_plot,
            # fa2_prefill_plot,
            # fa2_decode_plot,
            # quant_prefill_plot,
            # quant_decode_plot,
        ],
    )


def select_columns_fn(
    machine, backends, hardware_type, columns, search, llm_perf_df=None
):
    if llm_perf_df is None:
        llm_perf_df = get_llm_perf_df(
            machine=machine,
            backends=backends,
            hardware_type=hardware_type,
        )

    selected_leaderboard_df = get_leaderboard_df(llm_perf_df)
    selected_leaderboard_df = selected_leaderboard_df[
        selected_leaderboard_df["Model"].str.contains(search, case=False)
    ]
    selected_leaderboard_df = selected_leaderboard_df[columns]

    return selected_leaderboard_df


def create_select_callback(
    # fixed
    machine_value,
    backends_value,
    hardware_type_value,
    # interactive
    columns_checkboxes,
    search_bar,
    # outputs
    leaderboard_table,
):
    columns_checkboxes.change(
        fn=select_columns_fn,
        inputs=[
            machine_value,
            backends_value,
            hardware_type_value,
            columns_checkboxes,
            search_bar,
        ],
        outputs=[leaderboard_table],
    )
    search_bar.change(
        fn=select_columns_fn,
        inputs=[
            machine_value,
            backends_value,
            hardware_type_value,
            columns_checkboxes,
            search_bar,
        ],
        outputs=[leaderboard_table],
    )