Spaces:
Runtime error
Runtime error
import inspect | |
from tkinter import Image | |
from typing import List, Optional, Union | |
import numpy as np | |
import torch | |
import PIL | |
from PIL import Image | |
from tqdm.auto import tqdm | |
from diffusion_arch import DensePosteriorConditionalUNet | |
from guided_diffusion.script_util import create_gaussian_diffusion | |
import torch.nn.functional as F | |
import torchvision.transforms.functional as TF | |
from einops import rearrange | |
from kornia.morphology import dilation | |
from tqdm import tqdm | |
def preprocess_image(image): | |
w, h = image.size | |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
image = image.resize((w, h), resample=PIL.Image.LANCZOS) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = torch.from_numpy(image.transpose(2,0,1)).unsqueeze(0) | |
return 2.0 * image - 1.0 | |
def preprocess_mask(mask): | |
mask = mask.convert("L") | |
w, h = mask.size | |
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
mask = mask.resize((w, h), resample=PIL.Image.NEAREST) | |
mask = np.array(mask).astype(np.float32) / 255.0 | |
mask = torch.from_numpy(np.repeat(mask[None, ...], 3, axis=0)).unsqueeze(0) | |
mask[mask > 0] = 1 | |
return mask | |
class DiffusionPipeline(): | |
def __init__(self, device): | |
super().__init__() | |
self.device = device | |
self.model = DensePosteriorConditionalUNet( | |
in_channels=9, | |
model_channels=256, | |
out_channels=6, | |
num_res_blocks=2, | |
attention_resolutions=[8, 16, 32], | |
dropout=0.0, | |
channel_mult=(1, 1, 2, 2, 4, 4), | |
num_classes=None, | |
use_checkpoint=False, | |
use_fp16=False, | |
num_heads=4, | |
num_head_channels=64, | |
num_heads_upsample=-1, | |
use_scale_shift_norm=True, | |
resblock_updown=True, | |
use_new_attention_order=True | |
) | |
self.model.eval() | |
self.model.to(self.device) | |
self.model.load_state_dict(torch.load('net_g_400000.pth', map_location='cpu')["params_ema"], strict=True) | |
def __call__(self, lq, mask, dkernel, diffusion_step): | |
self.eval_gaussian_diffusion = create_gaussian_diffusion( | |
steps=1000, | |
learn_sigma=True, | |
noise_schedule='linear', | |
use_kl=False, | |
timestep_respacing="ddim" + str(diffusion_step), | |
predict_xstart=False, | |
rescale_timesteps=False, | |
rescale_learned_sigmas=False, | |
p2_gamma=1, | |
p2_k=1, | |
) | |
ow, oh = lq.size | |
# preprocess image | |
lq = preprocess_image(lq).to(self.device) | |
# preprocess mask | |
mask = preprocess_mask(mask).to(self.device) | |
mask = dilation(mask, torch.ones(dkernel, dkernel, device=self.device)) | |
# return Image.fromarray(np.uint8(torch.cat(((lq / 2 + 0.5).clamp(0, 1), mask), dim=2).cpu().permute(0, 2, 3, 1).numpy()[0] * 255.)) | |
#======== PADDING FORWARDING ============ | |
stride = 64 | |
kernel_size = 256 | |
_, _, h, w = mask.shape | |
mask = F.unfold(mask, kernel_size=kernel_size, stride=stride) | |
lq = F.unfold(lq, kernel_size=kernel_size, stride=stride) | |
n, c, l = mask.shape | |
mask = rearrange(mask, 'n (c3 h w) l -> (n l) c3 h w', h=kernel_size, w=kernel_size) | |
lq = rearrange(lq, 'n (c3 h w) l -> (n l) c3 h w', h=kernel_size, w=kernel_size) | |
#======== PADDING END ============ | |
#======== FORWARDING ============ | |
sub_imgs = [] | |
for (sub_lq, sub_mask) in zip(lq.unsqueeze(1), mask.unsqueeze(1)): | |
if torch.sum(sub_mask) > 1: | |
img = torch.randn_like(sub_lq, device=self.device) | |
indices = list(range(self.eval_gaussian_diffusion.num_timesteps))[::-1] | |
for i in indices: | |
t = torch.tensor([i] * img.size(0), device=self.device) | |
img = img * sub_mask + self.eval_gaussian_diffusion.q_sample(sub_lq, t) * (1 - sub_mask) | |
out = self.eval_gaussian_diffusion.p_mean_variance(self.model, img.contiguous(), t, model_kwargs={'latent': torch.cat((sub_lq, sub_mask), dim=1)}) | |
nonzero_mask = ( | |
(t != 0).float().view(-1, *([1] * (len(img.shape) - 1))) | |
) # no noise when t == 0 | |
img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device) | |
sub_imgs.append(img) | |
else: | |
sub_imgs.append(sub_lq) | |
img = torch.cat(sub_imgs, dim=0) | |
#======== PADDING BACKWARDING ============ | |
img = rearrange(img, '(n l) c3 h w -> n (c3 h w) l', h=kernel_size, w=kernel_size, l=l) | |
img = F.fold(img, (h, w), kernel_size=kernel_size, stride=stride) | |
norm_map = F.fold(F.unfold(torch.ones_like(img), kernel_size, stride=stride), (h, w), kernel_size, stride=stride) | |
img /= norm_map | |
img = (img / 2 + 0.5).clamp(0, 1) | |
img = img.cpu().permute(0, 2, 3, 1).numpy()[0] | |
img = Image.fromarray(np.uint8(img * 255.)) | |
img = img.resize((ow, oh), resample=PIL.Image.LANCZOS) | |
return img | |
def quick_solve(self, lq, mask, dkernel, diffusion_step): | |
self.eval_gaussian_diffusion = create_gaussian_diffusion( | |
steps=1000, | |
learn_sigma=True, | |
noise_schedule='linear', | |
use_kl=False, | |
timestep_respacing="ddim" + str(diffusion_step), | |
predict_xstart=False, | |
rescale_timesteps=False, | |
rescale_learned_sigmas=False, | |
p2_gamma=1, | |
p2_k=1, | |
) | |
ow, oh = lq.size | |
lq = lq.resize((512, 512), resample=Image.LANCZOS) | |
mask = mask.resize((512, 512), resample=Image.NEAREST) | |
# preprocess image | |
lq = preprocess_image(lq).to(self.device) | |
# preprocess mask | |
mask = preprocess_mask(mask).to(self.device) | |
mask = dilation(mask, torch.ones(dkernel, dkernel, device=self.device)) | |
# return Image.fromarray(np.uint8(torch.cat(((lq / 2 + 0.5).clamp(0, 1), mask), dim=2).cpu().permute(0, 2, 3, 1).numpy()[0] * 255.)) | |
img = torch.randn_like(lq, device=self.device) | |
indices = list(range(self.eval_gaussian_diffusion.num_timesteps))[::-1] | |
for i in indices: | |
t = torch.tensor([i] * img.size(0), device=self.device) | |
img = img * mask + self.eval_gaussian_diffusion.q_sample(lq, t) * (1 - mask) | |
out = self.eval_gaussian_diffusion.p_mean_variance(self.model, img.contiguous(), t, model_kwargs={'latent': torch.cat((lq, mask), dim=1)}) | |
nonzero_mask = ( | |
(t != 0).float().view(-1, *([1] * (len(img.shape) - 1))) | |
) # no noise when t == 0 | |
img = out["mean"] + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * torch.randn_like(img, device=self.device) | |
yield Image.fromarray(np.uint8((out["pred_xstart"] / 2 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()[0] * 255.)).resize((ow, oh), resample=Image.LANCZOS) | |
yield Image.fromarray(np.uint8((img / 2 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()[0] * 255.)).resize((ow, oh), resample=Image.LANCZOS) | |