ddpm-mnist / app.py
1aurent's picture
Update app.py
55ac5d5 verified
raw
history blame
1.01 kB
from diffusers import DiffusionPipeline
import spaces
import torch
import PIL.Image
import gradio as gr
import gradio.components as grc
import numpy as np
pipeline = DiffusionPipeline.from_pretrained("1aurent/ddpm-mnist")
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = pipeline.to(device=device)
@spaces.GPU
def predict(steps, seed):
generator = torch.manual_seed(seed)
for i in range(1,steps):
yield pipeline(generator=generator, num_inference_steps=i).images[0]
gr.Interface(
predict,
inputs=[
grc.Slider(1, 100, label='Inference Steps', value=12, step=1),
grc.Slider(0, 2147483647, label='Seed', value=69420, step=1),
],
outputs=gr.Image(height=28, width=28, type="pil", elem_id="output_image"),
css="#output_image{width: 256px !important; height: 256px !important;}",
title="Unconditional MNIST",
description="A DDIM scheduler and UNet model trained on the MNIST dataset for unconditional image generation.",
).queue().launch()