jmemon's picture
Upload folder using huggingface_hub
7c7dfcd
raw
history blame
686 Bytes
from pathlib import Path
from diffusers import DDPMPipeline, UNet2DModel, DDPMScheduler
from diffusers.utils import make_image_grid
import torch
if __name__ == '__main__':
unet=UNet2DModel.from_pretrained(
'jmemon/ddpm-paintings-128-finetuned-celebahq'
)
scheduler = DDPMScheduler.from_pretrained(
'jmemon/ddpm-paintings-128-finetuned-celebahq'
)
pipeline = DDPMPipeline(unet=unet, scheduler=scheduler).to('mps')
pipeline.enable_attention_slicing()
images = pipeline(batch_size=4, generator=torch.manual_seed(0), num_inference_steps=50).images
grid = make_image_grid(images, 2, 2)
grid.save(Path(__file__).parent / 'grid.png')