Spaces:
Sleeping
Sleeping
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from sklearn.cluster import MeanShift, estimate_bandwidth | |
from sklearn.datasets import make_blobs | |
def get_clusters_plot(n_blobs, quantile, cluster_std): | |
X, _, centers = make_blobs( | |
n_samples=10000, cluster_std=cluster_std, centers=n_blobs, return_centers=True | |
) | |
bandwidth = estimate_bandwidth(X, quantile=quantile, n_samples=500) | |
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True) | |
ms.fit(X) | |
labels = ms.labels_ | |
cluster_centers = ms.cluster_centers_ | |
labels_unique = np.unique(labels) | |
n_clusters_ = len(labels_unique) | |
fig = plt.figure() | |
for k in range(n_clusters_): | |
my_members = labels == k | |
cluster_center = cluster_centers[k] | |
plt.scatter(X[my_members, 0], X[my_members, 1]) | |
plt.plot( | |
cluster_center[0], | |
cluster_center[1], | |
"x", | |
markeredgecolor="k", | |
markersize=14, | |
) | |
plt.xlabel("Feature 1") | |
plt.ylabel("Feature 2") | |
plt.title(f"Estimated number of clusters: {n_clusters_}") | |
if len(centers) != n_clusters_: | |
message = ( | |
'<p style="text-align: center;">' | |
+ f"The number of estimated clusters ({n_clusters_})" | |
+ f" differs from the true number of clusters ({n_blobs})." | |
+ " Try changing the `Quantile` parameter.</p>" | |
) | |
else: | |
message = ( | |
'<p style="text-align: center;">' | |
+ f"The number of estimated clusters ({n_clusters_})" | |
+ f" matches the true number of clusters ({n_blobs})!</p>" | |
) | |
return fig, message | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Mean Shift Clustering | |
This space shows how to use the [Mean Shift Clustering](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html) algorithm to cluster 2D data points. You can change the parameters using the sliders and see how the model performs. | |
This space is based on [sklearn's original demo](https://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py). | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
n_blobs = gr.Slider( | |
minimum=2, | |
maximum=10, | |
label="Number of clusters in the data", | |
step=1, | |
value=3, | |
) | |
quantile = gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.2, | |
label="Quantile", | |
info="Used to determine clustering's bandwidth.", | |
) | |
cluster_std = gr.Slider( | |
minimum=0.1, | |
maximum=1, | |
label="Clusters' standard deviation", | |
step=0.1, | |
value=0.6, | |
) | |
with gr.Column(scale=4): | |
clusters_plots = gr.Plot(label="Clusters' Plot") | |
message = gr.HTML() | |
n_blobs.change( | |
get_clusters_plot, | |
[n_blobs, quantile, cluster_std], | |
[clusters_plots, message], | |
queue=False, | |
) | |
quantile.change( | |
get_clusters_plot, | |
[n_blobs, quantile, cluster_std], | |
[clusters_plots, message], | |
queue=False, | |
) | |
cluster_std.change( | |
get_clusters_plot, | |
[n_blobs, quantile, cluster_std], | |
[clusters_plots, message], | |
queue=False, | |
) | |
demo.load( | |
get_clusters_plot, | |
[n_blobs, quantile, cluster_std], | |
[clusters_plots, message], | |
queue=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |