Spaces:
Build error
Build error
from diffusers import DDPMPipeline | |
image_pipe = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256") | |
image_pipe.to("cuda") | |
images = image_pipe().images | |
from diffusers import UNet2DModel | |
repo_id = "google/ddpm-church-256" | |
model = UNet2DModel.from_pretrained(repo_id) | |
model_random = UNet2DModel(**model.config) | |
model_random.save_pretrained("my_model") | |
model_random = UNet2DModel.from_pretrained("my_model") | |
import torch | |
torch.manual_seed(0) | |
noisy_sample = torch.randn( | |
1, model.config.in_channels, model.config.sample_size, model.config.sample_size | |
) | |
with torch.no_grad(): | |
noisy_residual = model(sample=noisy_sample, timestep=2).sample | |
from diffusers import DDPMScheduler | |
scheduler = DDPMScheduler.from_config(repo_id) | |
new_scheduler = DDPMScheduler.from_config("my_scheduler") | |
less_noisy_sample = scheduler.step( | |
model_output=noisy_residual, timestep=2, sample=noisy_sample | |
).prev_sample | |
import PIL.Image | |
import numpy as np | |
def display_sample(sample, i): | |
image_processed = sample.cpu().permute(0, 2, 3, 1) | |
image_processed = (image_processed + 1.0) * 127.5 | |
image_processed = image_processed.numpy().astype(np.uint8) | |
image_pil = PIL.Image.fromarray(image_processed[0]) | |
display(f"Image at step {i}") | |
display(image_pil) | |
noisy_sample = noisy_sample.to("cuda") | |
import tqdm | |
sample = noisy_sample | |
for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)): | |
with torch.no_grad(): | |
residual = model(sample, t).sample | |
sample = scheduler.step(residual, t, sample).prev_sample | |
if (i + 1) % 50 == 0: | |
display_sample(sample, i + 1) | |
from diffusers import DDIMScheduler | |
scheduler = DDIMScheduler.from_config(repo_id) | |
import tqdm | |
sample = noisy_sample | |
for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)): | |
with torch.no_grad(): | |
residual = model(sample, t).sample | |
sample = scheduler.step(residual, t, sample).prev_sample | |
if (i + 1) % 10 == 0: | |
display_sample(sample, i + 1) |