|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib.colors import LogNorm |
|
from sklearn import mixture |
|
import gradio as gr |
|
import tempfile |
|
import os |
|
|
|
def generate_gaussian_mixture(n_samples): |
|
|
|
np.random.seed(0) |
|
|
|
|
|
shifted_gaussian = np.random.randn(n_samples, 2) + np.array([20, 20]) |
|
|
|
|
|
C = np.array([[0.0, -0.7], [3.5, 0.7]]) |
|
stretched_gaussian = np.dot(np.random.randn(n_samples, 2), C) |
|
|
|
|
|
X_train = np.vstack([shifted_gaussian, stretched_gaussian]) |
|
|
|
|
|
clf = mixture.GaussianMixture(n_components=2, covariance_type="full") |
|
clf.fit(X_train) |
|
|
|
|
|
x = np.linspace(-20.0, 30.0) |
|
y = np.linspace(-20.0, 40.0) |
|
X, Y = np.meshgrid(x, y) |
|
XX = np.array([X.ravel(), Y.ravel()]).T |
|
Z = -clf.score_samples(XX) |
|
Z = Z.reshape(X.shape) |
|
|
|
fig, ax = plt.subplots() |
|
CS = ax.contour( |
|
X, Y, Z, norm=LogNorm(vmin=1.0, vmax=1000.0), levels=np.logspace(0, 3, 10) |
|
) |
|
CB = fig.colorbar(CS, shrink=0.8, extend="both") |
|
ax.scatter(X_train[:, 0], X_train[:, 1], 0.8) |
|
|
|
ax.set_title("Negative log-likelihood predicted by a GMM") |
|
ax.axis("tight") |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
temp_file_path = os.path.join(temp_dir, "gmm_plot.png") |
|
fig.savefig(temp_file_path) |
|
plt.close(fig) |
|
|
|
return temp_file_path |
|
|
|
def plot_to_image(file_path): |
|
with open(file_path, "rb") as f: |
|
image_bytes = f.read() |
|
os.remove(file_path) |
|
return image_bytes |
|
|
|
inputs = gr.inputs.Slider(100, 1000, step=100, default=300, label="Number of Samples") |
|
outputs = gr.outputs.Image(type="pil", label="GMM Plot") |
|
|
|
title = "Density Estimation for a Gaussian mixture" |
|
description = "In this example, you can visualize the density estimation of a mixture of two Gaussians using a Gaussian Mixture Model (GMM). The data used for the model is generated from two Gaussians with distinct centers and covariance matrices. By adjusting the number of samples, you can observe how the GMM captures the underlying distribution and generates a contour plot representing the estimated density. This interactive application allows you to explore the behavior of the GMM and gain insights into the modeling of complex data distributions using mixture models. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_pdf.html" |
|
gr.Interface(generate_gaussian_mixture, inputs, outputs, title=title, description=description, postprocess=plot_to_image, live=True).launch() |
|
|