"""Gradio demo for different clustering techiniques

Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html

"""

import math
from functools import partial

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import (
    AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth
)
from sklearn.datasets import make_blobs, make_circles, make_moons
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import StandardScaler


plt.style.use('seaborn')


SEED = 0
MAX_CLUSTERS = 10
N_SAMPLES = 1000
N_COLS = 3
FIGSIZE = 7, 7  # does not affect size in webpage
COLORS = [
    'blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'
]
assert len(COLORS) >= MAX_CLUSTERS, "Not enough different colors for all clusters"
np.random.seed(SEED)


def normalize(X):
    return StandardScaler().fit_transform(X)


def get_regular(n_clusters):
    # spiral pattern
    centers = [
        [0, 0],
        [1, 0],
        [1, 1],
        [0, 1],
        [-1, 1],
        [-1, 0],
        [-1, -1],
        [0, -1],
        [1, -1],
        [2, -1],
    ][:n_clusters]
    assert len(centers) == n_clusters
    X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.25, random_state=SEED)
    return normalize(X), labels


def get_circles(n_clusters):
    X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)
    return normalize(X), labels


def get_moons(n_clusters):
    X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)
    return normalize(X), labels


def get_noise(n_clusters):
    np.random.seed(SEED)
    X, labels = np.random.rand(N_SAMPLES, 2), np.random.randint(0, n_clusters, size=(N_SAMPLES,))
    return normalize(X), labels


def get_anisotropic(n_clusters):
    X, labels = make_blobs(n_samples=N_SAMPLES, centers=n_clusters, random_state=170)
    transformation = [[0.6, -0.6], [-0.4, 0.8]]
    X = np.dot(X, transformation)
    return X, labels


def get_varied(n_clusters):
    cluster_std = [1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0, 2.5, 0.5, 1.0][:n_clusters]
    assert len(cluster_std) == n_clusters
    X, labels = make_blobs(
        n_samples=N_SAMPLES, centers=n_clusters, cluster_std=cluster_std, random_state=SEED
    )
    return normalize(X), labels


def get_spiral(n_clusters):
    # from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering.html
    np.random.seed(SEED)
    t = 1.5 * np.pi * (1 + 3 * np.random.rand(1, N_SAMPLES))
    x = t * np.cos(t)
    y = t * np.sin(t)
    X = np.concatenate((x, y))
    X += 0.7 * np.random.randn(2, N_SAMPLES)
    X = np.ascontiguousarray(X.T)

    labels = np.zeros(N_SAMPLES, dtype=int)
    return normalize(X), labels


DATA_MAPPING = {
    'regular': get_regular,
    'circles': get_circles,
    'moons': get_moons,
    'spiral': get_spiral,
    'noise': get_noise,
    'anisotropic': get_anisotropic,
    'varied': get_varied,
}


def get_groundtruth_model(X, labels, n_clusters, **kwargs):
    # dummy model to show true label distribution
    class Dummy:
        def __init__(self, y):
            self.labels_ = labels

    return Dummy(labels)


def get_kmeans(X, labels, n_clusters, **kwargs):
    model = KMeans(init="k-means++", n_clusters=n_clusters, n_init=10, random_state=SEED)
    model.set_params(**kwargs)
    return model.fit(X)


def get_dbscan(X, labels, n_clusters, **kwargs):
    model = DBSCAN(eps=0.3)
    model.set_params(**kwargs)
    return model.fit(X)


def get_agglomerative(X, labels, n_clusters, **kwargs):
    connectivity = kneighbors_graph(
        X, n_neighbors=n_clusters, include_self=False
    )
    # make connectivity symmetric
    connectivity = 0.5 * (connectivity + connectivity.T)
    model = AgglomerativeClustering(
        n_clusters=n_clusters, linkage="ward", connectivity=connectivity
    )
    model.set_params(**kwargs)
    return model.fit(X)


def get_meanshift(X, labels, n_clusters, **kwargs):
    bandwidth = estimate_bandwidth(X, quantile=0.25)
    model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    model.set_params(**kwargs)
    return model.fit(X)


def get_spectral(X, labels, n_clusters, **kwargs):
    model = SpectralClustering(
        n_clusters=n_clusters,
        eigen_solver="arpack",
        affinity="nearest_neighbors",
    )
    model.set_params(**kwargs)
    return model.fit(X)


def get_optics(X, labels, n_clusters, **kwargs):
    model = OPTICS(
        min_samples=7,
        xi=0.05,
        min_cluster_size=0.1,
    )
    model.set_params(**kwargs)
    return model.fit(X)


def get_birch(X, labels, n_clusters, **kwargs):
    model = Birch(n_clusters=n_clusters)
    model.set_params(**kwargs)
    return model.fit(X)


def get_gaussianmixture(X, labels, n_clusters, **kwargs):
    model = GaussianMixture(
        n_components=n_clusters, covariance_type="full", random_state=SEED,
    )
    model.set_params(**kwargs)
    return model.fit(X)


MODEL_MAPPING = {
    'True labels': get_groundtruth_model,
    'KMeans': get_kmeans,
    'DBSCAN': get_dbscan,
    'MeanShift': get_meanshift,
    'SpectralClustering': get_spectral,
    'OPTICS': get_optics,
    'Birch': get_birch,
    'GaussianMixture': get_gaussianmixture,
    'AgglomerativeClustering': get_agglomerative,
}


def plot_clusters(ax, X, labels):
    set_clusters = set(labels)
    set_clusters.discard(-1)  # -1 signifiies outliers, which we plot separately
    for label, color in zip(sorted(set_clusters), COLORS):
        idx = labels == label
        if not sum(idx):
            continue
        ax.scatter(X[idx, 0], X[idx, 1], color=color)

    # show outliers (if any)
    idx = labels == -1
    if sum(idx):
        ax.scatter(X[idx, 0], X[idx, 1], c='k', marker='x')

    ax.grid(None)
    ax.set_xticks([])
    ax.set_yticks([])
    return ax


def cluster(dataset: str, n_clusters: int, clustering_algorithm: str):
    if isinstance(n_clusters, dict):
        n_clusters = n_clusters['value']
    else:
        n_clusters = int(n_clusters)

    X, labels = DATA_MAPPING[dataset](n_clusters)
    model = MODEL_MAPPING[clustering_algorithm](X, labels, n_clusters=n_clusters)
    if hasattr(model, "labels_"):
        y_pred = model.labels_.astype(int)
    else:
        y_pred = model.predict(X)

    fig, ax = plt.subplots(figsize=FIGSIZE)

    plot_clusters(ax, X, y_pred)
    ax.set_title(clustering_algorithm, fontsize=16)

    return fig


title = "Clustering with Scikit-learn"
description = (
    "This example shows how different clustering algorithms work. Simply pick "
    "the dataset and the number of clusters to see how the clustering algorithms work. "
    "Colored cirles are (predicted) labels and black x are outliers."
)


def iter_grid(n_rows, n_cols):
    # create a grid using gradio Block
    for _ in range(n_rows):
        with gr.Row():
            for _ in range(n_cols):
                with gr.Column():
                    yield


with gr.Blocks(title=title) as demo:
    gr.HTML(f"<b>{title}</b>")
    gr.Markdown(description)

    input_models = list(MODEL_MAPPING)
    input_data = gr.Radio(
        list(DATA_MAPPING),
        value="regular",
        label="dataset"
    )
    input_n_clusters = gr.Slider(
        minimum=1,
        maximum=MAX_CLUSTERS,
        value=4,
        step=1,
        label='Number of clusters'
    )
    n_rows = int(math.ceil(len(input_models) / N_COLS))
    counter = 0
    for _ in iter_grid(n_rows, N_COLS):
        if counter >= len(input_models):
            break

        input_model = input_models[counter]
        plot = gr.Plot(label=input_model)
        fn = partial(cluster, clustering_algorithm=input_model)
        input_data.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
        input_n_clusters.change(fn=fn, inputs=[input_data, input_n_clusters], outputs=plot)
        counter += 1


demo.launch()