import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from itertools import combinations
from functools import partial

plt.rcParams['figure.dpi'] = 100

from sklearn.datasets import load_iris
from sklearn.ensemble import (
    RandomForestClassifier,
    ExtraTreesClassifier,
    AdaBoostClassifier,
)
from sklearn.tree import DecisionTreeClassifier

import gradio as gr

# ========================================

C1, C2, C3 = '#ff0000', '#ffff00', '#0000ff'
CMAP = ListedColormap([C1, C2, C3])
GRANULARITY = 0.05
SEED = 1
N_ESTIMATORS = 30

FEATURES = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
LABELS = ["Setosa", "Versicolour", "Virginica"]
MODEL_NAMES = ['DecisionTreeClassifier', 'RandomForestClassifier', 'ExtraTreesClassifier', 'AdaBoostClassifier']

iris = load_iris()

MODELS = [
        DecisionTreeClassifier(max_depth=None),
        RandomForestClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
        ExtraTreesClassifier(n_estimators=N_ESTIMATORS, n_jobs=-1),
        AdaBoostClassifier(DecisionTreeClassifier(max_depth=3), n_estimators=N_ESTIMATORS)
        ]

# ========================================

def create_plot(feature_string, n_estimators, max_depth, model_idx):
    np.random.seed(SEED)

    feature_list = feature_string.split(',')
    feature_list = [s.strip() for s in feature_list]
    idx_x = FEATURES.index(feature_list[0])
    idx_y = FEATURES.index(feature_list[1])

    X = iris.data[:, [idx_x, idx_y]]
    y = iris.target

    rnd_idx = np.random.permutation(X.shape[0])
    X = X[rnd_idx]
    y = y[rnd_idx]

    X = (X - X.mean(0)) / X.std(0)

    model_name = MODEL_NAMES[model_idx]
    model = MODELS[model_idx]

    if model_idx != 0: model.n_estimators = n_estimators
    if model_idx != 3: model.max_depth = max_depth
    if model_idx == 3: model.estimator.max_depth = max_depth
    
    model.fit(X, y)
    score = round(model.score(X, y), 3)

    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xrange = np.arange(x_min, x_max, GRANULARITY)
    yrange = np.arange(y_min, y_max, GRANULARITY)
    xx, yy = np.meshgrid(xrange, yrange)

    Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    fig = plt.figure(figsize=(4, 3.5))
    ax = fig.add_subplot(111)

    ax.contourf(xx, yy, Z, cmap=CMAP, alpha=0.65)

    for i, label in enumerate(LABELS):
        X_label = X[y==i,:]
        y_label = y[y==i]
        ax.scatter(X_label[:, 0], X_label[:, 1], c=[[C1], [C2], [C3]][i]*len(y_label), edgecolor='k', s=40, label=label)

    ax.set_xlabel(feature_list[0]); ax.set_ylabel(feature_list[1])
    ax.legend()
    ax.set_title(f'{model_name} | Score: {score}')
    fig.set_tight_layout(True)
    fig.set_constrained_layout(True)

    return fig

def iter_grid(n_rows, n_cols):
    for _ in range(n_rows):
        with gr.Row():
            for _ in range(n_cols):
                with gr.Column():
                    yield

info = '''
# Plot the decision surfaces of ensembles of trees on the Iris dataset

This plot compares the **decision surfaces** learned by a decision tree classifier, a random forest classifier, an extra-trees classifier, and by an AdaBoost classifier.

There are in total **four features** in the Iris dataset. In this example you can select **two features at a time** for visualization purposes using the dropdown box below. All features are normalized to zero mean and unit standard deviation.

Play around with the **number of estimators** in the ensembles and the **max depth** of the trees using the sliders.

Created by [@hubadul](https://huggingface.co/huabdul) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_iris.html).
'''

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(info)
            selections = combinations(FEATURES, 2)
            selections = [f'{s[0]}, {s[1]}' for s in selections] 
            dd = gr.Dropdown(selections, value=selections[0], interactive=True, label="Input features")
            slider_estimators = gr.Slider(1, 100, value=30, step=1, label='n_estimators')
            slider_max_depth = gr.Slider(1, 50, value=10, step=1, label='max_depth')

        with gr.Column(scale=2):
            counter = 0
            for _ in iter_grid(2, 2):
                if counter >= len(MODELS):
                    break
                
                plot = gr.Plot(show_label=False)
                fn = partial(create_plot, model_idx=counter)

                dd.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
                slider_estimators.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
                slider_max_depth.change(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])
                demo.load(fn, inputs=[dd, slider_estimators, slider_max_depth], outputs=[plot])

                counter += 1

demo.launch()