File size: 686 Bytes
7c7dfcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from pathlib import Path
from diffusers import DDPMPipeline, UNet2DModel, DDPMScheduler
from diffusers.utils import make_image_grid
import torch
if __name__ == '__main__':
unet=UNet2DModel.from_pretrained(
'jmemon/ddpm-paintings-128-finetuned-celebahq'
)
scheduler = DDPMScheduler.from_pretrained(
'jmemon/ddpm-paintings-128-finetuned-celebahq'
)
pipeline = DDPMPipeline(unet=unet, scheduler=scheduler).to('mps')
pipeline.enable_attention_slicing()
images = pipeline(batch_size=4, generator=torch.manual_seed(0), num_inference_steps=50).images
grid = make_image_grid(images, 2, 2)
grid.save(Path(__file__).parent / 'grid.png')
|