Sujatha's picture
Update app.py
a720b40 verified
import gradio as gr
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN, Birch, MeanShift
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA
from scipy.cluster.hierarchy import linkage, dendrogram
import matplotlib.pyplot as plt
import skfuzzy as fuzz # For Fuzzy C-Means
import hdbscan
def apply_clustering(algorithm, n_clusters, dataset):
# Read dataset
data = pd.read_csv(dataset.name)
data_matrix = data.values
# Apply dimensionality reduction for visualization
pca = PCA(n_components=2)
reduced_data = pca.fit_transform(data_matrix)
# Select clustering algorithm
if algorithm == "KMeans":
model = KMeans(n_clusters=n_clusters, random_state=42)
labels = model.fit_predict(data_matrix)
elif algorithm == "Fuzzy C-Means (FCM)":
# Use skfuzzy for Fuzzy C-Means
cntr, u, _, _, _, _, _ = fuzz.cmeans(data_matrix.T, n_clusters, 2, error=0.005, maxiter=1000)
labels = np.argmax(u, axis=0) # Get the fuzzy cluster assignment
elif algorithm == "Agglomerative Hierarchical Clustering (AHC)":
model = AgglomerativeClustering(n_clusters=n_clusters)
labels = model.fit_predict(data_matrix)
# Dendrogram
Z = linkage(data_matrix, 'ward') # 'ward' minimizes variance of merged clusters
plt.figure(figsize=(10, 7))
dendrogram(Z)
plt.title("Dendrogram for Agglomerative Clustering")
plt.xlabel("Sample index")
plt.ylabel("Distance")
plt.tight_layout()
plt.savefig("dendrogram.png")
plt.close()
return f"Agglomerative clustering applied successfully.", "dendrogram.png"
elif algorithm == "BIRCH":
model = Birch(n_clusters=n_clusters)
labels = model.fit_predict(data_matrix)
elif algorithm == "DBSCAN":
model = DBSCAN(eps=0.5, min_samples=5)
labels = model.fit_predict(data_matrix)
elif algorithm == "HDBSCAN":
model = hdbscan.HDBSCAN(min_samples=5)
labels = model.fit_predict(data_matrix)
elif algorithm == "Mean-Shift":
model = MeanShift()
labels = model.fit_predict(data_matrix)
elif algorithm == "Gaussian Mixture Models (GMM)":
model = GaussianMixture(n_components=n_clusters)
model.fit(data_matrix)
labels = model.predict(data_matrix)
else:
return "Algorithm not supported yet.", None
# Visualization for Clustering
plt.figure(figsize=(8, 6))
plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=labels, cmap='viridis', s=50)
plt.colorbar(label='Cluster Label')
plt.title(f"Clusters Visualization ({algorithm})")
plt.xlabel("PCA Component 1")
plt.ylabel("PCA Component 2")
plt.tight_layout()
plt.savefig("clusters_plot.png")
plt.close()
return f"{algorithm} clustering applied successfully.", "clusters_plot.png"
# Gradio Interface
def main_interface():
dataset = gr.File(label="Upload Dataset (CSV format)")
algorithm = gr.Dropdown(
choices=[
"KMeans",
"Fuzzy C-Means (FCM)",
"Agglomerative Hierarchical Clustering (AHC)",
"BIRCH",
"DBSCAN",
"HDBSCAN",
"Mean-Shift",
"Gaussian Mixture Models (GMM)"
],
label="Select Algorithm"
)
n_clusters = gr.Slider(minimum=2, maximum=10, value=3, step=1, label="Number of Clusters (for applicable algorithms)")
output_text = gr.Textbox(label="Result")
output_image = gr.Image(label="Cluster Visualization")
gr.Interface(
fn=apply_clustering,
inputs=[algorithm, n_clusters, dataset],
outputs=[output_text, output_image]
).launch()
# Run the application
if __name__ == "__main__":
main_interface()