Spaces:
Sleeping
Sleeping
basil-ahmad
commited on
Commit
•
53ef34c
1
Parent(s):
6c82165
Upload 3 files
Browse files
app.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from helper import generate_img, DDPM
|
3 |
+
import torch
|
4 |
+
from cv2 import resize
|
5 |
+
|
6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
timesteps = 500
|
8 |
+
beta1 = 1e-4
|
9 |
+
beta2 = 0.02
|
10 |
+
betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1
|
11 |
+
betas = betas.to(device)
|
12 |
+
alpha = 1.0 - betas
|
13 |
+
alpha_bar = torch.cumprod(alpha, dim=0).to(device)
|
14 |
+
model = torch.load("model.pt", map_location=device)
|
15 |
+
sampler = DDPM(betas)
|
16 |
+
|
17 |
+
label_to_index = {l:i for i, l in enumerate(['hero', 'non-hero -not recommended-', 'food', 'spells & weapons', 'side-facing'])}
|
18 |
+
|
19 |
+
sampling_count = 300
|
20 |
+
batch_size = 1
|
21 |
+
context = st.radio('Pick one:',
|
22 |
+
label_to_index.keys()
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
if st.button("click"):
|
27 |
+
index = [label_to_index[context]]
|
28 |
+
img = generate_img(model, sampler,betas, alpha, alpha_bar, batch_size, sampling_count, context=index)
|
29 |
+
img = img.cpu().detach().permute(0, 2, 3, 1).numpy()[0]
|
30 |
+
img = resize(img, (320,320), interpolation=0)
|
31 |
+
st.write(context)
|
32 |
+
st.image(img, clamp=True)
|
helper.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
def plot_images(figure, imgs):
|
6 |
+
h, w = figure
|
7 |
+
|
8 |
+
assert(h*w == imgs.shape[0]), "figure grid doesn't match imgs amount"
|
9 |
+
|
10 |
+
_, axs = plt.subplots(w, h)
|
11 |
+
|
12 |
+
img_index = 0
|
13 |
+
for i in range(h):
|
14 |
+
for j in range(w):
|
15 |
+
axs[j, i].imshow(imgs[img_index])
|
16 |
+
axs[j, i].axis('off')
|
17 |
+
img_index = img_index + 1
|
18 |
+
|
19 |
+
def denoise_image(noised_image, predicted_noise, t, betas, alphas, alpha_bar):
|
20 |
+
z = torch.randn_like(noised_image)
|
21 |
+
noise = betas.sqrt()[t] * z
|
22 |
+
mean = (noised_image - predicted_noise * ((1 - alphas[t]) / (1 - alpha_bar[t]).sqrt())) / alphas[t].sqrt()
|
23 |
+
return mean + noise
|
24 |
+
|
25 |
+
class DDPM(nn.Module):
|
26 |
+
def __init__(self, betas):
|
27 |
+
super(DDPM, self).__init__()
|
28 |
+
self.betas = betas
|
29 |
+
self.alphas = 1.0 - betas
|
30 |
+
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
|
31 |
+
|
32 |
+
def forward(self, x, t):
|
33 |
+
batch_size = x.shape[0]
|
34 |
+
device = x.device
|
35 |
+
|
36 |
+
# Get corresponding alpha_bar_t
|
37 |
+
alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1).to(device)
|
38 |
+
|
39 |
+
# Sample noise
|
40 |
+
noise = torch.randn_like(x)
|
41 |
+
|
42 |
+
# Compute the noised image
|
43 |
+
noised_image = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise
|
44 |
+
|
45 |
+
return noised_image, noise
|
46 |
+
|
47 |
+
def generate_img(model, sampler,betas, alpha, alpha_bar,batch_size, sampling_count, context=None, device=None):
|
48 |
+
if device is None:
|
49 |
+
device = torch.device("cpu")
|
50 |
+
model.eval()
|
51 |
+
if context is None:
|
52 |
+
context = [0 for _ in range(batch_size)]
|
53 |
+
context = torch.tensor(context, dtype=torch.int).to(device)
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
noised_img = sampler(torch.rand((batch_size, 3, 16, 16)).to(device),
|
57 |
+
torch.ones(batch_size, dtype=torch.int) * 200)[0]
|
58 |
+
|
59 |
+
for t in range(sampling_count, 0, -1):
|
60 |
+
_t = torch.tensor([[t for _ in range(noised_img.shape[0])]], dtype=torch.float32).to(device).T
|
61 |
+
noise = model(noised_img, _t, context)
|
62 |
+
noised_img = denoise_image(noised_img, noise, t, betas, alpha, alpha_bar)
|
63 |
+
return noised_img
|
model.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class ContextEmbedding(nn.Module):
|
5 |
+
def __init__(self, vocab_size, embedding_dim):
|
6 |
+
super(ContextEmbedding, self).__init__()
|
7 |
+
|
8 |
+
self.embedd_fc1 = nn.Embedding(vocab_size, embedding_dim)
|
9 |
+
self.fc1 = nn.Linear(embedding_dim, embedding_dim)
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
x = self.embedd_fc1(x)
|
13 |
+
x = self.fc1(x)
|
14 |
+
return x
|
15 |
+
|
16 |
+
class TimeEbedding(nn.Module):
|
17 |
+
def __init__(self, time_dim, hidden_dim, out_dim):
|
18 |
+
super(TimeEbedding, self).__init__()
|
19 |
+
self.fc1 = nn.Linear(time_dim, hidden_dim)
|
20 |
+
self.gelu = nn.GELU()
|
21 |
+
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x = self.fc1(x)
|
25 |
+
x = self.gelu(x)
|
26 |
+
x = self.fc2(x)
|
27 |
+
|
28 |
+
return x
|
29 |
+
|
30 |
+
class ConvBlock(nn.Module):
|
31 |
+
def __init__(self, in_channels, out_channels, vocab_size, embedding_dim , residual=False):
|
32 |
+
super(ConvBlock, self).__init__()
|
33 |
+
|
34 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, 2, padding="same")
|
35 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, 2, padding="same")
|
36 |
+
self.downsample = nn.Conv2d(in_channels, out_channels, 3, padding="same", bias=False)
|
37 |
+
self.relu = nn.ReLU()
|
38 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
39 |
+
self.te = TimeEbedding(1, out_channels, out_channels)
|
40 |
+
self.context_embedding = ContextEmbedding(vocab_size, embedding_dim)
|
41 |
+
self.__residual = residual
|
42 |
+
|
43 |
+
def forward(self, x, t, context):
|
44 |
+
x1 = self.downsample(x)
|
45 |
+
x = self.conv1(x)
|
46 |
+
x = self.relu(x)
|
47 |
+
x = self.conv2(x)
|
48 |
+
x = self.relu(x)
|
49 |
+
|
50 |
+
time_embedding = self.te(t)[:, :, None, None]
|
51 |
+
context = self.context_embedding(context)[:, :, None, None]
|
52 |
+
|
53 |
+
if self.__residual:
|
54 |
+
return self.bn(x + x1) * context + time_embedding
|
55 |
+
return x * context + time_embedding
|
56 |
+
|
57 |
+
class UNET(nn.Module):
|
58 |
+
def __init__(self, in_channels, out_channels, residual=False):
|
59 |
+
super(UNET, self).__init__()
|
60 |
+
|
61 |
+
self.convBD1 = ConvBlock(in_channels, 64, 5, 64,residual)
|
62 |
+
self.convBD2 = ConvBlock(64, 128, 5, 128,residual)
|
63 |
+
self.convBD3 = ConvBlock(128, 256, 5, 256,residual)
|
64 |
+
# self.convBD4 = ConvBlock(256, 512, 5, 512,residual)
|
65 |
+
|
66 |
+
self.bottle_neck = ConvBlock(256, 512, 5, 512,residual)
|
67 |
+
# self.bottle_neck = ConvBlock(512, 1024, 5, 1024,residual)
|
68 |
+
|
69 |
+
# self.convBU1 = ConvBlock(1024, 512, 5, 512,residual)
|
70 |
+
self.convBU2 = ConvBlock(512, 256, 5, 256,residual)
|
71 |
+
self.convBU3 = ConvBlock(256, 128, 5, 128,residual)
|
72 |
+
self.convBU4 = ConvBlock(128, 64, 5, 64,residual)
|
73 |
+
|
74 |
+
self.convT1 = nn.ConvTranspose2d(1024, 512, 2, 2)
|
75 |
+
self.convT2 = nn.ConvTranspose2d(512, 256, 2, 2)
|
76 |
+
self.convT3 = nn.ConvTranspose2d(256, 128, 2, 2)
|
77 |
+
self.convT4 = nn.ConvTranspose2d(128, 64, 2, 2)
|
78 |
+
|
79 |
+
self.final = nn.Conv2d(64, out_channels, 1)
|
80 |
+
|
81 |
+
self.maxpool = nn.MaxPool2d(2)
|
82 |
+
def forward(self, x, t, context):
|
83 |
+
if x.ndim == 3:
|
84 |
+
x = x.unsqueeze(0)
|
85 |
+
|
86 |
+
x1 = self.convBD1(x, t, context)
|
87 |
+
x = self.maxpool(x1)
|
88 |
+
|
89 |
+
x2 = self.convBD2(x, t, context)
|
90 |
+
x = self.maxpool(x2)
|
91 |
+
|
92 |
+
x3 = self.convBD3(x, t, context)
|
93 |
+
x = self.maxpool(x3)
|
94 |
+
|
95 |
+
# x4 = self.convBD4(x, t, context)
|
96 |
+
# x = self.maxpool(x4)
|
97 |
+
|
98 |
+
x = self.bottle_neck(x, t, context)
|
99 |
+
|
100 |
+
# x = self.convT1(x)
|
101 |
+
# x = torch.cat([x4, x], dim=1)
|
102 |
+
# x = self.convBU1(x, t, context)
|
103 |
+
|
104 |
+
x = self.convT2(x)
|
105 |
+
x = torch.cat([x3, x], dim=1)
|
106 |
+
x = self.convBU2(x, t, context)
|
107 |
+
|
108 |
+
x = self.convT3(x)
|
109 |
+
x = torch.cat([x2, x], dim=1)
|
110 |
+
x = self.convBU3(x, t, context)
|
111 |
+
|
112 |
+
x = self.convT4(x)
|
113 |
+
x = torch.cat([x1, x], dim=1)
|
114 |
+
x = self.convBU4(x, t, context)
|
115 |
+
|
116 |
+
x = self.final(x)
|
117 |
+
|
118 |
+
return x
|
119 |
+
|
120 |
+
import torch
|
121 |
+
import torch.nn as nn
|
122 |
+
from model import UNET
|
123 |
+
from helper import DDPM, denoise_image
|
124 |
+
|
125 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
126 |
+
timesteps = 500
|
127 |
+
beta1 = 1e-4
|
128 |
+
beta2 = 0.02
|
129 |
+
betas = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1) + beta1
|
130 |
+
betas = betas.to(device)
|
131 |
+
alpha = 1.0 - betas
|
132 |
+
alpha_bar = torch.cumprod(alpha, dim=0).to(device)
|
133 |
+
|
134 |
+
model = UNET(3, 3,True)
|
135 |
+
model = model.to(device)
|
136 |
+
loss_fn = nn.MSELoss()
|
137 |
+
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
|
138 |
+
sampler = DDPM(betas)
|