|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN, Birch, MeanShift |
|
from sklearn.mixture import GaussianMixture |
|
from sklearn.decomposition import PCA |
|
from scipy.cluster.hierarchy import linkage, dendrogram |
|
import matplotlib.pyplot as plt |
|
import skfuzzy as fuzz |
|
import hdbscan |
|
|
|
def apply_clustering(algorithm, n_clusters, dataset): |
|
|
|
data = pd.read_csv(dataset.name) |
|
data_matrix = data.values |
|
|
|
|
|
pca = PCA(n_components=2) |
|
reduced_data = pca.fit_transform(data_matrix) |
|
|
|
|
|
if algorithm == "KMeans": |
|
model = KMeans(n_clusters=n_clusters, random_state=42) |
|
labels = model.fit_predict(data_matrix) |
|
elif algorithm == "Fuzzy C-Means (FCM)": |
|
|
|
cntr, u, _, _, _, _, _ = fuzz.cmeans(data_matrix.T, n_clusters, 2, error=0.005, maxiter=1000) |
|
labels = np.argmax(u, axis=0) |
|
elif algorithm == "Agglomerative Hierarchical Clustering (AHC)": |
|
model = AgglomerativeClustering(n_clusters=n_clusters) |
|
labels = model.fit_predict(data_matrix) |
|
|
|
|
|
Z = linkage(data_matrix, 'ward') |
|
plt.figure(figsize=(10, 7)) |
|
dendrogram(Z) |
|
plt.title("Dendrogram for Agglomerative Clustering") |
|
plt.xlabel("Sample index") |
|
plt.ylabel("Distance") |
|
plt.tight_layout() |
|
plt.savefig("dendrogram.png") |
|
plt.close() |
|
return f"Agglomerative clustering applied successfully.", "dendrogram.png" |
|
|
|
elif algorithm == "BIRCH": |
|
model = Birch(n_clusters=n_clusters) |
|
labels = model.fit_predict(data_matrix) |
|
elif algorithm == "DBSCAN": |
|
model = DBSCAN(eps=0.5, min_samples=5) |
|
labels = model.fit_predict(data_matrix) |
|
elif algorithm == "HDBSCAN": |
|
model = hdbscan.HDBSCAN(min_samples=5) |
|
labels = model.fit_predict(data_matrix) |
|
elif algorithm == "Mean-Shift": |
|
model = MeanShift() |
|
labels = model.fit_predict(data_matrix) |
|
elif algorithm == "Gaussian Mixture Models (GMM)": |
|
model = GaussianMixture(n_components=n_clusters) |
|
model.fit(data_matrix) |
|
labels = model.predict(data_matrix) |
|
else: |
|
return "Algorithm not supported yet.", None |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=labels, cmap='viridis', s=50) |
|
plt.colorbar(label='Cluster Label') |
|
plt.title(f"Clusters Visualization ({algorithm})") |
|
plt.xlabel("PCA Component 1") |
|
plt.ylabel("PCA Component 2") |
|
plt.tight_layout() |
|
plt.savefig("clusters_plot.png") |
|
plt.close() |
|
|
|
return f"{algorithm} clustering applied successfully.", "clusters_plot.png" |
|
|
|
|
|
def main_interface(): |
|
dataset = gr.File(label="Upload Dataset (CSV format)") |
|
algorithm = gr.Dropdown( |
|
choices=[ |
|
"KMeans", |
|
"Fuzzy C-Means (FCM)", |
|
"Agglomerative Hierarchical Clustering (AHC)", |
|
"BIRCH", |
|
"DBSCAN", |
|
"HDBSCAN", |
|
"Mean-Shift", |
|
"Gaussian Mixture Models (GMM)" |
|
], |
|
label="Select Algorithm" |
|
) |
|
n_clusters = gr.Slider(minimum=2, maximum=10, value=3, step=1, label="Number of Clusters (for applicable algorithms)") |
|
|
|
output_text = gr.Textbox(label="Result") |
|
output_image = gr.Image(label="Cluster Visualization") |
|
|
|
gr.Interface( |
|
fn=apply_clustering, |
|
inputs=[algorithm, n_clusters, dataset], |
|
outputs=[output_text, output_image] |
|
).launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main_interface() |
|
|
|
|
|
|