Spaces:
Sleeping
Sleeping
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
def plot_images(figure, imgs): | |
h, w = figure | |
assert(h*w == imgs.shape[0]), "figure grid doesn't match imgs amount" | |
_, axs = plt.subplots(w, h) | |
img_index = 0 | |
for i in range(h): | |
for j in range(w): | |
axs[j, i].imshow(imgs[img_index]) | |
axs[j, i].axis('off') | |
img_index = img_index + 1 | |
def denoise_image(noised_image, predicted_noise, t, betas, alphas, alpha_bar): | |
z = torch.randn_like(noised_image) | |
noise = betas.sqrt()[t] * z | |
mean = (noised_image - predicted_noise * ((1 - alphas[t]) / (1 - alpha_bar[t]).sqrt())) / alphas[t].sqrt() | |
return mean + noise | |
class DDPM(nn.Module): | |
def __init__(self, betas): | |
super(DDPM, self).__init__() | |
self.betas = betas | |
self.alphas = 1.0 - betas | |
self.alpha_bars = torch.cumprod(self.alphas, dim=0) | |
def forward(self, x, t): | |
batch_size = x.shape[0] | |
device = x.device | |
# Get corresponding alpha_bar_t | |
alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1).to(device) | |
# Sample noise | |
noise = torch.randn_like(x) | |
# Compute the noised image | |
noised_image = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise | |
return noised_image, noise | |
def generate_img(model, sampler,betas, alpha, alpha_bar,batch_size, sampling_count, context=None, device=None): | |
if device is None: | |
device = torch.device("cpu") | |
model.eval() | |
if context is None: | |
context = [0 for _ in range(batch_size)] | |
context = torch.tensor(context, dtype=torch.int).to(device) | |
with torch.no_grad(): | |
noised_img = sampler(torch.rand((batch_size, 3, 16, 16)).to(device), | |
torch.ones(batch_size, dtype=torch.int) * 200)[0] | |
for t in range(sampling_count, 0, -1): | |
_t = torch.tensor([[t for _ in range(noised_img.shape[0])]], dtype=torch.float32).to(device).T | |
noise = model(noised_img, _t, context) | |
noised_img = denoise_image(noised_img, noise, t, betas, alpha, alpha_bar) | |
return noised_img |