from pathlib import Path | |
from diffusers import DDPMPipeline, UNet2DModel, DDPMScheduler | |
from diffusers.utils import make_image_grid | |
import torch | |
if __name__ == '__main__': | |
unet=UNet2DModel.from_pretrained( | |
'jmemon/ddpm-paintings-128-finetuned-celebahq' | |
) | |
scheduler = DDPMScheduler.from_pretrained( | |
'jmemon/ddpm-paintings-128-finetuned-celebahq' | |
) | |
pipeline = DDPMPipeline(unet=unet, scheduler=scheduler).to('mps') | |
pipeline.enable_attention_slicing() | |
images = pipeline(batch_size=4, generator=torch.manual_seed(0), num_inference_steps=50).images | |
grid = make_image_grid(images, 2, 2) | |
grid.save(Path(__file__).parent / 'grid.png') | |