File size: 2,183 Bytes
53ef34c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

def plot_images(figure, imgs):
    h, w = figure
    
    assert(h*w == imgs.shape[0]), "figure grid doesn't match imgs amount"
  
    _, axs = plt.subplots(w, h)
    
    img_index = 0
    for i in range(h):
        for j in range(w):
            axs[j, i].imshow(imgs[img_index])
            axs[j, i].axis('off')
            img_index = img_index + 1

def denoise_image(noised_image, predicted_noise, t, betas, alphas, alpha_bar):
    z = torch.randn_like(noised_image)
    noise = betas.sqrt()[t] * z
    mean = (noised_image - predicted_noise * ((1 - alphas[t]) / (1 - alpha_bar[t]).sqrt())) / alphas[t].sqrt()
    return mean + noise

class DDPM(nn.Module):
    def __init__(self, betas):
        super(DDPM, self).__init__()
        self.betas = betas
        self.alphas = 1.0 - betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def forward(self, x, t):
        batch_size = x.shape[0]
        device = x.device
        
        # Get corresponding alpha_bar_t
        alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1).to(device)

        # Sample noise
        noise = torch.randn_like(x)
        
        # Compute the noised image
        noised_image = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise

        return noised_image, noise

def generate_img(model, sampler,betas, alpha, alpha_bar,batch_size, sampling_count, context=None, device=None):
    if device is None:
        device = torch.device("cpu")
    model.eval()
    if context is None:
        context = [0 for _ in range(batch_size)]
    context = torch.tensor(context, dtype=torch.int).to(device)
    
    with torch.no_grad():
        noised_img = sampler(torch.rand((batch_size, 3, 16, 16)).to(device),
                            torch.ones(batch_size, dtype=torch.int) * 200)[0]

        for t in range(sampling_count, 0, -1):
            _t = torch.tensor([[t for _ in range(noised_img.shape[0])]], dtype=torch.float32).to(device).T
            noise = model(noised_img, _t, context)
            noised_img = denoise_image(noised_img, noise, t, betas, alpha, alpha_bar)
    return noised_img