import matplotlib.pyplot as plt import torch import torch.nn as nn def plot_images(figure, imgs): h, w = figure assert(h*w == imgs.shape[0]), "figure grid doesn't match imgs amount" _, axs = plt.subplots(w, h) img_index = 0 for i in range(h): for j in range(w): axs[j, i].imshow(imgs[img_index]) axs[j, i].axis('off') img_index = img_index + 1 def denoise_image(noised_image, predicted_noise, t, betas, alphas, alpha_bar): z = torch.randn_like(noised_image) noise = betas.sqrt()[t] * z mean = (noised_image - predicted_noise * ((1 - alphas[t]) / (1 - alpha_bar[t]).sqrt())) / alphas[t].sqrt() return mean + noise class DDPM(nn.Module): def __init__(self, betas): super(DDPM, self).__init__() self.betas = betas self.alphas = 1.0 - betas self.alpha_bars = torch.cumprod(self.alphas, dim=0) def forward(self, x, t): batch_size = x.shape[0] device = x.device # Get corresponding alpha_bar_t alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1).to(device) # Sample noise noise = torch.randn_like(x) # Compute the noised image noised_image = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise return noised_image, noise def generate_img(model, sampler,betas, alpha, alpha_bar,batch_size, sampling_count, context=None, device=None): if device is None: device = torch.device("cpu") model.eval() if context is None: context = [0 for _ in range(batch_size)] context = torch.tensor(context, dtype=torch.int).to(device) with torch.no_grad(): noised_img = sampler(torch.rand((batch_size, 3, 16, 16)).to(device), torch.ones(batch_size, dtype=torch.int) * 200)[0] for t in range(sampling_count, 0, -1): _t = torch.tensor([[t for _ in range(noised_img.shape[0])]], dtype=torch.float32).to(device).T noise = model(noised_img, _t, context) noised_img = denoise_image(noised_img, noise, t, betas, alpha, alpha_bar) return noised_img