basil-ahmad commited on
Commit
53ef34c
1 Parent(s): 6c82165

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +32 -0
  2. helper.py +63 -0
  3. model.py +138 -0
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from helper import generate_img, DDPM
3
+ import torch
4
+ from cv2 import resize
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ timesteps = 500
8
+ beta1 = 1e-4
9
+ beta2 = 0.02
10
+ betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1
11
+ betas = betas.to(device)
12
+ alpha = 1.0 - betas
13
+ alpha_bar = torch.cumprod(alpha, dim=0).to(device)
14
+ model = torch.load("model.pt", map_location=device)
15
+ sampler = DDPM(betas)
16
+
17
+ label_to_index = {l:i for i, l in enumerate(['hero', 'non-hero -not recommended-', 'food', 'spells & weapons', 'side-facing'])}
18
+
19
+ sampling_count = 300
20
+ batch_size = 1
21
+ context = st.radio('Pick one:',
22
+ label_to_index.keys()
23
+ )
24
+
25
+
26
+ if st.button("click"):
27
+ index = [label_to_index[context]]
28
+ img = generate_img(model, sampler,betas, alpha, alpha_bar, batch_size, sampling_count, context=index)
29
+ img = img.cpu().detach().permute(0, 2, 3, 1).numpy()[0]
30
+ img = resize(img, (320,320), interpolation=0)
31
+ st.write(context)
32
+ st.image(img, clamp=True)
helper.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ def plot_images(figure, imgs):
6
+ h, w = figure
7
+
8
+ assert(h*w == imgs.shape[0]), "figure grid doesn't match imgs amount"
9
+
10
+ _, axs = plt.subplots(w, h)
11
+
12
+ img_index = 0
13
+ for i in range(h):
14
+ for j in range(w):
15
+ axs[j, i].imshow(imgs[img_index])
16
+ axs[j, i].axis('off')
17
+ img_index = img_index + 1
18
+
19
+ def denoise_image(noised_image, predicted_noise, t, betas, alphas, alpha_bar):
20
+ z = torch.randn_like(noised_image)
21
+ noise = betas.sqrt()[t] * z
22
+ mean = (noised_image - predicted_noise * ((1 - alphas[t]) / (1 - alpha_bar[t]).sqrt())) / alphas[t].sqrt()
23
+ return mean + noise
24
+
25
+ class DDPM(nn.Module):
26
+ def __init__(self, betas):
27
+ super(DDPM, self).__init__()
28
+ self.betas = betas
29
+ self.alphas = 1.0 - betas
30
+ self.alpha_bars = torch.cumprod(self.alphas, dim=0)
31
+
32
+ def forward(self, x, t):
33
+ batch_size = x.shape[0]
34
+ device = x.device
35
+
36
+ # Get corresponding alpha_bar_t
37
+ alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1).to(device)
38
+
39
+ # Sample noise
40
+ noise = torch.randn_like(x)
41
+
42
+ # Compute the noised image
43
+ noised_image = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise
44
+
45
+ return noised_image, noise
46
+
47
+ def generate_img(model, sampler,betas, alpha, alpha_bar,batch_size, sampling_count, context=None, device=None):
48
+ if device is None:
49
+ device = torch.device("cpu")
50
+ model.eval()
51
+ if context is None:
52
+ context = [0 for _ in range(batch_size)]
53
+ context = torch.tensor(context, dtype=torch.int).to(device)
54
+
55
+ with torch.no_grad():
56
+ noised_img = sampler(torch.rand((batch_size, 3, 16, 16)).to(device),
57
+ torch.ones(batch_size, dtype=torch.int) * 200)[0]
58
+
59
+ for t in range(sampling_count, 0, -1):
60
+ _t = torch.tensor([[t for _ in range(noised_img.shape[0])]], dtype=torch.float32).to(device).T
61
+ noise = model(noised_img, _t, context)
62
+ noised_img = denoise_image(noised_img, noise, t, betas, alpha, alpha_bar)
63
+ return noised_img
model.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ContextEmbedding(nn.Module):
5
+ def __init__(self, vocab_size, embedding_dim):
6
+ super(ContextEmbedding, self).__init__()
7
+
8
+ self.embedd_fc1 = nn.Embedding(vocab_size, embedding_dim)
9
+ self.fc1 = nn.Linear(embedding_dim, embedding_dim)
10
+
11
+ def forward(self, x):
12
+ x = self.embedd_fc1(x)
13
+ x = self.fc1(x)
14
+ return x
15
+
16
+ class TimeEbedding(nn.Module):
17
+ def __init__(self, time_dim, hidden_dim, out_dim):
18
+ super(TimeEbedding, self).__init__()
19
+ self.fc1 = nn.Linear(time_dim, hidden_dim)
20
+ self.gelu = nn.GELU()
21
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
22
+
23
+ def forward(self, x):
24
+ x = self.fc1(x)
25
+ x = self.gelu(x)
26
+ x = self.fc2(x)
27
+
28
+ return x
29
+
30
+ class ConvBlock(nn.Module):
31
+ def __init__(self, in_channels, out_channels, vocab_size, embedding_dim , residual=False):
32
+ super(ConvBlock, self).__init__()
33
+
34
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 2, padding="same")
35
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 2, padding="same")
36
+ self.downsample = nn.Conv2d(in_channels, out_channels, 3, padding="same", bias=False)
37
+ self.relu = nn.ReLU()
38
+ self.bn = nn.BatchNorm2d(out_channels)
39
+ self.te = TimeEbedding(1, out_channels, out_channels)
40
+ self.context_embedding = ContextEmbedding(vocab_size, embedding_dim)
41
+ self.__residual = residual
42
+
43
+ def forward(self, x, t, context):
44
+ x1 = self.downsample(x)
45
+ x = self.conv1(x)
46
+ x = self.relu(x)
47
+ x = self.conv2(x)
48
+ x = self.relu(x)
49
+
50
+ time_embedding = self.te(t)[:, :, None, None]
51
+ context = self.context_embedding(context)[:, :, None, None]
52
+
53
+ if self.__residual:
54
+ return self.bn(x + x1) * context + time_embedding
55
+ return x * context + time_embedding
56
+
57
+ class UNET(nn.Module):
58
+ def __init__(self, in_channels, out_channels, residual=False):
59
+ super(UNET, self).__init__()
60
+
61
+ self.convBD1 = ConvBlock(in_channels, 64, 5, 64,residual)
62
+ self.convBD2 = ConvBlock(64, 128, 5, 128,residual)
63
+ self.convBD3 = ConvBlock(128, 256, 5, 256,residual)
64
+ # self.convBD4 = ConvBlock(256, 512, 5, 512,residual)
65
+
66
+ self.bottle_neck = ConvBlock(256, 512, 5, 512,residual)
67
+ # self.bottle_neck = ConvBlock(512, 1024, 5, 1024,residual)
68
+
69
+ # self.convBU1 = ConvBlock(1024, 512, 5, 512,residual)
70
+ self.convBU2 = ConvBlock(512, 256, 5, 256,residual)
71
+ self.convBU3 = ConvBlock(256, 128, 5, 128,residual)
72
+ self.convBU4 = ConvBlock(128, 64, 5, 64,residual)
73
+
74
+ self.convT1 = nn.ConvTranspose2d(1024, 512, 2, 2)
75
+ self.convT2 = nn.ConvTranspose2d(512, 256, 2, 2)
76
+ self.convT3 = nn.ConvTranspose2d(256, 128, 2, 2)
77
+ self.convT4 = nn.ConvTranspose2d(128, 64, 2, 2)
78
+
79
+ self.final = nn.Conv2d(64, out_channels, 1)
80
+
81
+ self.maxpool = nn.MaxPool2d(2)
82
+ def forward(self, x, t, context):
83
+ if x.ndim == 3:
84
+ x = x.unsqueeze(0)
85
+
86
+ x1 = self.convBD1(x, t, context)
87
+ x = self.maxpool(x1)
88
+
89
+ x2 = self.convBD2(x, t, context)
90
+ x = self.maxpool(x2)
91
+
92
+ x3 = self.convBD3(x, t, context)
93
+ x = self.maxpool(x3)
94
+
95
+ # x4 = self.convBD4(x, t, context)
96
+ # x = self.maxpool(x4)
97
+
98
+ x = self.bottle_neck(x, t, context)
99
+
100
+ # x = self.convT1(x)
101
+ # x = torch.cat([x4, x], dim=1)
102
+ # x = self.convBU1(x, t, context)
103
+
104
+ x = self.convT2(x)
105
+ x = torch.cat([x3, x], dim=1)
106
+ x = self.convBU2(x, t, context)
107
+
108
+ x = self.convT3(x)
109
+ x = torch.cat([x2, x], dim=1)
110
+ x = self.convBU3(x, t, context)
111
+
112
+ x = self.convT4(x)
113
+ x = torch.cat([x1, x], dim=1)
114
+ x = self.convBU4(x, t, context)
115
+
116
+ x = self.final(x)
117
+
118
+ return x
119
+
120
+ import torch
121
+ import torch.nn as nn
122
+ from model import UNET
123
+ from helper import DDPM, denoise_image
124
+
125
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
+ timesteps = 500
127
+ beta1 = 1e-4
128
+ beta2 = 0.02
129
+ betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1
130
+ betas = betas.to(device)
131
+ alpha = 1.0 - betas
132
+ alpha_bar = torch.cumprod(alpha, dim=0).to(device)
133
+
134
+ model = UNET(3, 3,True)
135
+ model = model.to(device)
136
+ loss_fn = nn.MSELoss()
137
+ optim = torch.optim.Adam(model.parameters(), lr=1e-3)
138
+ sampler = DDPM(betas)