File size: 4,867 Bytes
616f933 24d7e1c 616f933 24d7e1c 616f933 24d7e1c 616f933 24d7e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn import datasets
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import StratifiedKFold
colors = ["navy", "turquoise", "darkorange"]
def make_ellipses(gmm, ax):
for n, color in enumerate(colors):
if gmm.covariance_type == "full":
covariances = gmm.covariances_[n][:2, :2]
elif gmm.covariance_type == "tied":
covariances = gmm.covariances_[:2, :2]
elif gmm.covariance_type == "diag":
covariances = np.diag(gmm.covariances_[n][:2])
elif gmm.covariance_type == "spherical":
covariances = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n]
v, w = np.linalg.eigh(covariances)
u = w[0] / np.linalg.norm(w[0])
angle = np.arctan2(u[1], u[0])
angle = 180 * angle / np.pi # convert to degrees
v = 2.0 * np.sqrt(2.0) * np.sqrt(v)
ell = mpl.patches.Ellipse(
gmm.means_[n, :2], v[0], v[1], angle=180 + angle, color=color
)
ell.set_clip_box(ax.bbox)
ell.set_alpha(0.5)
ax.add_artist(ell)
ax.set_aspect("equal", "datalim")
def classify_iris(cov_type):
iris = datasets.load_iris()
# Break up the dataset into non-overlapping training (75%) and testing
# (25%) sets.
skf = StratifiedKFold(n_splits=4)
# Only take the first fold.
train_index, test_index = next(iter(skf.split(iris.data, iris.target)))
X_train = iris.data[train_index]
y_train = iris.target[train_index]
X_test = iris.data[test_index]
y_test = iris.target[test_index]
n_classes = len(np.unique(y_train))
# Try GMMs using different types of covariances.
estimator = GaussianMixture(
n_components=n_classes, covariance_type=cov_type, max_iter=20, random_state=0
)
# Since we have class labels for the training data, we can
# initialize the GMM parameters in a supervised manner.
estimator.means_init = np.array(
[X_train[y_train == i].mean(axis=0) for i in range(n_classes)]
)
# Train the other parameters using the EM algorithm.
estimator.fit(X_train)
fig, ax = plt.subplots(figsize=(8, 6))
make_ellipses(estimator, ax)
for n, color in enumerate(colors):
data = iris.data[iris.target == n]
ax.scatter(data[:, 0], data[:, 1], s=0.8, color=color, label=iris.target_names[n])
# Plot the test data with crosses
for n, color in enumerate(colors):
data = X_test[y_test == n]
ax.scatter(data[:, 0], data[:, 1], marker="x", color=color)
y_train_pred = estimator.predict(X_train)
train_accuracy = np.mean(y_train_pred.ravel() == y_train.ravel()) * 100
ax.text(0.05, 0.9, "Train accuracy: %.1f" % train_accuracy, transform=ax.transAxes)
y_test_pred = estimator.predict(X_test)
test_accuracy = np.mean(y_test_pred.ravel() == y_test.ravel()) * 100
ax.text(0.05, 0.8, "Test accuracy: %.1f" % test_accuracy, transform=ax.transAxes)
ax.set_xticks(())
ax.set_yticks(())
ax.set_title(cov_type.capitalize())
plt.legend(scatterpoints=1, loc="lower right", prop=dict(size=12))
# Save the plot to a file and return its path
output_path = "classification_plot.png"
plt.savefig(output_path)
plt.close()
return output_path
def update_plot(cov_type):
# image_path = classify_iris(cov_type)
# return gr.outputs.Image(image_path)
image_path = classify_iris(cov_type)
image = Image.open(image_path)
image_array = np.array(image)
return image_array
iface = gr.Interface(
fn=update_plot,
inputs=gr.inputs.Radio(
["spherical", "diag", "tied", "full"],
label="Covariance Type",
default="spherical"
),
outputs=gr.outputs.Image(type="numpy"),
live=True,
title="Gaussian Mixture Model Covariance",
description="Explore different covariance types for Gaussian mixture models (GMMs) in this demonstration. GMMs are commonly used for clustering, but in this example, we compare the obtained clusters with the actual classes from the dataset. By initializing the means of the Gaussians with the means of the classes in the training set, we ensure a valid comparison. The plots show the predicted labels on both training and test data using GMMs with spherical, diagonal, full, and tied covariance matrices. Interestingly, while full covariance is expected to perform best, it may overfit small datasets and struggle to generalize to held out test data. See the original scikit-learn example for more information: https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_covariances.html",
examples=[
["spherical"],
["diag"],
["tied"],
["full"],
],
)
iface.launch() |