llvictorll
commited on
add gradio app
Browse files- Models/__init__.py +0 -0
- Models/models/__init__.py +0 -0
- Models/models/transformer.py +117 -0
- Models/models/vqgan.py +294 -0
- Models/modules/diffusionmodules/model.py +436 -0
- Models/modules/util.py +130 -0
- Models/modules/vqvae/quantize.py +335 -0
- Models/util.py +157 -0
- app.py +78 -0
- runner.py +221 -0
Models/__init__.py
ADDED
File without changes
|
Models/models/__init__.py
ADDED
File without changes
|
Models/models/transformer.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BERT architecture for the Masked Bidirectional Encoder Transformer
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class PreNorm(nn.Module):
|
7 |
+
def __init__(self, dim, fn):
|
8 |
+
super().__init__()
|
9 |
+
self.norm = nn.LayerNorm(dim)
|
10 |
+
self.fn = fn
|
11 |
+
|
12 |
+
def forward(self, x, **kwargs):
|
13 |
+
return self.fn(self.norm(x), **kwargs)
|
14 |
+
|
15 |
+
|
16 |
+
class FeedForward(nn.Module):
|
17 |
+
def __init__(self, dim, hidden_dim, dropout=0.):
|
18 |
+
super().__init__()
|
19 |
+
self.net = nn.Sequential(
|
20 |
+
nn.Linear(dim, hidden_dim, bias=True),
|
21 |
+
nn.GELU(),
|
22 |
+
nn.Dropout(dropout),
|
23 |
+
nn.Linear(hidden_dim, dim, bias=True),
|
24 |
+
nn.Dropout(dropout)
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.net(x)
|
29 |
+
|
30 |
+
|
31 |
+
class Attention(nn.Module):
|
32 |
+
def __init__(self, embed_dim, num_heads, dropout=0.):
|
33 |
+
super(Attention, self).__init__()
|
34 |
+
self.dim = embed_dim
|
35 |
+
self.mha = nn.MultiheadAttention(embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True, bias=True)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
attention_value, attention_weight = self.mha(x, x, x)
|
39 |
+
return attention_value, attention_weight
|
40 |
+
|
41 |
+
|
42 |
+
class TransformerEncoder(nn.Module):
|
43 |
+
def __init__(self, dim, depth, heads, mlp_dim, dropout=0.):
|
44 |
+
super().__init__()
|
45 |
+
self.layers = nn.ModuleList([])
|
46 |
+
for _ in range(depth):
|
47 |
+
self.layers.append(nn.ModuleList([
|
48 |
+
PreNorm(dim, Attention(dim, heads, dropout=dropout)),
|
49 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
|
50 |
+
]))
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
l_attn = []
|
54 |
+
for attn, ff in self.layers:
|
55 |
+
attention_value, attention_weight = attn(x)
|
56 |
+
x = attention_value + x
|
57 |
+
x = ff(x) + x
|
58 |
+
l_attn.append(attention_weight)
|
59 |
+
return x, l_attn
|
60 |
+
|
61 |
+
|
62 |
+
class MaskTransformer(nn.Module):
|
63 |
+
def __init__(self, img_size=256, hidden_dim=768, codebook_size=1024, depth=24, heads=8, mlp_dim=3072, dropout=0.1, nclass=1000):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
self.nclass = nclass
|
67 |
+
self.patch_size = img_size // 16
|
68 |
+
self.codebook_size = codebook_size
|
69 |
+
self.tok_emb = nn.Embedding(codebook_size+1+nclass+1, hidden_dim) # +1 for the mask of the viz token, +1 for mask of the class
|
70 |
+
# self.msk_emb = nn.Embedding(2, hidden_dim)
|
71 |
+
self.pos_emb = nn.init.trunc_normal_(nn.Parameter(torch.zeros(1, (self.patch_size*self.patch_size)+1, hidden_dim)), 0., 0.02)
|
72 |
+
self.first_layer = nn.Sequential(
|
73 |
+
nn.LayerNorm(hidden_dim, eps=1e-12),
|
74 |
+
nn.Dropout(p=dropout),
|
75 |
+
nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
|
76 |
+
nn.GELU(),
|
77 |
+
nn.LayerNorm(hidden_dim, eps=1e-12),
|
78 |
+
nn.Dropout(p=dropout),
|
79 |
+
nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
|
80 |
+
)
|
81 |
+
|
82 |
+
self.transformer = TransformerEncoder(dim=hidden_dim, depth=depth, heads=heads, mlp_dim=mlp_dim, dropout=dropout)
|
83 |
+
|
84 |
+
self.last_layer = nn.Sequential(
|
85 |
+
nn.LayerNorm(hidden_dim, eps=1e-12),
|
86 |
+
nn.Dropout(p=dropout),
|
87 |
+
nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
|
88 |
+
nn.GELU(),
|
89 |
+
nn.LayerNorm(hidden_dim, eps=1e-12),
|
90 |
+
)
|
91 |
+
|
92 |
+
self.bias = nn.Parameter(torch.zeros((self.patch_size*self.patch_size)+1, codebook_size+1+nclass+1))
|
93 |
+
|
94 |
+
def forward(self, img_token, y=None, drop_label=None, return_attn=False): # , masking_flag=None):
|
95 |
+
b, w, h = img_token.size()
|
96 |
+
|
97 |
+
cls_token = y.view(b, -1) + self.codebook_size + 1
|
98 |
+
cls_token[drop_label] = self.codebook_size + 1 + self.nclass
|
99 |
+
input = torch.cat([img_token.view(b, -1), cls_token.view(b, -1)], -1)
|
100 |
+
tok_embeddings = self.tok_emb(input)
|
101 |
+
pos_embeddings = self.pos_emb
|
102 |
+
x = tok_embeddings + pos_embeddings
|
103 |
+
|
104 |
+
# if masking_flag is not None:
|
105 |
+
# flag = torch.cat([masking_flag.view(b, -1), torch.zeros_like(cls_token.view(b, -1))], -1)
|
106 |
+
# x += self.msk_emb(flag)
|
107 |
+
|
108 |
+
x = self.first_layer(x)
|
109 |
+
x, attn = self.transformer(x)
|
110 |
+
x = self.last_layer(x)
|
111 |
+
|
112 |
+
logit = torch.matmul(x, self.tok_emb.weight.T) + self.bias
|
113 |
+
|
114 |
+
if return_attn:
|
115 |
+
return logit[:, :self.patch_size * self.patch_size, :self.codebook_size + 1], attn
|
116 |
+
|
117 |
+
return logit[:, :self.patch_size*self.patch_size, :self.codebook_size+1]
|
Models/models/vqgan.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
|
7 |
+
from Models.modules.diffusionmodules.model import Encoder, Decoder
|
8 |
+
from Models.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
9 |
+
from Models.modules.vqvae.quantize import GumbelQuantize
|
10 |
+
|
11 |
+
|
12 |
+
def get_obj_from_str(string, reload=False):
|
13 |
+
module, cls = string.rsplit(".", 1)
|
14 |
+
if reload:
|
15 |
+
module_imp = importlib.import_module(module)
|
16 |
+
importlib.reload(module_imp)
|
17 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
18 |
+
|
19 |
+
|
20 |
+
def instantiate_from_config(config):
|
21 |
+
if not "target" in config:
|
22 |
+
raise KeyError("Expected key `target` to instantiate.")
|
23 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
24 |
+
|
25 |
+
|
26 |
+
class VQModel(pl.LightningModule):
|
27 |
+
def __init__(self,
|
28 |
+
ddconfig,
|
29 |
+
lossconfig,
|
30 |
+
n_embed,
|
31 |
+
embed_dim,
|
32 |
+
ckpt_path=None,
|
33 |
+
ignore_keys=[],
|
34 |
+
image_key="image",
|
35 |
+
colorize_nlabels=None,
|
36 |
+
monitor=None,
|
37 |
+
remap=None,
|
38 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
self.image_key = image_key
|
42 |
+
self.encoder = Encoder(**ddconfig)
|
43 |
+
self.decoder = Decoder(**ddconfig)
|
44 |
+
# self.loss = instantiate_from_config(lossconfig)
|
45 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
46 |
+
remap=remap, sane_index_shape=sane_index_shape)
|
47 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
48 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
49 |
+
if ckpt_path is not None:
|
50 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
51 |
+
self.image_key = image_key
|
52 |
+
if colorize_nlabels is not None:
|
53 |
+
assert type(colorize_nlabels) == int
|
54 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
55 |
+
if monitor is not None:
|
56 |
+
self.monitor = monitor
|
57 |
+
|
58 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
59 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
60 |
+
keys = list(sd.keys())
|
61 |
+
for k in keys:
|
62 |
+
for ik in ignore_keys:
|
63 |
+
if k.startswith(ik):
|
64 |
+
print("Deleting key {} from state_dict.".format(k))
|
65 |
+
del sd[k]
|
66 |
+
self.load_state_dict(sd, strict=False)
|
67 |
+
print(f"Restored from {path}")
|
68 |
+
|
69 |
+
def encode(self, x):
|
70 |
+
h = self.encoder(x)
|
71 |
+
h = self.quant_conv(h)
|
72 |
+
quant, emb_loss, info = self.quantize(h)
|
73 |
+
return quant, emb_loss, info
|
74 |
+
|
75 |
+
def decode(self, quant):
|
76 |
+
quant = self.post_quant_conv(quant)
|
77 |
+
dec = self.decoder(quant)
|
78 |
+
return dec
|
79 |
+
|
80 |
+
def decode_code(self, code_b):
|
81 |
+
quant_b = self.quantize.get_codebook_entry(code_b.view(-1), (code_b.size(0), code_b.size(1), code_b.size(2), 256))
|
82 |
+
dec = self.decode(quant_b)
|
83 |
+
return dec
|
84 |
+
|
85 |
+
def forward(self, input):
|
86 |
+
quant, diff, _ = self.encode(input)
|
87 |
+
dec = self.decode(quant)
|
88 |
+
return dec, diff
|
89 |
+
|
90 |
+
def get_input(self, batch, k):
|
91 |
+
x = batch[k]
|
92 |
+
if len(x.shape) == 3:
|
93 |
+
x = x[..., None]
|
94 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
95 |
+
return x.float()
|
96 |
+
|
97 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
98 |
+
x = self.get_input(batch, self.image_key)
|
99 |
+
xrec, qloss = self(x)
|
100 |
+
|
101 |
+
if optimizer_idx == 0:
|
102 |
+
# autoencode
|
103 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
104 |
+
last_layer=self.get_last_layer(), split="train")
|
105 |
+
|
106 |
+
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
107 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
108 |
+
return aeloss
|
109 |
+
|
110 |
+
if optimizer_idx == 1:
|
111 |
+
# discriminator
|
112 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
113 |
+
last_layer=self.get_last_layer(), split="train")
|
114 |
+
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
115 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
116 |
+
return discloss
|
117 |
+
|
118 |
+
def validation_step(self, batch, batch_idx):
|
119 |
+
x = self.get_input(batch, self.image_key)
|
120 |
+
xrec, qloss = self(x)
|
121 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
122 |
+
last_layer=self.get_last_layer(), split="val")
|
123 |
+
|
124 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
125 |
+
last_layer=self.get_last_layer(), split="val")
|
126 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
127 |
+
self.log("val/rec_loss", rec_loss,
|
128 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
129 |
+
self.log("val/aeloss", aeloss,
|
130 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
131 |
+
self.log_dict(log_dict_ae)
|
132 |
+
self.log_dict(log_dict_disc)
|
133 |
+
return self.log_dict
|
134 |
+
|
135 |
+
def configure_optimizers(self):
|
136 |
+
lr = self.learning_rate
|
137 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
138 |
+
list(self.decoder.parameters())+
|
139 |
+
list(self.quantize.parameters())+
|
140 |
+
list(self.quant_conv.parameters())+
|
141 |
+
list(self.post_quant_conv.parameters()),
|
142 |
+
lr=lr, betas=(0.5, 0.9))
|
143 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
144 |
+
lr=lr, betas=(0.5, 0.9))
|
145 |
+
return [opt_ae, opt_disc], []
|
146 |
+
|
147 |
+
def get_last_layer(self):
|
148 |
+
return self.decoder.conv_out.weight
|
149 |
+
|
150 |
+
def log_images(self, batch, **kwargs):
|
151 |
+
log = dict()
|
152 |
+
x = self.get_input(batch, self.image_key)
|
153 |
+
x = x.to(self.device)
|
154 |
+
xrec, _ = self(x)
|
155 |
+
if x.shape[1] > 3:
|
156 |
+
# colorize with random projection
|
157 |
+
assert xrec.shape[1] > 3
|
158 |
+
x = self.to_rgb(x)
|
159 |
+
xrec = self.to_rgb(xrec)
|
160 |
+
log["inputs"] = x
|
161 |
+
log["reconstructions"] = xrec
|
162 |
+
return log
|
163 |
+
|
164 |
+
def to_rgb(self, x):
|
165 |
+
assert self.image_key == "segmentation"
|
166 |
+
if not hasattr(self, "colorize"):
|
167 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
168 |
+
x = F.conv2d(x, weight=self.colorize)
|
169 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class GumbelVQ(VQModel):
|
174 |
+
def __init__(self,
|
175 |
+
ddconfig,
|
176 |
+
lossconfig,
|
177 |
+
n_embed,
|
178 |
+
embed_dim,
|
179 |
+
temperature_scheduler_config,
|
180 |
+
ckpt_path=None,
|
181 |
+
ignore_keys=[],
|
182 |
+
image_key="image",
|
183 |
+
colorize_nlabels=None,
|
184 |
+
monitor=None,
|
185 |
+
kl_weight=1e-8,
|
186 |
+
remap=None,
|
187 |
+
):
|
188 |
+
|
189 |
+
z_channels = ddconfig["z_channels"]
|
190 |
+
super().__init__(ddconfig,
|
191 |
+
lossconfig,
|
192 |
+
n_embed,
|
193 |
+
embed_dim,
|
194 |
+
ckpt_path=None,
|
195 |
+
ignore_keys=ignore_keys,
|
196 |
+
image_key=image_key,
|
197 |
+
colorize_nlabels=colorize_nlabels,
|
198 |
+
monitor=monitor,
|
199 |
+
)
|
200 |
+
|
201 |
+
# self.loss.n_classes = n_embed
|
202 |
+
self.vocab_size = n_embed
|
203 |
+
|
204 |
+
self.quantize = GumbelQuantize(z_channels, embed_dim,
|
205 |
+
n_embed=n_embed,
|
206 |
+
kl_weight=kl_weight, temp_init=1.0,
|
207 |
+
remap=remap)
|
208 |
+
|
209 |
+
# self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
|
210 |
+
|
211 |
+
if ckpt_path is not None:
|
212 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
213 |
+
|
214 |
+
def temperature_scheduling(self):
|
215 |
+
self.quantize.temperature = self.temperature_scheduler(self.global_step)
|
216 |
+
|
217 |
+
def encode_to_prequant(self, x):
|
218 |
+
h = self.encoder(x)
|
219 |
+
h = self.quant_conv(h)
|
220 |
+
return h
|
221 |
+
|
222 |
+
def decode_code(self, code_b):
|
223 |
+
quant_b = self.quantize.get_codebook_entry(code_b.view(-1), (code_b.size(0), 32, 32, 8192))
|
224 |
+
dec = self.decode(quant_b)
|
225 |
+
return dec
|
226 |
+
|
227 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
228 |
+
self.temperature_scheduling()
|
229 |
+
x = self.get_input(batch, self.image_key)
|
230 |
+
xrec, qloss = self(x)
|
231 |
+
|
232 |
+
if optimizer_idx == 0:
|
233 |
+
# autoencoder
|
234 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
235 |
+
last_layer=self.get_last_layer(), split="train")
|
236 |
+
|
237 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
238 |
+
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
239 |
+
return aeloss
|
240 |
+
|
241 |
+
if optimizer_idx == 1:
|
242 |
+
# discriminator
|
243 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
244 |
+
last_layer=self.get_last_layer(), split="train")
|
245 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
246 |
+
return discloss
|
247 |
+
|
248 |
+
def validation_step(self, batch, batch_idx):
|
249 |
+
x = self.get_input(batch, self.image_key)
|
250 |
+
xrec, qloss = self(x, return_pred_indices=True)
|
251 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
252 |
+
last_layer=self.get_last_layer(), split="val")
|
253 |
+
|
254 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
255 |
+
last_layer=self.get_last_layer(), split="val")
|
256 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
257 |
+
self.log("val/rec_loss", rec_loss,
|
258 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
259 |
+
self.log("val/aeloss", aeloss,
|
260 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
261 |
+
self.log_dict(log_dict_ae)
|
262 |
+
self.log_dict(log_dict_disc)
|
263 |
+
return self.log_dict
|
264 |
+
|
265 |
+
def log_images(self, batch, **kwargs):
|
266 |
+
log = dict()
|
267 |
+
x = self.get_input(batch, self.image_key)
|
268 |
+
x = x.to(self.device)
|
269 |
+
# encode
|
270 |
+
h = self.encoder(x)
|
271 |
+
h = self.quant_conv(h)
|
272 |
+
quant, _, _ = self.quantize(h)
|
273 |
+
# decode
|
274 |
+
x_rec = self.decode(quant)
|
275 |
+
log["inputs"] = x
|
276 |
+
log["reconstructions"] = x_rec
|
277 |
+
return log
|
278 |
+
|
279 |
+
def reco(self, x): # , batch, **kwargs):
|
280 |
+
# log = dict()
|
281 |
+
# x = self.get_input(batch, self.image_key)
|
282 |
+
# x = x.to(self.device)
|
283 |
+
# encode
|
284 |
+
h = self.encoder(x)
|
285 |
+
# print(h, h.size())
|
286 |
+
h = self.quant_conv(h)
|
287 |
+
quant, _, _ = self.quantize(h)
|
288 |
+
print(quant, quant.size())
|
289 |
+
exit()
|
290 |
+
# decode
|
291 |
+
x_rec = self.decode(quant)
|
292 |
+
# log["inputs"] = x
|
293 |
+
# log["reconstructions"] = x_rec
|
294 |
+
return x_rec
|
Models/modules/diffusionmodules/model.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_diffusion + derived encoder decoder
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
8 |
+
"""
|
9 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
10 |
+
From Fairseq.
|
11 |
+
Build sinusoidal embeddings.
|
12 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
13 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
14 |
+
"""
|
15 |
+
assert len(timesteps.shape) == 1
|
16 |
+
|
17 |
+
half_dim = embedding_dim // 2
|
18 |
+
emb = math.log(10000) / (half_dim - 1)
|
19 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
20 |
+
emb = emb.to(device=timesteps.device)
|
21 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
22 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
23 |
+
if embedding_dim % 2 == 1: # zero pad
|
24 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
25 |
+
return emb
|
26 |
+
|
27 |
+
|
28 |
+
def nonlinearity(x):
|
29 |
+
# swish
|
30 |
+
return x*torch.sigmoid(x)
|
31 |
+
|
32 |
+
|
33 |
+
def Normalize(in_channels):
|
34 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
35 |
+
|
36 |
+
|
37 |
+
class Upsample(nn.Module):
|
38 |
+
def __init__(self, in_channels, with_conv):
|
39 |
+
super().__init__()
|
40 |
+
self.with_conv = with_conv
|
41 |
+
if self.with_conv:
|
42 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
43 |
+
in_channels,
|
44 |
+
kernel_size=3,
|
45 |
+
stride=1,
|
46 |
+
padding=1)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
50 |
+
if self.with_conv:
|
51 |
+
x = self.conv(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class Downsample(nn.Module):
|
56 |
+
def __init__(self, in_channels, with_conv):
|
57 |
+
super().__init__()
|
58 |
+
self.with_conv = with_conv
|
59 |
+
if self.with_conv:
|
60 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
61 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
62 |
+
in_channels,
|
63 |
+
kernel_size=3,
|
64 |
+
stride=2,
|
65 |
+
padding=0)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.with_conv:
|
69 |
+
pad = (0,1,0,1)
|
70 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
71 |
+
x = self.conv(x)
|
72 |
+
else:
|
73 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class ResnetBlock(nn.Module):
|
78 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
79 |
+
dropout, temb_channels=512):
|
80 |
+
super().__init__()
|
81 |
+
self.in_channels = in_channels
|
82 |
+
out_channels = in_channels if out_channels is None else out_channels
|
83 |
+
self.out_channels = out_channels
|
84 |
+
self.use_conv_shortcut = conv_shortcut
|
85 |
+
|
86 |
+
self.norm1 = Normalize(in_channels)
|
87 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
88 |
+
out_channels,
|
89 |
+
kernel_size=3,
|
90 |
+
stride=1,
|
91 |
+
padding=1)
|
92 |
+
if temb_channels > 0:
|
93 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
94 |
+
out_channels)
|
95 |
+
self.norm2 = Normalize(out_channels)
|
96 |
+
self.dropout = torch.nn.Dropout(dropout)
|
97 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
98 |
+
out_channels,
|
99 |
+
kernel_size=3,
|
100 |
+
stride=1,
|
101 |
+
padding=1)
|
102 |
+
if self.in_channels != self.out_channels:
|
103 |
+
if self.use_conv_shortcut:
|
104 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
105 |
+
out_channels,
|
106 |
+
kernel_size=3,
|
107 |
+
stride=1,
|
108 |
+
padding=1)
|
109 |
+
else:
|
110 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
111 |
+
out_channels,
|
112 |
+
kernel_size=1,
|
113 |
+
stride=1,
|
114 |
+
padding=0)
|
115 |
+
|
116 |
+
def forward(self, x, temb):
|
117 |
+
h = x
|
118 |
+
h = self.norm1(h)
|
119 |
+
h = nonlinearity(h)
|
120 |
+
h = self.conv1(h)
|
121 |
+
|
122 |
+
if temb is not None:
|
123 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
124 |
+
|
125 |
+
h = self.norm2(h)
|
126 |
+
h = nonlinearity(h)
|
127 |
+
h = self.dropout(h)
|
128 |
+
h = self.conv2(h)
|
129 |
+
|
130 |
+
if self.in_channels != self.out_channels:
|
131 |
+
if self.use_conv_shortcut:
|
132 |
+
x = self.conv_shortcut(x)
|
133 |
+
else:
|
134 |
+
x = self.nin_shortcut(x)
|
135 |
+
|
136 |
+
return x+h
|
137 |
+
|
138 |
+
|
139 |
+
class AttnBlock(nn.Module):
|
140 |
+
def __init__(self, in_channels):
|
141 |
+
super().__init__()
|
142 |
+
self.in_channels = in_channels
|
143 |
+
|
144 |
+
self.norm = Normalize(in_channels)
|
145 |
+
self.q = torch.nn.Conv2d(in_channels,
|
146 |
+
in_channels,
|
147 |
+
kernel_size=1,
|
148 |
+
stride=1,
|
149 |
+
padding=0)
|
150 |
+
self.k = torch.nn.Conv2d(in_channels,
|
151 |
+
in_channels,
|
152 |
+
kernel_size=1,
|
153 |
+
stride=1,
|
154 |
+
padding=0)
|
155 |
+
self.v = torch.nn.Conv2d(in_channels,
|
156 |
+
in_channels,
|
157 |
+
kernel_size=1,
|
158 |
+
stride=1,
|
159 |
+
padding=0)
|
160 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
161 |
+
in_channels,
|
162 |
+
kernel_size=1,
|
163 |
+
stride=1,
|
164 |
+
padding=0)
|
165 |
+
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
h_ = x
|
169 |
+
h_ = self.norm(h_)
|
170 |
+
q = self.q(h_)
|
171 |
+
k = self.k(h_)
|
172 |
+
v = self.v(h_)
|
173 |
+
|
174 |
+
# compute attention
|
175 |
+
b,c,h,w = q.shape
|
176 |
+
q = q.reshape(b,c,h*w)
|
177 |
+
q = q.permute(0,2,1) # b,hw,c
|
178 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
179 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
180 |
+
w_ = w_ * (int(c)**(-0.5))
|
181 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
182 |
+
|
183 |
+
# attend to values
|
184 |
+
v = v.reshape(b,c,h*w)
|
185 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
186 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
187 |
+
h_ = h_.reshape(b,c,h,w)
|
188 |
+
|
189 |
+
h_ = self.proj_out(h_)
|
190 |
+
|
191 |
+
return x+h_
|
192 |
+
|
193 |
+
|
194 |
+
class Encoder(nn.Module):
|
195 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
196 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
197 |
+
resolution, z_channels, double_z=True, **ignore_kwargs):
|
198 |
+
super().__init__()
|
199 |
+
self.ch = ch
|
200 |
+
self.temb_ch = 0
|
201 |
+
self.num_resolutions = len(ch_mult)
|
202 |
+
self.num_res_blocks = num_res_blocks
|
203 |
+
self.resolution = resolution
|
204 |
+
self.in_channels = in_channels
|
205 |
+
|
206 |
+
# downsampling
|
207 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
208 |
+
self.ch,
|
209 |
+
kernel_size=3,
|
210 |
+
stride=1,
|
211 |
+
padding=1)
|
212 |
+
|
213 |
+
curr_res = resolution
|
214 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
215 |
+
self.down = nn.ModuleList()
|
216 |
+
for i_level in range(self.num_resolutions):
|
217 |
+
block = nn.ModuleList()
|
218 |
+
attn = nn.ModuleList()
|
219 |
+
block_in = ch*in_ch_mult[i_level]
|
220 |
+
block_out = ch*ch_mult[i_level]
|
221 |
+
for i_block in range(self.num_res_blocks):
|
222 |
+
block.append(ResnetBlock(in_channels=block_in,
|
223 |
+
out_channels=block_out,
|
224 |
+
temb_channels=self.temb_ch,
|
225 |
+
dropout=dropout))
|
226 |
+
block_in = block_out
|
227 |
+
if curr_res in attn_resolutions:
|
228 |
+
attn.append(AttnBlock(block_in))
|
229 |
+
down = nn.Module()
|
230 |
+
down.block = block
|
231 |
+
down.attn = attn
|
232 |
+
if i_level != self.num_resolutions-1:
|
233 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
234 |
+
curr_res = curr_res // 2
|
235 |
+
self.down.append(down)
|
236 |
+
|
237 |
+
# middle
|
238 |
+
self.mid = nn.Module()
|
239 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
240 |
+
out_channels=block_in,
|
241 |
+
temb_channels=self.temb_ch,
|
242 |
+
dropout=dropout)
|
243 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
244 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
245 |
+
out_channels=block_in,
|
246 |
+
temb_channels=self.temb_ch,
|
247 |
+
dropout=dropout)
|
248 |
+
|
249 |
+
# end
|
250 |
+
self.norm_out = Normalize(block_in)
|
251 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
252 |
+
2*z_channels if double_z else z_channels,
|
253 |
+
kernel_size=3,
|
254 |
+
stride=1,
|
255 |
+
padding=1)
|
256 |
+
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
260 |
+
|
261 |
+
# timestep embedding
|
262 |
+
temb = None
|
263 |
+
|
264 |
+
# downsampling
|
265 |
+
hs = [self.conv_in(x)]
|
266 |
+
for i_level in range(self.num_resolutions):
|
267 |
+
for i_block in range(self.num_res_blocks):
|
268 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
269 |
+
if len(self.down[i_level].attn) > 0:
|
270 |
+
h = self.down[i_level].attn[i_block](h)
|
271 |
+
hs.append(h)
|
272 |
+
if i_level != self.num_resolutions-1:
|
273 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
274 |
+
|
275 |
+
# middle
|
276 |
+
h = hs[-1]
|
277 |
+
h = self.mid.block_1(h, temb)
|
278 |
+
h = self.mid.attn_1(h)
|
279 |
+
h = self.mid.block_2(h, temb)
|
280 |
+
|
281 |
+
# end
|
282 |
+
h = self.norm_out(h)
|
283 |
+
h = nonlinearity(h)
|
284 |
+
h = self.conv_out(h)
|
285 |
+
return h
|
286 |
+
|
287 |
+
|
288 |
+
class Decoder(nn.Module):
|
289 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
290 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
291 |
+
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
292 |
+
super().__init__()
|
293 |
+
self.ch = ch
|
294 |
+
self.temb_ch = 0
|
295 |
+
self.num_resolutions = len(ch_mult)
|
296 |
+
self.num_res_blocks = num_res_blocks
|
297 |
+
self.resolution = resolution
|
298 |
+
self.in_channels = in_channels
|
299 |
+
self.give_pre_end = give_pre_end
|
300 |
+
|
301 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
302 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
303 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
304 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
305 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
306 |
+
# print("Working with z of shape {} = {} dimensions.".format(
|
307 |
+
# self.z_shape, np.prod(self.z_shape)))
|
308 |
+
|
309 |
+
# z to block_in
|
310 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
311 |
+
block_in,
|
312 |
+
kernel_size=3,
|
313 |
+
stride=1,
|
314 |
+
padding=1)
|
315 |
+
|
316 |
+
# middle
|
317 |
+
self.mid = nn.Module()
|
318 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
319 |
+
out_channels=block_in,
|
320 |
+
temb_channels=self.temb_ch,
|
321 |
+
dropout=dropout)
|
322 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
323 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
324 |
+
out_channels=block_in,
|
325 |
+
temb_channels=self.temb_ch,
|
326 |
+
dropout=dropout)
|
327 |
+
|
328 |
+
# upsampling
|
329 |
+
self.up = nn.ModuleList()
|
330 |
+
for i_level in reversed(range(self.num_resolutions)):
|
331 |
+
block = nn.ModuleList()
|
332 |
+
attn = nn.ModuleList()
|
333 |
+
block_out = ch*ch_mult[i_level]
|
334 |
+
for i_block in range(self.num_res_blocks+1):
|
335 |
+
block.append(ResnetBlock(in_channels=block_in,
|
336 |
+
out_channels=block_out,
|
337 |
+
temb_channels=self.temb_ch,
|
338 |
+
dropout=dropout))
|
339 |
+
block_in = block_out
|
340 |
+
if curr_res in attn_resolutions:
|
341 |
+
attn.append(AttnBlock(block_in))
|
342 |
+
up = nn.Module()
|
343 |
+
up.block = block
|
344 |
+
up.attn = attn
|
345 |
+
if i_level != 0:
|
346 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
347 |
+
curr_res = curr_res * 2
|
348 |
+
self.up.insert(0, up) # prepend to get consistent order
|
349 |
+
|
350 |
+
# end
|
351 |
+
self.norm_out = Normalize(block_in)
|
352 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
353 |
+
out_ch,
|
354 |
+
kernel_size=3,
|
355 |
+
stride=1,
|
356 |
+
padding=1)
|
357 |
+
|
358 |
+
def forward(self, z):
|
359 |
+
self.last_z_shape = z.shape
|
360 |
+
|
361 |
+
# timestep embedding
|
362 |
+
temb = None
|
363 |
+
|
364 |
+
# z to block_in
|
365 |
+
h = self.conv_in(z)
|
366 |
+
|
367 |
+
# middle
|
368 |
+
h = self.mid.block_1(h, temb)
|
369 |
+
h = self.mid.attn_1(h)
|
370 |
+
h = self.mid.block_2(h, temb)
|
371 |
+
|
372 |
+
# upsampling
|
373 |
+
for i_level in reversed(range(self.num_resolutions)):
|
374 |
+
for i_block in range(self.num_res_blocks+1):
|
375 |
+
h = self.up[i_level].block[i_block](h, temb)
|
376 |
+
if len(self.up[i_level].attn) > 0:
|
377 |
+
h = self.up[i_level].attn[i_block](h)
|
378 |
+
if i_level != 0:
|
379 |
+
h = self.up[i_level].upsample(h)
|
380 |
+
|
381 |
+
# end
|
382 |
+
if self.give_pre_end:
|
383 |
+
return h
|
384 |
+
|
385 |
+
h = self.norm_out(h)
|
386 |
+
h = nonlinearity(h)
|
387 |
+
h = self.conv_out(h)
|
388 |
+
return h
|
389 |
+
|
390 |
+
|
391 |
+
class UpsampleDecoder(nn.Module):
|
392 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
|
393 |
+
super().__init__()
|
394 |
+
# upsampling
|
395 |
+
self.temb_ch = 0
|
396 |
+
self.num_resolutions = len(ch_mult)
|
397 |
+
self.num_res_blocks = num_res_blocks
|
398 |
+
block_in = in_channels
|
399 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
400 |
+
self.res_blocks = nn.ModuleList()
|
401 |
+
self.upsample_blocks = nn.ModuleList()
|
402 |
+
for i_level in range(self.num_resolutions):
|
403 |
+
res_block = []
|
404 |
+
block_out = ch * ch_mult[i_level]
|
405 |
+
for i_block in range(self.num_res_blocks + 1):
|
406 |
+
res_block.append(ResnetBlock(in_channels=block_in,
|
407 |
+
out_channels=block_out,
|
408 |
+
temb_channels=self.temb_ch,
|
409 |
+
dropout=dropout))
|
410 |
+
block_in = block_out
|
411 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
412 |
+
if i_level != self.num_resolutions - 1:
|
413 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
414 |
+
curr_res = curr_res * 2
|
415 |
+
|
416 |
+
# end
|
417 |
+
self.norm_out = Normalize(block_in)
|
418 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
419 |
+
out_channels,
|
420 |
+
kernel_size=3,
|
421 |
+
stride=1,
|
422 |
+
padding=1)
|
423 |
+
|
424 |
+
def forward(self, x):
|
425 |
+
# upsampling
|
426 |
+
h = x
|
427 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
428 |
+
for i_block in range(self.num_res_blocks + 1):
|
429 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
430 |
+
if i_level != self.num_resolutions - 1:
|
431 |
+
h = self.upsample_blocks[k](h)
|
432 |
+
h = self.norm_out(h)
|
433 |
+
h = nonlinearity(h)
|
434 |
+
h = self.conv_out(h)
|
435 |
+
return h
|
436 |
+
|
Models/modules/util.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
def count_params(model):
|
6 |
+
total_params = sum(p.numel() for p in model.parameters())
|
7 |
+
return total_params
|
8 |
+
|
9 |
+
|
10 |
+
class ActNorm(nn.Module):
|
11 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
12 |
+
allow_reverse_init=False):
|
13 |
+
assert affine
|
14 |
+
super().__init__()
|
15 |
+
self.logdet = logdet
|
16 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
17 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
18 |
+
self.allow_reverse_init = allow_reverse_init
|
19 |
+
|
20 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
21 |
+
|
22 |
+
def initialize(self, input):
|
23 |
+
with torch.no_grad():
|
24 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
25 |
+
mean = (
|
26 |
+
flatten.mean(1)
|
27 |
+
.unsqueeze(1)
|
28 |
+
.unsqueeze(2)
|
29 |
+
.unsqueeze(3)
|
30 |
+
.permute(1, 0, 2, 3)
|
31 |
+
)
|
32 |
+
std = (
|
33 |
+
flatten.std(1)
|
34 |
+
.unsqueeze(1)
|
35 |
+
.unsqueeze(2)
|
36 |
+
.unsqueeze(3)
|
37 |
+
.permute(1, 0, 2, 3)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.loc.data.copy_(-mean)
|
41 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
42 |
+
|
43 |
+
def forward(self, input, reverse=False):
|
44 |
+
if reverse:
|
45 |
+
return self.reverse(input)
|
46 |
+
if len(input.shape) == 2:
|
47 |
+
input = input[:,:,None,None]
|
48 |
+
squeeze = True
|
49 |
+
else:
|
50 |
+
squeeze = False
|
51 |
+
|
52 |
+
_, _, height, width = input.shape
|
53 |
+
|
54 |
+
if self.training and self.initialized.item() == 0:
|
55 |
+
self.initialize(input)
|
56 |
+
self.initialized.fill_(1)
|
57 |
+
|
58 |
+
h = self.scale * (input + self.loc)
|
59 |
+
|
60 |
+
if squeeze:
|
61 |
+
h = h.squeeze(-1).squeeze(-1)
|
62 |
+
|
63 |
+
if self.logdet:
|
64 |
+
log_abs = torch.log(torch.abs(self.scale))
|
65 |
+
logdet = height*width*torch.sum(log_abs)
|
66 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
67 |
+
return h, logdet
|
68 |
+
|
69 |
+
return h
|
70 |
+
|
71 |
+
def reverse(self, output):
|
72 |
+
if self.training and self.initialized.item() == 0:
|
73 |
+
if not self.allow_reverse_init:
|
74 |
+
raise RuntimeError(
|
75 |
+
"Initializing ActNorm in reverse direction is "
|
76 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
self.initialize(output)
|
80 |
+
self.initialized.fill_(1)
|
81 |
+
|
82 |
+
if len(output.shape) == 2:
|
83 |
+
output = output[:,:,None,None]
|
84 |
+
squeeze = True
|
85 |
+
else:
|
86 |
+
squeeze = False
|
87 |
+
|
88 |
+
h = output / self.scale - self.loc
|
89 |
+
|
90 |
+
if squeeze:
|
91 |
+
h = h.squeeze(-1).squeeze(-1)
|
92 |
+
return h
|
93 |
+
|
94 |
+
|
95 |
+
class AbstractEncoder(nn.Module):
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
def encode(self, *args, **kwargs):
|
100 |
+
raise NotImplementedError
|
101 |
+
|
102 |
+
|
103 |
+
class Labelator(AbstractEncoder):
|
104 |
+
"""Net2Net Interface for Class-Conditional Model"""
|
105 |
+
def __init__(self, n_classes, quantize_interface=True):
|
106 |
+
super().__init__()
|
107 |
+
self.n_classes = n_classes
|
108 |
+
self.quantize_interface = quantize_interface
|
109 |
+
|
110 |
+
def encode(self, c):
|
111 |
+
c = c[:,None]
|
112 |
+
if self.quantize_interface:
|
113 |
+
return c, None, [None, None, c.long()]
|
114 |
+
return c
|
115 |
+
|
116 |
+
|
117 |
+
class SOSProvider(AbstractEncoder):
|
118 |
+
# for unconditional training
|
119 |
+
def __init__(self, sos_token, quantize_interface=True):
|
120 |
+
super().__init__()
|
121 |
+
self.sos_token = sos_token
|
122 |
+
self.quantize_interface = quantize_interface
|
123 |
+
|
124 |
+
def encode(self, x):
|
125 |
+
# get batch size from data and replicate sos_token
|
126 |
+
c = torch.ones(x.shape[0], 1)*self.sos_token
|
127 |
+
c = c.long().to(x.device)
|
128 |
+
if self.quantize_interface:
|
129 |
+
return c, None, [None, None, c]
|
130 |
+
return c
|
Models/modules/vqvae/quantize.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from torch import einsum
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
|
9 |
+
class VectorQuantizer(nn.Module):
|
10 |
+
"""
|
11 |
+
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
12 |
+
____________________________________________
|
13 |
+
Discretization bottleneck part of the VQ-VAE.
|
14 |
+
Inputs:
|
15 |
+
- n_e : number of embeddings
|
16 |
+
- e_dim : dimension of embedding
|
17 |
+
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
18 |
+
_____________________________________________
|
19 |
+
"""
|
20 |
+
|
21 |
+
# NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
|
22 |
+
# a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
|
23 |
+
# used wherever VectorQuantizer has been used before and is additionally
|
24 |
+
# more efficient.
|
25 |
+
def __init__(self, n_e, e_dim, beta):
|
26 |
+
super(VectorQuantizer, self).__init__()
|
27 |
+
self.n_e = n_e
|
28 |
+
self.e_dim = e_dim
|
29 |
+
self.beta = beta
|
30 |
+
|
31 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
32 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
33 |
+
|
34 |
+
def forward(self, z):
|
35 |
+
"""
|
36 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
37 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
38 |
+
z (continuous) -> z_q (discrete)
|
39 |
+
z.shape = (batch, channel, height, width)
|
40 |
+
quantization pipeline:
|
41 |
+
1. get encoder input (B,C,H,W)
|
42 |
+
2. flatten input to (B*H*W,C)
|
43 |
+
"""
|
44 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
45 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
46 |
+
z_flattened = z.view(-1, self.e_dim)
|
47 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
48 |
+
|
49 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
50 |
+
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
|
51 |
+
torch.matmul(z_flattened, self.embedding.weight.t())
|
52 |
+
|
53 |
+
## could possible replace this here
|
54 |
+
# #\start...
|
55 |
+
# find closest encodings
|
56 |
+
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
57 |
+
|
58 |
+
min_encodings = torch.zeros(
|
59 |
+
min_encoding_indices.shape[0], self.n_e).to(z)
|
60 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
61 |
+
|
62 |
+
# dtype min encodings: torch.float32
|
63 |
+
# min_encodings shape: torch.Size([2048, 512])
|
64 |
+
# min_encoding_indices.shape: torch.Size([2048, 1])
|
65 |
+
|
66 |
+
# get quantized latent vectors
|
67 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
68 |
+
# .........\end
|
69 |
+
|
70 |
+
# with:
|
71 |
+
# .........\start
|
72 |
+
# min_encoding_indices = torch.argmin(d, dim=1)
|
73 |
+
# z_q = self.embedding(min_encoding_indices)
|
74 |
+
# ......\end......... (TODO)
|
75 |
+
|
76 |
+
# compute loss for embedding
|
77 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
|
78 |
+
torch.mean((z_q - z.detach()) ** 2)
|
79 |
+
|
80 |
+
# preserve gradients
|
81 |
+
z_q = z + (z_q - z).detach()
|
82 |
+
|
83 |
+
# perplexity
|
84 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
85 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
86 |
+
|
87 |
+
# reshape back to match original input shape
|
88 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
89 |
+
|
90 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
91 |
+
|
92 |
+
def get_codebook_entry(self, indices, shape):
|
93 |
+
# shape specifying (batch, height, width, channel)
|
94 |
+
# TODO: check for more easy handling with nn.Embedding
|
95 |
+
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
96 |
+
min_encodings.scatter_(1, indices[:, None], 1)
|
97 |
+
|
98 |
+
# get quantized latent vectors
|
99 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
100 |
+
|
101 |
+
if shape is not None:
|
102 |
+
z_q = z_q.view(shape)
|
103 |
+
|
104 |
+
# reshape back to match original input shape
|
105 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
106 |
+
|
107 |
+
return z_q
|
108 |
+
|
109 |
+
|
110 |
+
class GumbelQuantize(nn.Module):
|
111 |
+
"""
|
112 |
+
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
113 |
+
Gumbel Softmax trick quantizer
|
114 |
+
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
115 |
+
https://arxiv.org/abs/1611.01144
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
|
119 |
+
kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
|
120 |
+
remap=None, unknown_index="random"):
|
121 |
+
super().__init__()
|
122 |
+
|
123 |
+
self.embedding_dim = embedding_dim
|
124 |
+
self.n_embed = n_embed
|
125 |
+
print(n_embed)
|
126 |
+
self.straight_through = straight_through
|
127 |
+
self.temperature = temp_init
|
128 |
+
self.kl_weight = kl_weight
|
129 |
+
|
130 |
+
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
131 |
+
self.embed = nn.Embedding(n_embed, embedding_dim)
|
132 |
+
|
133 |
+
self.use_vqinterface = use_vqinterface
|
134 |
+
|
135 |
+
self.remap = remap
|
136 |
+
|
137 |
+
if self.remap is not None:
|
138 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
139 |
+
self.re_embed = self.used.shape[0]
|
140 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
141 |
+
if self.unknown_index == "extra":
|
142 |
+
self.unknown_index = self.re_embed
|
143 |
+
self.re_embed = self.re_embed + 1
|
144 |
+
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
145 |
+
f"Using {self.unknown_index} for unknown indices.")
|
146 |
+
else:
|
147 |
+
self.re_embed = n_embed
|
148 |
+
|
149 |
+
def remap_to_used(self, inds):
|
150 |
+
ishape = inds.shape
|
151 |
+
assert len(ishape) > 1
|
152 |
+
inds = inds.reshape(ishape[0], -1)
|
153 |
+
used = self.used.to(inds)
|
154 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
155 |
+
new = match.argmax(-1)
|
156 |
+
unknown = match.sum(2) < 1
|
157 |
+
if self.unknown_index == "random":
|
158 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
159 |
+
else:
|
160 |
+
new[unknown] = self.unknown_index
|
161 |
+
return new.reshape(ishape)
|
162 |
+
|
163 |
+
def unmap_to_all(self, inds):
|
164 |
+
ishape = inds.shape
|
165 |
+
assert len(ishape) > 1
|
166 |
+
inds = inds.reshape(ishape[0], -1)
|
167 |
+
used = self.used.to(inds)
|
168 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
169 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
170 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
171 |
+
return back.reshape(ishape)
|
172 |
+
|
173 |
+
def forward(self, z, temp=None, return_logits=False):
|
174 |
+
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
|
175 |
+
hard = self.straight_through if self.training else True
|
176 |
+
temp = self.temperature if temp is None else temp
|
177 |
+
|
178 |
+
logits = self.proj(z)
|
179 |
+
if self.remap is not None:
|
180 |
+
# continue only with used logits
|
181 |
+
full_zeros = torch.zeros_like(logits)
|
182 |
+
logits = logits[:, self.used, ...]
|
183 |
+
|
184 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
185 |
+
if self.remap is not None:
|
186 |
+
# go back to all entries but unused set to zero
|
187 |
+
full_zeros[:, self.used, ...] = soft_one_hot
|
188 |
+
soft_one_hot = full_zeros
|
189 |
+
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
|
190 |
+
|
191 |
+
# + kl divergence to the prior loss
|
192 |
+
qy = F.softmax(logits, dim=1)
|
193 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
194 |
+
|
195 |
+
ind = soft_one_hot.argmax(dim=1)
|
196 |
+
if self.remap is not None:
|
197 |
+
ind = self.remap_to_used(ind)
|
198 |
+
if self.use_vqinterface:
|
199 |
+
if return_logits:
|
200 |
+
return z_q, diff, (None, None, ind), logits
|
201 |
+
return z_q, diff, (None, None, ind)
|
202 |
+
return z_q, diff, ind
|
203 |
+
|
204 |
+
def get_codebook_entry(self, indices, shape):
|
205 |
+
b, h, w, c = shape
|
206 |
+
assert b * h * w == indices.shape[0]
|
207 |
+
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
|
208 |
+
if self.remap is not None:
|
209 |
+
indices = self.unmap_to_all(indices)
|
210 |
+
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
211 |
+
# print(one_hot.size())
|
212 |
+
# exit()
|
213 |
+
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
|
214 |
+
|
215 |
+
return z_q
|
216 |
+
|
217 |
+
|
218 |
+
class VectorQuantizer2(nn.Module):
|
219 |
+
"""
|
220 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
221 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
222 |
+
"""
|
223 |
+
|
224 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
225 |
+
# backwards compatibility we use the buggy version by default, but you can
|
226 |
+
# specify legacy=False to fix it.
|
227 |
+
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
228 |
+
sane_index_shape=False, legacy=True):
|
229 |
+
super().__init__()
|
230 |
+
self.n_e = n_e
|
231 |
+
self.e_dim = e_dim
|
232 |
+
self.beta = beta
|
233 |
+
self.legacy = legacy
|
234 |
+
|
235 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
236 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
237 |
+
|
238 |
+
self.remap = remap
|
239 |
+
if self.remap is not None:
|
240 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
241 |
+
self.re_embed = self.used.shape[0]
|
242 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
243 |
+
if self.unknown_index == "extra":
|
244 |
+
self.unknown_index = self.re_embed
|
245 |
+
self.re_embed = self.re_embed + 1
|
246 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
247 |
+
f"Using {self.unknown_index} for unknown indices.")
|
248 |
+
else:
|
249 |
+
self.re_embed = n_e
|
250 |
+
|
251 |
+
self.sane_index_shape = sane_index_shape
|
252 |
+
|
253 |
+
def remap_to_used(self, inds):
|
254 |
+
ishape = inds.shape
|
255 |
+
assert len(ishape) > 1
|
256 |
+
inds = inds.reshape(ishape[0], -1)
|
257 |
+
used = self.used.to(inds)
|
258 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
259 |
+
new = match.argmax(-1)
|
260 |
+
unknown = match.sum(2) < 1
|
261 |
+
if self.unknown_index == "random":
|
262 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
263 |
+
else:
|
264 |
+
new[unknown] = self.unknown_index
|
265 |
+
return new.reshape(ishape)
|
266 |
+
|
267 |
+
def unmap_to_all(self, inds):
|
268 |
+
ishape = inds.shape
|
269 |
+
assert len(ishape) > 1
|
270 |
+
inds = inds.reshape(ishape[0], -1)
|
271 |
+
used = self.used.to(inds)
|
272 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
273 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
274 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
275 |
+
return back.reshape(ishape)
|
276 |
+
|
277 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
278 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
279 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
280 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
281 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
282 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
283 |
+
z_flattened = z.view(-1, self.e_dim)
|
284 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
285 |
+
|
286 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
287 |
+
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
|
288 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
289 |
+
|
290 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
291 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
292 |
+
perplexity = None
|
293 |
+
min_encodings = None
|
294 |
+
|
295 |
+
# compute loss for embedding
|
296 |
+
if not self.legacy:
|
297 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
|
298 |
+
torch.mean((z_q - z.detach()) ** 2)
|
299 |
+
else:
|
300 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
|
301 |
+
torch.mean((z_q - z.detach()) ** 2)
|
302 |
+
|
303 |
+
# preserve gradients
|
304 |
+
z_q = z + (z_q - z).detach()
|
305 |
+
|
306 |
+
# reshape back to match original input shape
|
307 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
308 |
+
|
309 |
+
if self.remap is not None:
|
310 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
311 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
312 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
313 |
+
|
314 |
+
if self.sane_index_shape:
|
315 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
316 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
317 |
+
|
318 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
319 |
+
|
320 |
+
def get_codebook_entry(self, indices, shape):
|
321 |
+
# shape specifying (batch, height, width, channel)
|
322 |
+
if self.remap is not None:
|
323 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
324 |
+
indices = self.unmap_to_all(indices)
|
325 |
+
indices = indices.reshape(-1) # flatten again
|
326 |
+
|
327 |
+
# get quantized latent vectors
|
328 |
+
z_q = self.embedding(indices)
|
329 |
+
|
330 |
+
if shape is not None:
|
331 |
+
z_q = z_q.view(shape)
|
332 |
+
# reshape back to match original input shape
|
333 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
334 |
+
|
335 |
+
return z_q
|
Models/util.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, hashlib
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
URL_MAP = {
|
6 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
7 |
+
}
|
8 |
+
|
9 |
+
CKPT_MAP = {
|
10 |
+
"vgg_lpips": "vgg.pth"
|
11 |
+
}
|
12 |
+
|
13 |
+
MD5_MAP = {
|
14 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def download(url, local_path, chunk_size=1024):
|
19 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
20 |
+
with requests.get(url, stream=True) as r:
|
21 |
+
total_size = int(r.headers.get("content-length", 0))
|
22 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
23 |
+
with open(local_path, "wb") as f:
|
24 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
25 |
+
if data:
|
26 |
+
f.write(data)
|
27 |
+
pbar.update(chunk_size)
|
28 |
+
|
29 |
+
|
30 |
+
def md5_hash(path):
|
31 |
+
with open(path, "rb") as f:
|
32 |
+
content = f.read()
|
33 |
+
return hashlib.md5(content).hexdigest()
|
34 |
+
|
35 |
+
|
36 |
+
def get_ckpt_path(name, root, check=False):
|
37 |
+
assert name in URL_MAP
|
38 |
+
path = os.path.join(root, CKPT_MAP[name])
|
39 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
40 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
41 |
+
download(URL_MAP[name], path)
|
42 |
+
md5 = md5_hash(path)
|
43 |
+
assert md5 == MD5_MAP[name], md5
|
44 |
+
return path
|
45 |
+
|
46 |
+
|
47 |
+
class KeyNotFoundError(Exception):
|
48 |
+
def __init__(self, cause, keys=None, visited=None):
|
49 |
+
self.cause = cause
|
50 |
+
self.keys = keys
|
51 |
+
self.visited = visited
|
52 |
+
messages = list()
|
53 |
+
if keys is not None:
|
54 |
+
messages.append("Key not found: {}".format(keys))
|
55 |
+
if visited is not None:
|
56 |
+
messages.append("Visited: {}".format(visited))
|
57 |
+
messages.append("Cause:\n{}".format(cause))
|
58 |
+
message = "\n".join(messages)
|
59 |
+
super().__init__(message)
|
60 |
+
|
61 |
+
|
62 |
+
def retrieve(
|
63 |
+
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
64 |
+
):
|
65 |
+
"""Given a nested list or dict return the desired value at key expanding
|
66 |
+
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
67 |
+
is done in-place.
|
68 |
+
|
69 |
+
Parameters
|
70 |
+
----------
|
71 |
+
list_or_dict : list or dict
|
72 |
+
Possibly nested list or dictionary.
|
73 |
+
key : str
|
74 |
+
key/to/value, path like string describing all keys necessary to
|
75 |
+
consider to get to the desired value. List indices can also be
|
76 |
+
passed here.
|
77 |
+
splitval : str
|
78 |
+
String that defines the delimiter between keys of the
|
79 |
+
different depth levels in `key`.
|
80 |
+
default : obj
|
81 |
+
Value returned if :attr:`key` is not found.
|
82 |
+
expand : bool
|
83 |
+
Whether to expand callable nodes on the path or not.
|
84 |
+
|
85 |
+
Returns
|
86 |
+
-------
|
87 |
+
The desired value or if :attr:`default` is not ``None`` and the
|
88 |
+
:attr:`key` is not found returns ``default``.
|
89 |
+
|
90 |
+
Raises
|
91 |
+
------
|
92 |
+
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
93 |
+
``None``.
|
94 |
+
"""
|
95 |
+
|
96 |
+
keys = key.split(splitval)
|
97 |
+
|
98 |
+
success = True
|
99 |
+
try:
|
100 |
+
visited = []
|
101 |
+
parent = None
|
102 |
+
last_key = None
|
103 |
+
for key in keys:
|
104 |
+
if callable(list_or_dict):
|
105 |
+
if not expand:
|
106 |
+
raise KeyNotFoundError(
|
107 |
+
ValueError(
|
108 |
+
"Trying to get past callable node with expand=False."
|
109 |
+
),
|
110 |
+
keys=keys,
|
111 |
+
visited=visited,
|
112 |
+
)
|
113 |
+
list_or_dict = list_or_dict()
|
114 |
+
parent[last_key] = list_or_dict
|
115 |
+
|
116 |
+
last_key = key
|
117 |
+
parent = list_or_dict
|
118 |
+
|
119 |
+
try:
|
120 |
+
if isinstance(list_or_dict, dict):
|
121 |
+
list_or_dict = list_or_dict[key]
|
122 |
+
else:
|
123 |
+
list_or_dict = list_or_dict[int(key)]
|
124 |
+
except (KeyError, IndexError, ValueError) as e:
|
125 |
+
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
126 |
+
|
127 |
+
visited += [key]
|
128 |
+
# final expansion of retrieved value
|
129 |
+
if expand and callable(list_or_dict):
|
130 |
+
list_or_dict = list_or_dict()
|
131 |
+
parent[last_key] = list_or_dict
|
132 |
+
except KeyNotFoundError as e:
|
133 |
+
if default is None:
|
134 |
+
raise e
|
135 |
+
else:
|
136 |
+
list_or_dict = default
|
137 |
+
success = False
|
138 |
+
|
139 |
+
if not pass_success:
|
140 |
+
return list_or_dict
|
141 |
+
else:
|
142 |
+
return list_or_dict, success
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
config = {"keya": "a",
|
147 |
+
"keyb": "b",
|
148 |
+
"keyc":
|
149 |
+
{"cc1": 1,
|
150 |
+
"cc2": 2,
|
151 |
+
}
|
152 |
+
}
|
153 |
+
from omegaconf import OmegaConf
|
154 |
+
config = OmegaConf.create(config)
|
155 |
+
print(config)
|
156 |
+
retrieve(config, "keya")
|
157 |
+
|
app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from torchvision import transforms
|
5 |
+
from runner import MaskGIT
|
6 |
+
import numpy as np
|
7 |
+
import random
|
8 |
+
import torchvision.utils as vutils
|
9 |
+
|
10 |
+
|
11 |
+
class Args(argparse.Namespace):
|
12 |
+
data_folder = ""
|
13 |
+
vqgan_folder = "pretrained_maskgit/VQGAN"
|
14 |
+
writer_log = ""
|
15 |
+
data = ""
|
16 |
+
mask_value = 1024
|
17 |
+
seed = 1
|
18 |
+
channel = 3
|
19 |
+
num_workers = 0
|
20 |
+
iter = 0
|
21 |
+
global_epoch = 0
|
22 |
+
lr = 1e-4
|
23 |
+
drop_label = 0.1
|
24 |
+
resume = True
|
25 |
+
device = "cpu"
|
26 |
+
print(device)
|
27 |
+
debug = True
|
28 |
+
test_only = False
|
29 |
+
is_master = True
|
30 |
+
is_multi_gpus = False
|
31 |
+
vit_size = "base"
|
32 |
+
vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
|
33 |
+
img_size = 256
|
34 |
+
patch_size = 256 // 16
|
35 |
+
|
36 |
+
|
37 |
+
def set_seed(seed):
|
38 |
+
if seed > 0:
|
39 |
+
torch.manual_seed(seed)
|
40 |
+
torch.cuda.manual_seed(seed)
|
41 |
+
np.random.seed(seed)
|
42 |
+
random.seed(seed)
|
43 |
+
torch.backends.cudnn.enable = False
|
44 |
+
torch.backends.cudnn.deterministic = True
|
45 |
+
|
46 |
+
args = Args()
|
47 |
+
maskgit = MaskGIT(args)
|
48 |
+
|
49 |
+
|
50 |
+
# Function to perform image synthesis
|
51 |
+
def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1):
|
52 |
+
# Perform image synthesis using your model
|
53 |
+
set_seed(seed)
|
54 |
+
with torch.no_grad():
|
55 |
+
labels = [cls] * nb_img
|
56 |
+
labels = torch.LongTensor(labels).to(args.device)
|
57 |
+
gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w,
|
58 |
+
randomize="linear", r_temp=r_temp, sched_mode="arccos",
|
59 |
+
step=step)[0]
|
60 |
+
|
61 |
+
# Post-process the output image (adjust based on your needs)
|
62 |
+
output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True))
|
63 |
+
|
64 |
+
return output_image
|
65 |
+
|
66 |
+
|
67 |
+
# Gradio Interface
|
68 |
+
app = gr.Interface(
|
69 |
+
fn=synthesize_image,
|
70 |
+
inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16),
|
71 |
+
gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)],
|
72 |
+
outputs=gr.Image(),
|
73 |
+
title="Image Synthesis using MaskGIT",
|
74 |
+
)
|
75 |
+
|
76 |
+
# Launch the Gradio app
|
77 |
+
app.launch(share=True)
|
78 |
+
|
runner.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Trainer for MaskGIT
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import math
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
13 |
+
|
14 |
+
from Models.models.transformer import MaskTransformer
|
15 |
+
from Models.models.vqgan import VQModel
|
16 |
+
|
17 |
+
|
18 |
+
class MaskGIT(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, args):
|
21 |
+
""" Initialization of the model (VQGAN and Masked Transformer), optimizer, criterion, etc."""
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.args = args # Main argument see main.py
|
25 |
+
self.patch_size = self.args.img_size // 16 # Number of vizual token (+1 for the class)
|
26 |
+
self.scaler = torch.cuda.amp.GradScaler() # Init Scaler for multi GPUs
|
27 |
+
self.vit = self.get_network("vit") # Load Masked Bidirectional Transformer
|
28 |
+
self.ae = self.get_network("autoencoder") # Load VQGAN
|
29 |
+
|
30 |
+
def get_network(self, archi):
|
31 |
+
""" return the network, load checkpoint if self.args.resume == True
|
32 |
+
:param
|
33 |
+
archi -> str: vit|autoencoder, the architecture to load
|
34 |
+
:return
|
35 |
+
model -> nn.Module: the network
|
36 |
+
"""
|
37 |
+
if archi == "vit":
|
38 |
+
if self.args.vit_size == "base":
|
39 |
+
model = MaskTransformer(
|
40 |
+
img_size=self.args.img_size, hidden_dim=768, codebook_size=1024, depth=24, heads=16, mlp_dim=3072, dropout=0.1 # Small
|
41 |
+
)
|
42 |
+
elif self.args.vit_size == "big":
|
43 |
+
model = MaskTransformer(
|
44 |
+
img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=32, heads=16, mlp_dim=3072, dropout=0.1 # Big
|
45 |
+
)
|
46 |
+
elif self.args.vit_size == "huge":
|
47 |
+
model = MaskTransformer(
|
48 |
+
img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=48, heads=16, mlp_dim=3072, dropout=0.1 # Huge
|
49 |
+
)
|
50 |
+
|
51 |
+
if self.args.resume:
|
52 |
+
ckpt = self.args.vit_folder
|
53 |
+
ckpt += "current.pth" if os.path.isdir(self.args.vit_folder) else ""
|
54 |
+
if self.args.is_master:
|
55 |
+
print("load ckpt from:", ckpt)
|
56 |
+
# Read checkpoint file
|
57 |
+
checkpoint = torch.load(ckpt, map_location='cpu')
|
58 |
+
# Load network
|
59 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
60 |
+
|
61 |
+
model = model.to(self.args.device)
|
62 |
+
|
63 |
+
if self.args.is_multi_gpus: # put model on multi GPUs if available
|
64 |
+
model = DDP(model, device_ids=[self.args.device])
|
65 |
+
|
66 |
+
elif archi == "autoencoder":
|
67 |
+
# Load config
|
68 |
+
config = OmegaConf.load(os.path.join(self.args.vqgan_folder, "model.yaml"))
|
69 |
+
model = VQModel(**config.model.params)
|
70 |
+
checkpoint = torch.load(os.path.join(self.args.vqgan_folder, "last.ckpt"), map_location="cpu")["state_dict"]
|
71 |
+
# Load network
|
72 |
+
model.load_state_dict(checkpoint, strict=False)
|
73 |
+
model = model.eval()
|
74 |
+
model = model.to(self.args.device)
|
75 |
+
|
76 |
+
if self.args.is_multi_gpus: # put model on multi GPUs if available
|
77 |
+
model = DDP(model, device_ids=[self.args.device])
|
78 |
+
model = model.module
|
79 |
+
else:
|
80 |
+
model = None
|
81 |
+
|
82 |
+
if self.args.is_master:
|
83 |
+
print(f"Size of model {archi}: "
|
84 |
+
f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M")
|
85 |
+
|
86 |
+
return model
|
87 |
+
|
88 |
+
def adap_sche(self, step, mode="arccos", leave=False):
|
89 |
+
""" Create a sampling scheduler
|
90 |
+
:param
|
91 |
+
step -> int: number of prediction during inference
|
92 |
+
mode -> str: the rate of value to unmask
|
93 |
+
leave -> bool: tqdm arg on either to keep the bar or not
|
94 |
+
:return
|
95 |
+
scheduler -> torch.LongTensor(): the list of token to predict at each step
|
96 |
+
"""
|
97 |
+
r = torch.linspace(1, 0, step)
|
98 |
+
if mode == "root": # root scheduler
|
99 |
+
val_to_mask = 1 - (r ** .5)
|
100 |
+
elif mode == "linear": # linear scheduler
|
101 |
+
val_to_mask = 1 - r
|
102 |
+
elif mode == "square": # square scheduler
|
103 |
+
val_to_mask = 1 - (r ** 2)
|
104 |
+
elif mode == "cosine": # cosine scheduler
|
105 |
+
val_to_mask = torch.cos(r * math.pi * 0.5)
|
106 |
+
elif mode == "arccos": # arc cosine scheduler
|
107 |
+
val_to_mask = torch.arccos(r) / (math.pi * 0.5)
|
108 |
+
else:
|
109 |
+
return
|
110 |
+
|
111 |
+
# fill the scheduler by the ratio of tokens to predict at each step
|
112 |
+
sche = (val_to_mask / val_to_mask.sum()) * (self.patch_size * self.patch_size)
|
113 |
+
sche = sche.round()
|
114 |
+
sche[sche == 0] = 1 # add 1 to predict a least 1 token / step
|
115 |
+
sche[-1] += (self.patch_size * self.patch_size) - sche.sum() # need to sum up nb of code
|
116 |
+
return tqdm(sche.int(), leave=leave)
|
117 |
+
|
118 |
+
def sample(self, init_code=None, nb_sample=50, labels=None, sm_temp=1, w=3,
|
119 |
+
randomize="linear", r_temp=4.5, sched_mode="arccos", step=12):
|
120 |
+
""" Generate sample with the MaskGIT model
|
121 |
+
:param
|
122 |
+
init_code -> torch.LongTensor: nb_sample x 16 x 16, the starting initialization code
|
123 |
+
nb_sample -> int: the number of image to generated
|
124 |
+
labels -> torch.LongTensor: the list of classes to generate
|
125 |
+
sm_temp -> float: the temperature before softmax
|
126 |
+
w -> float: scale for the classifier free guidance
|
127 |
+
randomize -> str: linear|warm_up|random|no, either or not to add randomness
|
128 |
+
r_temp -> float: temperature for the randomness
|
129 |
+
sched_mode -> str: root|linear|square|cosine|arccos, the shape of the scheduler
|
130 |
+
step: -> int: number of step for the decoding
|
131 |
+
:return
|
132 |
+
x -> torch.FloatTensor: nb_sample x 3 x 256 x 256, the generated images
|
133 |
+
code -> torch.LongTensor: nb_sample x step x 16 x 16, the code corresponding to the generated images
|
134 |
+
"""
|
135 |
+
self.vit.eval()
|
136 |
+
l_codes = [] # Save the intermediate codes predicted
|
137 |
+
l_mask = [] # Save the intermediate masks
|
138 |
+
with torch.no_grad():
|
139 |
+
if labels is None: # Default classes generated
|
140 |
+
# goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear, random
|
141 |
+
labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, random.randint(0, 999)] * (nb_sample // 10)
|
142 |
+
labels = torch.LongTensor(labels).to(self.args.device)
|
143 |
+
|
144 |
+
drop = torch.ones(nb_sample, dtype=torch.bool).to(self.args.device)
|
145 |
+
if init_code is not None: # Start with a pre-define code
|
146 |
+
code = init_code
|
147 |
+
mask = (init_code == 1024).float().view(nb_sample, self.patch_size*self.patch_size)
|
148 |
+
else: # Initialize a code
|
149 |
+
if self.args.mask_value < 0: # Code initialize with random tokens
|
150 |
+
code = torch.randint(0, 1024, (nb_sample, self.patch_size, self.patch_size)).to(self.args.device)
|
151 |
+
else: # Code initialize with masked tokens
|
152 |
+
code = torch.full((nb_sample, self.patch_size, self.patch_size), self.args.mask_value).to(self.args.device)
|
153 |
+
mask = torch.ones(nb_sample, self.patch_size*self.patch_size).to(self.args.device)
|
154 |
+
|
155 |
+
# Instantiate scheduler
|
156 |
+
if isinstance(sched_mode, str): # Standard ones
|
157 |
+
scheduler = self.adap_sche(step, mode=sched_mode)
|
158 |
+
else: # Custom one
|
159 |
+
scheduler = sched_mode
|
160 |
+
|
161 |
+
# Beginning of sampling, t = number of token to predict a step "indice"
|
162 |
+
for indice, t in enumerate(scheduler):
|
163 |
+
if mask.sum() < t: # Cannot predict more token than 16*16 or 32*32
|
164 |
+
t = int(mask.sum().item())
|
165 |
+
|
166 |
+
if mask.sum() == 0: # Break if code is fully predicted
|
167 |
+
break
|
168 |
+
|
169 |
+
with torch.cuda.amp.autocast(): # half precision
|
170 |
+
if w != 0:
|
171 |
+
# Model Prediction
|
172 |
+
logit = self.vit(torch.cat([code.clone(), code.clone()], dim=0),
|
173 |
+
torch.cat([labels, labels], dim=0),
|
174 |
+
torch.cat([~drop, drop], dim=0))
|
175 |
+
logit_c, logit_u = torch.chunk(logit, 2, dim=0)
|
176 |
+
_w = w * (indice / (len(scheduler)-1))
|
177 |
+
# Classifier Free Guidance
|
178 |
+
logit = (1 + _w) * logit_c - _w * logit_u
|
179 |
+
else:
|
180 |
+
logit = self.vit(code.clone(), labels, drop_label=~drop)
|
181 |
+
|
182 |
+
prob = torch.softmax(logit * sm_temp, -1)
|
183 |
+
# Sample the code from the softmax prediction
|
184 |
+
distri = torch.distributions.Categorical(probs=prob)
|
185 |
+
pred_code = distri.sample()
|
186 |
+
|
187 |
+
conf = torch.gather(prob, 2, pred_code.view(nb_sample, self.patch_size*self.patch_size, 1))
|
188 |
+
|
189 |
+
if randomize == "linear": # add gumbel noise decreasing over the sampling process
|
190 |
+
ratio = (indice / len(scheduler))
|
191 |
+
rand = r_temp * np.random.gumbel(size=(nb_sample, self.patch_size*self.patch_size)) * (1 - ratio)
|
192 |
+
conf = torch.log(conf.squeeze()) + torch.from_numpy(rand).to(self.args.device)
|
193 |
+
elif randomize == "warm_up": # chose random sample for the 2 first steps
|
194 |
+
conf = torch.rand_like(conf) if indice < 2 else conf
|
195 |
+
elif randomize == "random": # chose random prediction at each step
|
196 |
+
conf = torch.rand_like(conf)
|
197 |
+
|
198 |
+
# do not predict on already predicted tokens
|
199 |
+
conf[~mask.bool()] = -math.inf
|
200 |
+
|
201 |
+
# chose the predicted token with the highest confidence
|
202 |
+
tresh_conf, indice_mask = torch.topk(conf.view(nb_sample, -1), k=t, dim=-1)
|
203 |
+
tresh_conf = tresh_conf[:, -1]
|
204 |
+
|
205 |
+
# replace the chosen tokens
|
206 |
+
conf = (conf >= tresh_conf.unsqueeze(-1)).view(nb_sample, self.patch_size, self.patch_size)
|
207 |
+
f_mask = (mask.view(nb_sample, self.patch_size, self.patch_size).float() * conf.view(nb_sample, self.patch_size, self.patch_size).float()).bool()
|
208 |
+
code[f_mask] = pred_code.view(nb_sample, self.patch_size, self.patch_size)[f_mask]
|
209 |
+
|
210 |
+
# update the mask
|
211 |
+
for i_mask, ind_mask in enumerate(indice_mask):
|
212 |
+
mask[i_mask, ind_mask] = 0
|
213 |
+
l_codes.append(pred_code.view(nb_sample, self.patch_size, self.patch_size).clone())
|
214 |
+
l_mask.append(mask.view(nb_sample, self.patch_size, self.patch_size).clone())
|
215 |
+
|
216 |
+
# decode the final prediction
|
217 |
+
_code = torch.clamp(code, 0, 1023) # VQGAN has only 1024 codebook
|
218 |
+
x = self.ae.decode_code(_code)
|
219 |
+
x = (torch.clamp(x, -1, 1) + 1) / 2
|
220 |
+
self.vit.train()
|
221 |
+
return x, l_codes, l_mask
|