1aurent's picture
Update app.py
24c35d3 verified
raw
history blame contribute delete
No virus
1.29 kB
import torch
import spaces
import PIL.Image
import gradio as gr
import gradio.components as grc
import numpy as np
from pipeline import DDIMPipelineCustom
pipeline = DDIMPipelineCustom.from_pretrained("1aurent/ddpm-mnist-conditional")
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = pipeline.to(device=device)
@spaces.GPU
def predict(steps, seed, value, guidance):
generator = torch.manual_seed(seed)
for i in range(1,steps):
yield pipeline(
generator=generator,
condition=torch.tensor([value], device=device),
guidance=guidance,
num_inference_steps=steps
).images[0]
gr.Interface(
predict,
inputs=[
grc.Slider(1, 100, label='Inference Steps', value=20, step=1),
grc.Slider(0, 2147483647, label='Seed', value=69420, step=1),
grc.Slider(0, 9, label='Value', value=5, step=1),
grc.Slider(-2.5, 2.5, label='Guidance Factor', value=1),
],
outputs=gr.Image(height=28, width=28, type="pil", elem_id="output_image"),
css="#output_image{width: 256px !important; height: 256px !important;}",
title="Conditional MNIST",
description="A DDIM scheduler and UNet model trained on the MNIST dataset for conditional image generation.",
).queue().launch()