from functools import partial
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import gradio as gr
from typing import Dict, List

from src.logic.data_processing import PARTITION_OPTIONS, prepare_for_non_grouped_plotting, prepare_for_group_plotting
from src.logic.graph_settings import Grouping
from src.logic.utils import set_alpha
from datatrove.utils.stats import MetricStatsDict

def plot_scatter(
        data: Dict[str, MetricStatsDict],
        metric_name: str,
        log_scale_x: bool,
        log_scale_y: bool,
        normalization: bool,
        rounding: int,
        cumsum: bool,
        perc: bool,
        progress: gr.Progress,
):
    fig = go.Figure()
    data = {name: histogram for name, histogram in sorted(data.items())}
    for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
        histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding)
        x = sorted(histogram_prepared.keys())
        y = [histogram_prepared[k] for k in x]
        if cumsum:
            y = np.cumsum(y).tolist()
        if perc:
            y = (np.array(y) * 100).tolist()

        fig.add_trace(
            go.Scatter(
                x=x,
                y=y,
                mode="lines",
                name=name,
                marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
            )
        )

    yaxis_title = "Frequency" if normalization else "Total"

    fig.update_layout(
        title=f"Line Plots for {metric_name}",
        xaxis_title=metric_name,
        yaxis_title=yaxis_title,
        xaxis_type="log" if log_scale_x and len(x) > 1 else None,
        yaxis_type="log" if log_scale_y and len(y) > 1 else None,
        width=1200,
        height=600,
        showlegend=True,
    )

    return fig

def plot_bars(
        data: Dict[str, MetricStatsDict],
        metric_name: str,
        top_k: int,
        direction: PARTITION_OPTIONS,
        regex: str | None,
        rounding: int,
        log_scale_x: bool,
        log_scale_y: bool,
        show_stds: bool,
        progress: gr.Progress,
):
    fig = go.Figure()
    x = []
    y = []

    for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
        x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding)

        fig.add_trace(go.Bar(
            x=x,
            y=y,
            name=f"{name} Mean",
            marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
            error_y=dict(type='data', array=stds, visible=show_stds)
        ))

    fig.update_layout(
        title=f"Bar Plots for {metric_name}",
        xaxis_title=metric_name,
        yaxis_title="Avg. value",
        xaxis_type="log" if log_scale_x and len(x) > 1 else None,
        yaxis_type="log" if log_scale_y and len(y) > 1 else None,
        autosize=True,
        width=1200,
        height=600,
        showlegend=True,
    )

    return fig


# Add any other necessary functions

def plot_data(
    metric_data: Dict[str, MetricStatsDict],
    metric_name: str,
    normalize: bool,
    rounding: int,
    grouping: Grouping,
    top_n: int,
    direction: PARTITION_OPTIONS,
    group_regex: str,
    log_scale_x: bool,
    log_scale_y: bool,
    cdf: bool,
    perc: bool,
    show_stds: bool,
) -> tuple[go.Figure, gr.Row, str]:
    if grouping == "histogram":
        fig = plot_scatter(
            metric_data,
            metric_name,
            log_scale_x,
            log_scale_y,
            normalize,
            rounding,
            cdf,
            perc,
            gr.Progress(),
        )
        min_max_hist_data = generate_min_max_hist_data(metric_data)
        return fig, gr.Row.update(visible=True), min_max_hist_data
    else:
        fig = plot_bars(
            metric_data,
            metric_name,
            top_n,
            direction,
            group_regex,
            rounding,
            log_scale_x,
            log_scale_y,
            show_stds,
            gr.Progress(),
        )
        return fig, gr.Row.update(visible=True), ""

def generate_min_max_hist_data(data: Dict[str, MetricStatsDict]) -> str:
    runs_data = {
        run: {
            "min": min(map(float, dato.keys())),
            "max": max(map(float, dato.keys())),
        }
        for run, dato in data.items()
    }
    
    runs_rows = [
        f"| {run} | {values['min']:.4f} | {values['max']:.4f} |"
        for run, values in runs_data.items()
    ]
    header = "| Run | Min | Max |\n|-----|-----|-----|\n"
    return header + "\n".join(runs_rows)