caliex's picture
Update app.py
0c29b07
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from sklearn import mixture
import gradio as gr
import tempfile
import os
def generate_gaussian_mixture(n_samples):
# generate random sample, two components
np.random.seed(0)
# generate spherical data centered on (20, 20)
shifted_gaussian = np.random.randn(n_samples, 2) + np.array([20, 20])
# generate zero centered stretched Gaussian data
C = np.array([[0.0, -0.7], [3.5, 0.7]])
stretched_gaussian = np.dot(np.random.randn(n_samples, 2), C)
# concatenate the two datasets into the final training set
X_train = np.vstack([shifted_gaussian, stretched_gaussian])
# fit a Gaussian Mixture Model with two components
clf = mixture.GaussianMixture(n_components=2, covariance_type="full")
clf.fit(X_train)
# display predicted scores by the model as a contour plot
x = np.linspace(-20.0, 30.0)
y = np.linspace(-20.0, 40.0)
X, Y = np.meshgrid(x, y)
XX = np.array([X.ravel(), Y.ravel()]).T
Z = -clf.score_samples(XX)
Z = Z.reshape(X.shape)
fig, ax = plt.subplots()
CS = ax.contour(
X, Y, Z, norm=LogNorm(vmin=1.0, vmax=1000.0), levels=np.logspace(0, 3, 10)
)
CB = fig.colorbar(CS, shrink=0.8, extend="both")
ax.scatter(X_train[:, 0], X_train[:, 1], 0.8)
ax.set_title("Negative log-likelihood predicted by a GMM")
ax.axis("tight")
# Save the plot as a temporary image file
temp_dir = tempfile.mkdtemp()
temp_file_path = os.path.join(temp_dir, "gmm_plot.png")
fig.savefig(temp_file_path)
plt.close(fig)
return temp_file_path
def plot_to_image(file_path):
with open(file_path, "rb") as f:
image_bytes = f.read()
os.remove(file_path)
return image_bytes
inputs = gr.inputs.Slider(100, 1000, step=100, default=300, label="Number of Samples")
outputs = gr.outputs.Image(type="pil", label="GMM Plot")
title = "Density Estimation for a Gaussian mixture"
description = "In this example, you can visualize the density estimation of a mixture of two Gaussians using a Gaussian Mixture Model (GMM). The data used for the model is generated from two Gaussians with distinct centers and covariance matrices. By adjusting the number of samples, you can observe how the GMM captures the underlying distribution and generates a contour plot representing the estimated density. This interactive application allows you to explore the behavior of the GMM and gain insights into the modeling of complex data distributions using mixture models. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_pdf.html"
gr.Interface(generate_gaussian_mixture, inputs, outputs, title=title, description=description, postprocess=plot_to_image, live=True).launch()