sprite-generator / helper.py
basil-ahmad's picture
Upload 3 files
53ef34c verified
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
def plot_images(figure, imgs):
h, w = figure
assert(h*w == imgs.shape[0]), "figure grid doesn't match imgs amount"
_, axs = plt.subplots(w, h)
img_index = 0
for i in range(h):
for j in range(w):
axs[j, i].imshow(imgs[img_index])
axs[j, i].axis('off')
img_index = img_index + 1
def denoise_image(noised_image, predicted_noise, t, betas, alphas, alpha_bar):
z = torch.randn_like(noised_image)
noise = betas.sqrt()[t] * z
mean = (noised_image - predicted_noise * ((1 - alphas[t]) / (1 - alpha_bar[t]).sqrt())) / alphas[t].sqrt()
return mean + noise
class DDPM(nn.Module):
def __init__(self, betas):
super(DDPM, self).__init__()
self.betas = betas
self.alphas = 1.0 - betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def forward(self, x, t):
batch_size = x.shape[0]
device = x.device
# Get corresponding alpha_bar_t
alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1).to(device)
# Sample noise
noise = torch.randn_like(x)
# Compute the noised image
noised_image = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise
return noised_image, noise
def generate_img(model, sampler,betas, alpha, alpha_bar,batch_size, sampling_count, context=None, device=None):
if device is None:
device = torch.device("cpu")
model.eval()
if context is None:
context = [0 for _ in range(batch_size)]
context = torch.tensor(context, dtype=torch.int).to(device)
with torch.no_grad():
noised_img = sampler(torch.rand((batch_size, 3, 16, 16)).to(device),
torch.ones(batch_size, dtype=torch.int) * 200)[0]
for t in range(sampling_count, 0, -1):
_t = torch.tensor([[t for _ in range(noised_img.shape[0])]], dtype=torch.float32).to(device).T
noise = model(noised_img, _t, context)
noised_img = denoise_image(noised_img, noise, t, betas, alpha, alpha_bar)
return noised_img