"""
==========================================================
Gradio demo to Plot multi-class SGD on the iris dataset
==========================================================

Plot decision surface of multi-class SGD on iris dataset.
The hyperplanes corresponding to the three one-versus-all (OVA) classifiers
are represented by the dashed lines.

Created by Syed Affan <saffand03@gmail.com>

"""
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.linear_model import SGDClassifier
from sklearn.inspection import DecisionBoundaryDisplay
import matplotlib.cm

def make_plot(alpha,max_iter,Standardize):
# import some data to play with
    iris = datasets.load_iris()
    fig = plt.figure()

# we only take the first two features. We could
# avoid this ugly slicing by using a two-dim dataset
    X = iris.data[:, :2]
    y = iris.target
    colors = "bry"

# shuffle
    idx = np.arange(X.shape[0])
    np.random.seed(13)
    np.random.shuffle(idx)
    X = X[idx]
    y = y[idx]

# standardize
    if Standardize:
        mean = X.mean(axis=0)
        std = X.std(axis=0)
        X = (X - mean) / std


    clf = SGDClassifier(alpha=alpha, max_iter=max_iter).fit(X, y)
    accuracy = clf.score(X,y)
    acc = f'### The Accuracy on the entire dataset: {accuracy}'
    ax = plt.gca()
    DecisionBoundaryDisplay.from_estimator(
        clf,
        X,
        cmap=matplotlib.cm.Paired,
        ax=ax,
        response_method="predict",
        xlabel=iris.feature_names[0],
        ylabel=iris.feature_names[1],
    )
    plt.axis("tight")

# Plot also the training points
    for i, color in zip(clf.classes_, colors):
        idx = np.where(y == i)
        plt.scatter(
            X[idx, 0],
            X[idx, 1],
            c=color,
            label=iris.target_names[i],
            cmap=matplotlib.cm.Paired,
            edgecolor="black",
            s=20,
        )
    plt.title("Decision surface of multi-class SGD")
    plt.axis("tight")

# Plot the three one-against-all classifiers
    xmin, xmax = plt.xlim()
    ymin, ymax = plt.ylim()
    coef = clf.coef_
    intercept = clf.intercept_


    def plot_hyperplane(c, color):
        def line(x0):
            return (-(x0 * coef[c, 0]) - intercept[c]) / coef[c, 1]

        plt.plot([xmin, xmax], [line(xmin), line(xmax)], ls="--", color=color)


    for i, color in zip(clf.classes_, colors):
        plot_hyperplane(i, color)
    plt.legend()
    
    return fig,acc

title = "Plot multi-class SGD on the iris dataset"

model_card = f"""
## Description
This interactive demo is based on the [Plot multi-class SGD on the iris dataset](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_iris.html#sphx-glr-auto-examples-linear-model-plot-sgd-iris-py) example from the popular [scikit-learn](https://scikit-learn.org/stable/)  library, which is a widely-used library for machine learning in Python. 
This demo plots the decision surface of multi-class SGD on the iris dataset. The hyperplanes corresponding to the three one-versus-all (OVA) classifiers are represented by the dashed lines.
You can play with the following hyperparameters:
`alpha` is a constant that multiplies the regularization term. The higher the value, the stronger the regularization. 
`max_iter` is the maximum number of passes over the training data (aka epochs). 
`Standardise` centers the dataset

## Dataset
[Iris Dataset](https://en.wikipedia.org/wiki/Iris_flower_data_set)

## Model
currentmodule: [sklearn.linear_model](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model)
class:`SGDClassifier` is the estimator used in this example.

"""

with gr.Blocks(title=title) as demo:
    gr.Markdown('''
            <div>
            <h1 style='text-align: center'>Plot multi-class SGD on iris dataset</h1>
            </div>
        ''')
 
    gr.Markdown(model_card)
    gr.Markdown("Author: <a href=\"https://huggingface.co/sulpha\">sulpha</a>")
    d0 = gr.Slider(0.001,5,step=0.001,value=0.001,label='alpha')
    d1 = gr.Slider(1,1001,step=10,value=100,label='max_iter')
    d2 = gr.Checkbox(value=True,label='Standardize')

    btn =gr.Button(value='Submit')
    btn.click(make_plot,inputs=[d0,d1,d2],outputs=[gr.Plot(),gr.Markdown()])

demo.launch()