llvictorll commited on
Commit
8513f87
·
verified ·
1 Parent(s): 40d1980

add gradio app

Browse files
Models/__init__.py ADDED
File without changes
Models/models/__init__.py ADDED
File without changes
Models/models/transformer.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BERT architecture for the Masked Bidirectional Encoder Transformer
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class PreNorm(nn.Module):
7
+ def __init__(self, dim, fn):
8
+ super().__init__()
9
+ self.norm = nn.LayerNorm(dim)
10
+ self.fn = fn
11
+
12
+ def forward(self, x, **kwargs):
13
+ return self.fn(self.norm(x), **kwargs)
14
+
15
+
16
+ class FeedForward(nn.Module):
17
+ def __init__(self, dim, hidden_dim, dropout=0.):
18
+ super().__init__()
19
+ self.net = nn.Sequential(
20
+ nn.Linear(dim, hidden_dim, bias=True),
21
+ nn.GELU(),
22
+ nn.Dropout(dropout),
23
+ nn.Linear(hidden_dim, dim, bias=True),
24
+ nn.Dropout(dropout)
25
+ )
26
+
27
+ def forward(self, x):
28
+ return self.net(x)
29
+
30
+
31
+ class Attention(nn.Module):
32
+ def __init__(self, embed_dim, num_heads, dropout=0.):
33
+ super(Attention, self).__init__()
34
+ self.dim = embed_dim
35
+ self.mha = nn.MultiheadAttention(embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True, bias=True)
36
+
37
+ def forward(self, x):
38
+ attention_value, attention_weight = self.mha(x, x, x)
39
+ return attention_value, attention_weight
40
+
41
+
42
+ class TransformerEncoder(nn.Module):
43
+ def __init__(self, dim, depth, heads, mlp_dim, dropout=0.):
44
+ super().__init__()
45
+ self.layers = nn.ModuleList([])
46
+ for _ in range(depth):
47
+ self.layers.append(nn.ModuleList([
48
+ PreNorm(dim, Attention(dim, heads, dropout=dropout)),
49
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
50
+ ]))
51
+
52
+ def forward(self, x):
53
+ l_attn = []
54
+ for attn, ff in self.layers:
55
+ attention_value, attention_weight = attn(x)
56
+ x = attention_value + x
57
+ x = ff(x) + x
58
+ l_attn.append(attention_weight)
59
+ return x, l_attn
60
+
61
+
62
+ class MaskTransformer(nn.Module):
63
+ def __init__(self, img_size=256, hidden_dim=768, codebook_size=1024, depth=24, heads=8, mlp_dim=3072, dropout=0.1, nclass=1000):
64
+ super().__init__()
65
+
66
+ self.nclass = nclass
67
+ self.patch_size = img_size // 16
68
+ self.codebook_size = codebook_size
69
+ self.tok_emb = nn.Embedding(codebook_size+1+nclass+1, hidden_dim) # +1 for the mask of the viz token, +1 for mask of the class
70
+ # self.msk_emb = nn.Embedding(2, hidden_dim)
71
+ self.pos_emb = nn.init.trunc_normal_(nn.Parameter(torch.zeros(1, (self.patch_size*self.patch_size)+1, hidden_dim)), 0., 0.02)
72
+ self.first_layer = nn.Sequential(
73
+ nn.LayerNorm(hidden_dim, eps=1e-12),
74
+ nn.Dropout(p=dropout),
75
+ nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
76
+ nn.GELU(),
77
+ nn.LayerNorm(hidden_dim, eps=1e-12),
78
+ nn.Dropout(p=dropout),
79
+ nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
80
+ )
81
+
82
+ self.transformer = TransformerEncoder(dim=hidden_dim, depth=depth, heads=heads, mlp_dim=mlp_dim, dropout=dropout)
83
+
84
+ self.last_layer = nn.Sequential(
85
+ nn.LayerNorm(hidden_dim, eps=1e-12),
86
+ nn.Dropout(p=dropout),
87
+ nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
88
+ nn.GELU(),
89
+ nn.LayerNorm(hidden_dim, eps=1e-12),
90
+ )
91
+
92
+ self.bias = nn.Parameter(torch.zeros((self.patch_size*self.patch_size)+1, codebook_size+1+nclass+1))
93
+
94
+ def forward(self, img_token, y=None, drop_label=None, return_attn=False): # , masking_flag=None):
95
+ b, w, h = img_token.size()
96
+
97
+ cls_token = y.view(b, -1) + self.codebook_size + 1
98
+ cls_token[drop_label] = self.codebook_size + 1 + self.nclass
99
+ input = torch.cat([img_token.view(b, -1), cls_token.view(b, -1)], -1)
100
+ tok_embeddings = self.tok_emb(input)
101
+ pos_embeddings = self.pos_emb
102
+ x = tok_embeddings + pos_embeddings
103
+
104
+ # if masking_flag is not None:
105
+ # flag = torch.cat([masking_flag.view(b, -1), torch.zeros_like(cls_token.view(b, -1))], -1)
106
+ # x += self.msk_emb(flag)
107
+
108
+ x = self.first_layer(x)
109
+ x, attn = self.transformer(x)
110
+ x = self.last_layer(x)
111
+
112
+ logit = torch.matmul(x, self.tok_emb.weight.T) + self.bias
113
+
114
+ if return_attn:
115
+ return logit[:, :self.patch_size * self.patch_size, :self.codebook_size + 1], attn
116
+
117
+ return logit[:, :self.patch_size*self.patch_size, :self.codebook_size+1]
Models/models/vqgan.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+
6
+
7
+ from Models.modules.diffusionmodules.model import Encoder, Decoder
8
+ from Models.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
9
+ from Models.modules.vqvae.quantize import GumbelQuantize
10
+
11
+
12
+ def get_obj_from_str(string, reload=False):
13
+ module, cls = string.rsplit(".", 1)
14
+ if reload:
15
+ module_imp = importlib.import_module(module)
16
+ importlib.reload(module_imp)
17
+ return getattr(importlib.import_module(module, package=None), cls)
18
+
19
+
20
+ def instantiate_from_config(config):
21
+ if not "target" in config:
22
+ raise KeyError("Expected key `target` to instantiate.")
23
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
24
+
25
+
26
+ class VQModel(pl.LightningModule):
27
+ def __init__(self,
28
+ ddconfig,
29
+ lossconfig,
30
+ n_embed,
31
+ embed_dim,
32
+ ckpt_path=None,
33
+ ignore_keys=[],
34
+ image_key="image",
35
+ colorize_nlabels=None,
36
+ monitor=None,
37
+ remap=None,
38
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
39
+ ):
40
+ super().__init__()
41
+ self.image_key = image_key
42
+ self.encoder = Encoder(**ddconfig)
43
+ self.decoder = Decoder(**ddconfig)
44
+ # self.loss = instantiate_from_config(lossconfig)
45
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
46
+ remap=remap, sane_index_shape=sane_index_shape)
47
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
48
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
49
+ if ckpt_path is not None:
50
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51
+ self.image_key = image_key
52
+ if colorize_nlabels is not None:
53
+ assert type(colorize_nlabels) == int
54
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
55
+ if monitor is not None:
56
+ self.monitor = monitor
57
+
58
+ def init_from_ckpt(self, path, ignore_keys=list()):
59
+ sd = torch.load(path, map_location="cpu")["state_dict"]
60
+ keys = list(sd.keys())
61
+ for k in keys:
62
+ for ik in ignore_keys:
63
+ if k.startswith(ik):
64
+ print("Deleting key {} from state_dict.".format(k))
65
+ del sd[k]
66
+ self.load_state_dict(sd, strict=False)
67
+ print(f"Restored from {path}")
68
+
69
+ def encode(self, x):
70
+ h = self.encoder(x)
71
+ h = self.quant_conv(h)
72
+ quant, emb_loss, info = self.quantize(h)
73
+ return quant, emb_loss, info
74
+
75
+ def decode(self, quant):
76
+ quant = self.post_quant_conv(quant)
77
+ dec = self.decoder(quant)
78
+ return dec
79
+
80
+ def decode_code(self, code_b):
81
+ quant_b = self.quantize.get_codebook_entry(code_b.view(-1), (code_b.size(0), code_b.size(1), code_b.size(2), 256))
82
+ dec = self.decode(quant_b)
83
+ return dec
84
+
85
+ def forward(self, input):
86
+ quant, diff, _ = self.encode(input)
87
+ dec = self.decode(quant)
88
+ return dec, diff
89
+
90
+ def get_input(self, batch, k):
91
+ x = batch[k]
92
+ if len(x.shape) == 3:
93
+ x = x[..., None]
94
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
95
+ return x.float()
96
+
97
+ def training_step(self, batch, batch_idx, optimizer_idx):
98
+ x = self.get_input(batch, self.image_key)
99
+ xrec, qloss = self(x)
100
+
101
+ if optimizer_idx == 0:
102
+ # autoencode
103
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
104
+ last_layer=self.get_last_layer(), split="train")
105
+
106
+ self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
107
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
108
+ return aeloss
109
+
110
+ if optimizer_idx == 1:
111
+ # discriminator
112
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
113
+ last_layer=self.get_last_layer(), split="train")
114
+ self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
115
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
116
+ return discloss
117
+
118
+ def validation_step(self, batch, batch_idx):
119
+ x = self.get_input(batch, self.image_key)
120
+ xrec, qloss = self(x)
121
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
122
+ last_layer=self.get_last_layer(), split="val")
123
+
124
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
125
+ last_layer=self.get_last_layer(), split="val")
126
+ rec_loss = log_dict_ae["val/rec_loss"]
127
+ self.log("val/rec_loss", rec_loss,
128
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
129
+ self.log("val/aeloss", aeloss,
130
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
131
+ self.log_dict(log_dict_ae)
132
+ self.log_dict(log_dict_disc)
133
+ return self.log_dict
134
+
135
+ def configure_optimizers(self):
136
+ lr = self.learning_rate
137
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
138
+ list(self.decoder.parameters())+
139
+ list(self.quantize.parameters())+
140
+ list(self.quant_conv.parameters())+
141
+ list(self.post_quant_conv.parameters()),
142
+ lr=lr, betas=(0.5, 0.9))
143
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
144
+ lr=lr, betas=(0.5, 0.9))
145
+ return [opt_ae, opt_disc], []
146
+
147
+ def get_last_layer(self):
148
+ return self.decoder.conv_out.weight
149
+
150
+ def log_images(self, batch, **kwargs):
151
+ log = dict()
152
+ x = self.get_input(batch, self.image_key)
153
+ x = x.to(self.device)
154
+ xrec, _ = self(x)
155
+ if x.shape[1] > 3:
156
+ # colorize with random projection
157
+ assert xrec.shape[1] > 3
158
+ x = self.to_rgb(x)
159
+ xrec = self.to_rgb(xrec)
160
+ log["inputs"] = x
161
+ log["reconstructions"] = xrec
162
+ return log
163
+
164
+ def to_rgb(self, x):
165
+ assert self.image_key == "segmentation"
166
+ if not hasattr(self, "colorize"):
167
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
168
+ x = F.conv2d(x, weight=self.colorize)
169
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
170
+ return x
171
+
172
+
173
+ class GumbelVQ(VQModel):
174
+ def __init__(self,
175
+ ddconfig,
176
+ lossconfig,
177
+ n_embed,
178
+ embed_dim,
179
+ temperature_scheduler_config,
180
+ ckpt_path=None,
181
+ ignore_keys=[],
182
+ image_key="image",
183
+ colorize_nlabels=None,
184
+ monitor=None,
185
+ kl_weight=1e-8,
186
+ remap=None,
187
+ ):
188
+
189
+ z_channels = ddconfig["z_channels"]
190
+ super().__init__(ddconfig,
191
+ lossconfig,
192
+ n_embed,
193
+ embed_dim,
194
+ ckpt_path=None,
195
+ ignore_keys=ignore_keys,
196
+ image_key=image_key,
197
+ colorize_nlabels=colorize_nlabels,
198
+ monitor=monitor,
199
+ )
200
+
201
+ # self.loss.n_classes = n_embed
202
+ self.vocab_size = n_embed
203
+
204
+ self.quantize = GumbelQuantize(z_channels, embed_dim,
205
+ n_embed=n_embed,
206
+ kl_weight=kl_weight, temp_init=1.0,
207
+ remap=remap)
208
+
209
+ # self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
210
+
211
+ if ckpt_path is not None:
212
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
213
+
214
+ def temperature_scheduling(self):
215
+ self.quantize.temperature = self.temperature_scheduler(self.global_step)
216
+
217
+ def encode_to_prequant(self, x):
218
+ h = self.encoder(x)
219
+ h = self.quant_conv(h)
220
+ return h
221
+
222
+ def decode_code(self, code_b):
223
+ quant_b = self.quantize.get_codebook_entry(code_b.view(-1), (code_b.size(0), 32, 32, 8192))
224
+ dec = self.decode(quant_b)
225
+ return dec
226
+
227
+ def training_step(self, batch, batch_idx, optimizer_idx):
228
+ self.temperature_scheduling()
229
+ x = self.get_input(batch, self.image_key)
230
+ xrec, qloss = self(x)
231
+
232
+ if optimizer_idx == 0:
233
+ # autoencoder
234
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
235
+ last_layer=self.get_last_layer(), split="train")
236
+
237
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
238
+ self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
239
+ return aeloss
240
+
241
+ if optimizer_idx == 1:
242
+ # discriminator
243
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
244
+ last_layer=self.get_last_layer(), split="train")
245
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
246
+ return discloss
247
+
248
+ def validation_step(self, batch, batch_idx):
249
+ x = self.get_input(batch, self.image_key)
250
+ xrec, qloss = self(x, return_pred_indices=True)
251
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
252
+ last_layer=self.get_last_layer(), split="val")
253
+
254
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
255
+ last_layer=self.get_last_layer(), split="val")
256
+ rec_loss = log_dict_ae["val/rec_loss"]
257
+ self.log("val/rec_loss", rec_loss,
258
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
259
+ self.log("val/aeloss", aeloss,
260
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
261
+ self.log_dict(log_dict_ae)
262
+ self.log_dict(log_dict_disc)
263
+ return self.log_dict
264
+
265
+ def log_images(self, batch, **kwargs):
266
+ log = dict()
267
+ x = self.get_input(batch, self.image_key)
268
+ x = x.to(self.device)
269
+ # encode
270
+ h = self.encoder(x)
271
+ h = self.quant_conv(h)
272
+ quant, _, _ = self.quantize(h)
273
+ # decode
274
+ x_rec = self.decode(quant)
275
+ log["inputs"] = x
276
+ log["reconstructions"] = x_rec
277
+ return log
278
+
279
+ def reco(self, x): # , batch, **kwargs):
280
+ # log = dict()
281
+ # x = self.get_input(batch, self.image_key)
282
+ # x = x.to(self.device)
283
+ # encode
284
+ h = self.encoder(x)
285
+ # print(h, h.size())
286
+ h = self.quant_conv(h)
287
+ quant, _, _ = self.quantize(h)
288
+ print(quant, quant.size())
289
+ exit()
290
+ # decode
291
+ x_rec = self.decode(quant)
292
+ # log["inputs"] = x
293
+ # log["reconstructions"] = x_rec
294
+ return x_rec
Models/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ def get_timestep_embedding(timesteps, embedding_dim):
8
+ """
9
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
10
+ From Fairseq.
11
+ Build sinusoidal embeddings.
12
+ This matches the implementation in tensor2tensor, but differs slightly
13
+ from the description in Section 3.5 of "Attention Is All You Need".
14
+ """
15
+ assert len(timesteps.shape) == 1
16
+
17
+ half_dim = embedding_dim // 2
18
+ emb = math.log(10000) / (half_dim - 1)
19
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
20
+ emb = emb.to(device=timesteps.device)
21
+ emb = timesteps.float()[:, None] * emb[None, :]
22
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
23
+ if embedding_dim % 2 == 1: # zero pad
24
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
25
+ return emb
26
+
27
+
28
+ def nonlinearity(x):
29
+ # swish
30
+ return x*torch.sigmoid(x)
31
+
32
+
33
+ def Normalize(in_channels):
34
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
35
+
36
+
37
+ class Upsample(nn.Module):
38
+ def __init__(self, in_channels, with_conv):
39
+ super().__init__()
40
+ self.with_conv = with_conv
41
+ if self.with_conv:
42
+ self.conv = torch.nn.Conv2d(in_channels,
43
+ in_channels,
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1)
47
+
48
+ def forward(self, x):
49
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
50
+ if self.with_conv:
51
+ x = self.conv(x)
52
+ return x
53
+
54
+
55
+ class Downsample(nn.Module):
56
+ def __init__(self, in_channels, with_conv):
57
+ super().__init__()
58
+ self.with_conv = with_conv
59
+ if self.with_conv:
60
+ # no asymmetric padding in torch conv, must do it ourselves
61
+ self.conv = torch.nn.Conv2d(in_channels,
62
+ in_channels,
63
+ kernel_size=3,
64
+ stride=2,
65
+ padding=0)
66
+
67
+ def forward(self, x):
68
+ if self.with_conv:
69
+ pad = (0,1,0,1)
70
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
71
+ x = self.conv(x)
72
+ else:
73
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
74
+ return x
75
+
76
+
77
+ class ResnetBlock(nn.Module):
78
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
79
+ dropout, temb_channels=512):
80
+ super().__init__()
81
+ self.in_channels = in_channels
82
+ out_channels = in_channels if out_channels is None else out_channels
83
+ self.out_channels = out_channels
84
+ self.use_conv_shortcut = conv_shortcut
85
+
86
+ self.norm1 = Normalize(in_channels)
87
+ self.conv1 = torch.nn.Conv2d(in_channels,
88
+ out_channels,
89
+ kernel_size=3,
90
+ stride=1,
91
+ padding=1)
92
+ if temb_channels > 0:
93
+ self.temb_proj = torch.nn.Linear(temb_channels,
94
+ out_channels)
95
+ self.norm2 = Normalize(out_channels)
96
+ self.dropout = torch.nn.Dropout(dropout)
97
+ self.conv2 = torch.nn.Conv2d(out_channels,
98
+ out_channels,
99
+ kernel_size=3,
100
+ stride=1,
101
+ padding=1)
102
+ if self.in_channels != self.out_channels:
103
+ if self.use_conv_shortcut:
104
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
105
+ out_channels,
106
+ kernel_size=3,
107
+ stride=1,
108
+ padding=1)
109
+ else:
110
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
111
+ out_channels,
112
+ kernel_size=1,
113
+ stride=1,
114
+ padding=0)
115
+
116
+ def forward(self, x, temb):
117
+ h = x
118
+ h = self.norm1(h)
119
+ h = nonlinearity(h)
120
+ h = self.conv1(h)
121
+
122
+ if temb is not None:
123
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
124
+
125
+ h = self.norm2(h)
126
+ h = nonlinearity(h)
127
+ h = self.dropout(h)
128
+ h = self.conv2(h)
129
+
130
+ if self.in_channels != self.out_channels:
131
+ if self.use_conv_shortcut:
132
+ x = self.conv_shortcut(x)
133
+ else:
134
+ x = self.nin_shortcut(x)
135
+
136
+ return x+h
137
+
138
+
139
+ class AttnBlock(nn.Module):
140
+ def __init__(self, in_channels):
141
+ super().__init__()
142
+ self.in_channels = in_channels
143
+
144
+ self.norm = Normalize(in_channels)
145
+ self.q = torch.nn.Conv2d(in_channels,
146
+ in_channels,
147
+ kernel_size=1,
148
+ stride=1,
149
+ padding=0)
150
+ self.k = torch.nn.Conv2d(in_channels,
151
+ in_channels,
152
+ kernel_size=1,
153
+ stride=1,
154
+ padding=0)
155
+ self.v = torch.nn.Conv2d(in_channels,
156
+ in_channels,
157
+ kernel_size=1,
158
+ stride=1,
159
+ padding=0)
160
+ self.proj_out = torch.nn.Conv2d(in_channels,
161
+ in_channels,
162
+ kernel_size=1,
163
+ stride=1,
164
+ padding=0)
165
+
166
+
167
+ def forward(self, x):
168
+ h_ = x
169
+ h_ = self.norm(h_)
170
+ q = self.q(h_)
171
+ k = self.k(h_)
172
+ v = self.v(h_)
173
+
174
+ # compute attention
175
+ b,c,h,w = q.shape
176
+ q = q.reshape(b,c,h*w)
177
+ q = q.permute(0,2,1) # b,hw,c
178
+ k = k.reshape(b,c,h*w) # b,c,hw
179
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
180
+ w_ = w_ * (int(c)**(-0.5))
181
+ w_ = torch.nn.functional.softmax(w_, dim=2)
182
+
183
+ # attend to values
184
+ v = v.reshape(b,c,h*w)
185
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
186
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
187
+ h_ = h_.reshape(b,c,h,w)
188
+
189
+ h_ = self.proj_out(h_)
190
+
191
+ return x+h_
192
+
193
+
194
+ class Encoder(nn.Module):
195
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
196
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
197
+ resolution, z_channels, double_z=True, **ignore_kwargs):
198
+ super().__init__()
199
+ self.ch = ch
200
+ self.temb_ch = 0
201
+ self.num_resolutions = len(ch_mult)
202
+ self.num_res_blocks = num_res_blocks
203
+ self.resolution = resolution
204
+ self.in_channels = in_channels
205
+
206
+ # downsampling
207
+ self.conv_in = torch.nn.Conv2d(in_channels,
208
+ self.ch,
209
+ kernel_size=3,
210
+ stride=1,
211
+ padding=1)
212
+
213
+ curr_res = resolution
214
+ in_ch_mult = (1,)+tuple(ch_mult)
215
+ self.down = nn.ModuleList()
216
+ for i_level in range(self.num_resolutions):
217
+ block = nn.ModuleList()
218
+ attn = nn.ModuleList()
219
+ block_in = ch*in_ch_mult[i_level]
220
+ block_out = ch*ch_mult[i_level]
221
+ for i_block in range(self.num_res_blocks):
222
+ block.append(ResnetBlock(in_channels=block_in,
223
+ out_channels=block_out,
224
+ temb_channels=self.temb_ch,
225
+ dropout=dropout))
226
+ block_in = block_out
227
+ if curr_res in attn_resolutions:
228
+ attn.append(AttnBlock(block_in))
229
+ down = nn.Module()
230
+ down.block = block
231
+ down.attn = attn
232
+ if i_level != self.num_resolutions-1:
233
+ down.downsample = Downsample(block_in, resamp_with_conv)
234
+ curr_res = curr_res // 2
235
+ self.down.append(down)
236
+
237
+ # middle
238
+ self.mid = nn.Module()
239
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
240
+ out_channels=block_in,
241
+ temb_channels=self.temb_ch,
242
+ dropout=dropout)
243
+ self.mid.attn_1 = AttnBlock(block_in)
244
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
245
+ out_channels=block_in,
246
+ temb_channels=self.temb_ch,
247
+ dropout=dropout)
248
+
249
+ # end
250
+ self.norm_out = Normalize(block_in)
251
+ self.conv_out = torch.nn.Conv2d(block_in,
252
+ 2*z_channels if double_z else z_channels,
253
+ kernel_size=3,
254
+ stride=1,
255
+ padding=1)
256
+
257
+
258
+ def forward(self, x):
259
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
260
+
261
+ # timestep embedding
262
+ temb = None
263
+
264
+ # downsampling
265
+ hs = [self.conv_in(x)]
266
+ for i_level in range(self.num_resolutions):
267
+ for i_block in range(self.num_res_blocks):
268
+ h = self.down[i_level].block[i_block](hs[-1], temb)
269
+ if len(self.down[i_level].attn) > 0:
270
+ h = self.down[i_level].attn[i_block](h)
271
+ hs.append(h)
272
+ if i_level != self.num_resolutions-1:
273
+ hs.append(self.down[i_level].downsample(hs[-1]))
274
+
275
+ # middle
276
+ h = hs[-1]
277
+ h = self.mid.block_1(h, temb)
278
+ h = self.mid.attn_1(h)
279
+ h = self.mid.block_2(h, temb)
280
+
281
+ # end
282
+ h = self.norm_out(h)
283
+ h = nonlinearity(h)
284
+ h = self.conv_out(h)
285
+ return h
286
+
287
+
288
+ class Decoder(nn.Module):
289
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
290
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
291
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
292
+ super().__init__()
293
+ self.ch = ch
294
+ self.temb_ch = 0
295
+ self.num_resolutions = len(ch_mult)
296
+ self.num_res_blocks = num_res_blocks
297
+ self.resolution = resolution
298
+ self.in_channels = in_channels
299
+ self.give_pre_end = give_pre_end
300
+
301
+ # compute in_ch_mult, block_in and curr_res at lowest res
302
+ in_ch_mult = (1,)+tuple(ch_mult)
303
+ block_in = ch*ch_mult[self.num_resolutions-1]
304
+ curr_res = resolution // 2**(self.num_resolutions-1)
305
+ self.z_shape = (1,z_channels,curr_res,curr_res)
306
+ # print("Working with z of shape {} = {} dimensions.".format(
307
+ # self.z_shape, np.prod(self.z_shape)))
308
+
309
+ # z to block_in
310
+ self.conv_in = torch.nn.Conv2d(z_channels,
311
+ block_in,
312
+ kernel_size=3,
313
+ stride=1,
314
+ padding=1)
315
+
316
+ # middle
317
+ self.mid = nn.Module()
318
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
319
+ out_channels=block_in,
320
+ temb_channels=self.temb_ch,
321
+ dropout=dropout)
322
+ self.mid.attn_1 = AttnBlock(block_in)
323
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
324
+ out_channels=block_in,
325
+ temb_channels=self.temb_ch,
326
+ dropout=dropout)
327
+
328
+ # upsampling
329
+ self.up = nn.ModuleList()
330
+ for i_level in reversed(range(self.num_resolutions)):
331
+ block = nn.ModuleList()
332
+ attn = nn.ModuleList()
333
+ block_out = ch*ch_mult[i_level]
334
+ for i_block in range(self.num_res_blocks+1):
335
+ block.append(ResnetBlock(in_channels=block_in,
336
+ out_channels=block_out,
337
+ temb_channels=self.temb_ch,
338
+ dropout=dropout))
339
+ block_in = block_out
340
+ if curr_res in attn_resolutions:
341
+ attn.append(AttnBlock(block_in))
342
+ up = nn.Module()
343
+ up.block = block
344
+ up.attn = attn
345
+ if i_level != 0:
346
+ up.upsample = Upsample(block_in, resamp_with_conv)
347
+ curr_res = curr_res * 2
348
+ self.up.insert(0, up) # prepend to get consistent order
349
+
350
+ # end
351
+ self.norm_out = Normalize(block_in)
352
+ self.conv_out = torch.nn.Conv2d(block_in,
353
+ out_ch,
354
+ kernel_size=3,
355
+ stride=1,
356
+ padding=1)
357
+
358
+ def forward(self, z):
359
+ self.last_z_shape = z.shape
360
+
361
+ # timestep embedding
362
+ temb = None
363
+
364
+ # z to block_in
365
+ h = self.conv_in(z)
366
+
367
+ # middle
368
+ h = self.mid.block_1(h, temb)
369
+ h = self.mid.attn_1(h)
370
+ h = self.mid.block_2(h, temb)
371
+
372
+ # upsampling
373
+ for i_level in reversed(range(self.num_resolutions)):
374
+ for i_block in range(self.num_res_blocks+1):
375
+ h = self.up[i_level].block[i_block](h, temb)
376
+ if len(self.up[i_level].attn) > 0:
377
+ h = self.up[i_level].attn[i_block](h)
378
+ if i_level != 0:
379
+ h = self.up[i_level].upsample(h)
380
+
381
+ # end
382
+ if self.give_pre_end:
383
+ return h
384
+
385
+ h = self.norm_out(h)
386
+ h = nonlinearity(h)
387
+ h = self.conv_out(h)
388
+ return h
389
+
390
+
391
+ class UpsampleDecoder(nn.Module):
392
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
393
+ super().__init__()
394
+ # upsampling
395
+ self.temb_ch = 0
396
+ self.num_resolutions = len(ch_mult)
397
+ self.num_res_blocks = num_res_blocks
398
+ block_in = in_channels
399
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
400
+ self.res_blocks = nn.ModuleList()
401
+ self.upsample_blocks = nn.ModuleList()
402
+ for i_level in range(self.num_resolutions):
403
+ res_block = []
404
+ block_out = ch * ch_mult[i_level]
405
+ for i_block in range(self.num_res_blocks + 1):
406
+ res_block.append(ResnetBlock(in_channels=block_in,
407
+ out_channels=block_out,
408
+ temb_channels=self.temb_ch,
409
+ dropout=dropout))
410
+ block_in = block_out
411
+ self.res_blocks.append(nn.ModuleList(res_block))
412
+ if i_level != self.num_resolutions - 1:
413
+ self.upsample_blocks.append(Upsample(block_in, True))
414
+ curr_res = curr_res * 2
415
+
416
+ # end
417
+ self.norm_out = Normalize(block_in)
418
+ self.conv_out = torch.nn.Conv2d(block_in,
419
+ out_channels,
420
+ kernel_size=3,
421
+ stride=1,
422
+ padding=1)
423
+
424
+ def forward(self, x):
425
+ # upsampling
426
+ h = x
427
+ for k, i_level in enumerate(range(self.num_resolutions)):
428
+ for i_block in range(self.num_res_blocks + 1):
429
+ h = self.res_blocks[i_level][i_block](h, None)
430
+ if i_level != self.num_resolutions - 1:
431
+ h = self.upsample_blocks[k](h)
432
+ h = self.norm_out(h)
433
+ h = nonlinearity(h)
434
+ h = self.conv_out(h)
435
+ return h
436
+
Models/modules/util.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def count_params(model):
6
+ total_params = sum(p.numel() for p in model.parameters())
7
+ return total_params
8
+
9
+
10
+ class ActNorm(nn.Module):
11
+ def __init__(self, num_features, logdet=False, affine=True,
12
+ allow_reverse_init=False):
13
+ assert affine
14
+ super().__init__()
15
+ self.logdet = logdet
16
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18
+ self.allow_reverse_init = allow_reverse_init
19
+
20
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21
+
22
+ def initialize(self, input):
23
+ with torch.no_grad():
24
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25
+ mean = (
26
+ flatten.mean(1)
27
+ .unsqueeze(1)
28
+ .unsqueeze(2)
29
+ .unsqueeze(3)
30
+ .permute(1, 0, 2, 3)
31
+ )
32
+ std = (
33
+ flatten.std(1)
34
+ .unsqueeze(1)
35
+ .unsqueeze(2)
36
+ .unsqueeze(3)
37
+ .permute(1, 0, 2, 3)
38
+ )
39
+
40
+ self.loc.data.copy_(-mean)
41
+ self.scale.data.copy_(1 / (std + 1e-6))
42
+
43
+ def forward(self, input, reverse=False):
44
+ if reverse:
45
+ return self.reverse(input)
46
+ if len(input.shape) == 2:
47
+ input = input[:,:,None,None]
48
+ squeeze = True
49
+ else:
50
+ squeeze = False
51
+
52
+ _, _, height, width = input.shape
53
+
54
+ if self.training and self.initialized.item() == 0:
55
+ self.initialize(input)
56
+ self.initialized.fill_(1)
57
+
58
+ h = self.scale * (input + self.loc)
59
+
60
+ if squeeze:
61
+ h = h.squeeze(-1).squeeze(-1)
62
+
63
+ if self.logdet:
64
+ log_abs = torch.log(torch.abs(self.scale))
65
+ logdet = height*width*torch.sum(log_abs)
66
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
67
+ return h, logdet
68
+
69
+ return h
70
+
71
+ def reverse(self, output):
72
+ if self.training and self.initialized.item() == 0:
73
+ if not self.allow_reverse_init:
74
+ raise RuntimeError(
75
+ "Initializing ActNorm in reverse direction is "
76
+ "disabled by default. Use allow_reverse_init=True to enable."
77
+ )
78
+ else:
79
+ self.initialize(output)
80
+ self.initialized.fill_(1)
81
+
82
+ if len(output.shape) == 2:
83
+ output = output[:,:,None,None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ h = output / self.scale - self.loc
89
+
90
+ if squeeze:
91
+ h = h.squeeze(-1).squeeze(-1)
92
+ return h
93
+
94
+
95
+ class AbstractEncoder(nn.Module):
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+ def encode(self, *args, **kwargs):
100
+ raise NotImplementedError
101
+
102
+
103
+ class Labelator(AbstractEncoder):
104
+ """Net2Net Interface for Class-Conditional Model"""
105
+ def __init__(self, n_classes, quantize_interface=True):
106
+ super().__init__()
107
+ self.n_classes = n_classes
108
+ self.quantize_interface = quantize_interface
109
+
110
+ def encode(self, c):
111
+ c = c[:,None]
112
+ if self.quantize_interface:
113
+ return c, None, [None, None, c.long()]
114
+ return c
115
+
116
+
117
+ class SOSProvider(AbstractEncoder):
118
+ # for unconditional training
119
+ def __init__(self, sos_token, quantize_interface=True):
120
+ super().__init__()
121
+ self.sos_token = sos_token
122
+ self.quantize_interface = quantize_interface
123
+
124
+ def encode(self, x):
125
+ # get batch size from data and replicate sos_token
126
+ c = torch.ones(x.shape[0], 1)*self.sos_token
127
+ c = c.long().to(x.device)
128
+ if self.quantize_interface:
129
+ return c, None, [None, None, c]
130
+ return c
Models/modules/vqvae/quantize.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from torch import einsum
6
+ from einops import rearrange
7
+
8
+
9
+ class VectorQuantizer(nn.Module):
10
+ """
11
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
12
+ ____________________________________________
13
+ Discretization bottleneck part of the VQ-VAE.
14
+ Inputs:
15
+ - n_e : number of embeddings
16
+ - e_dim : dimension of embedding
17
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
18
+ _____________________________________________
19
+ """
20
+
21
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
22
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
23
+ # used wherever VectorQuantizer has been used before and is additionally
24
+ # more efficient.
25
+ def __init__(self, n_e, e_dim, beta):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.n_e = n_e
28
+ self.e_dim = e_dim
29
+ self.beta = beta
30
+
31
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
32
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
33
+
34
+ def forward(self, z):
35
+ """
36
+ Inputs the output of the encoder network z and maps it to a discrete
37
+ one-hot vector that is the index of the closest embedding vector e_j
38
+ z (continuous) -> z_q (discrete)
39
+ z.shape = (batch, channel, height, width)
40
+ quantization pipeline:
41
+ 1. get encoder input (B,C,H,W)
42
+ 2. flatten input to (B*H*W,C)
43
+ """
44
+ # reshape z -> (batch, height, width, channel) and flatten
45
+ z = z.permute(0, 2, 3, 1).contiguous()
46
+ z_flattened = z.view(-1, self.e_dim)
47
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
48
+
49
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
50
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
51
+ torch.matmul(z_flattened, self.embedding.weight.t())
52
+
53
+ ## could possible replace this here
54
+ # #\start...
55
+ # find closest encodings
56
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
57
+
58
+ min_encodings = torch.zeros(
59
+ min_encoding_indices.shape[0], self.n_e).to(z)
60
+ min_encodings.scatter_(1, min_encoding_indices, 1)
61
+
62
+ # dtype min encodings: torch.float32
63
+ # min_encodings shape: torch.Size([2048, 512])
64
+ # min_encoding_indices.shape: torch.Size([2048, 1])
65
+
66
+ # get quantized latent vectors
67
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
68
+ # .........\end
69
+
70
+ # with:
71
+ # .........\start
72
+ # min_encoding_indices = torch.argmin(d, dim=1)
73
+ # z_q = self.embedding(min_encoding_indices)
74
+ # ......\end......... (TODO)
75
+
76
+ # compute loss for embedding
77
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
78
+ torch.mean((z_q - z.detach()) ** 2)
79
+
80
+ # preserve gradients
81
+ z_q = z + (z_q - z).detach()
82
+
83
+ # perplexity
84
+ e_mean = torch.mean(min_encodings, dim=0)
85
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
86
+
87
+ # reshape back to match original input shape
88
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
89
+
90
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
91
+
92
+ def get_codebook_entry(self, indices, shape):
93
+ # shape specifying (batch, height, width, channel)
94
+ # TODO: check for more easy handling with nn.Embedding
95
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
96
+ min_encodings.scatter_(1, indices[:, None], 1)
97
+
98
+ # get quantized latent vectors
99
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
100
+
101
+ if shape is not None:
102
+ z_q = z_q.view(shape)
103
+
104
+ # reshape back to match original input shape
105
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
106
+
107
+ return z_q
108
+
109
+
110
+ class GumbelQuantize(nn.Module):
111
+ """
112
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
113
+ Gumbel Softmax trick quantizer
114
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
115
+ https://arxiv.org/abs/1611.01144
116
+ """
117
+
118
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
119
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
120
+ remap=None, unknown_index="random"):
121
+ super().__init__()
122
+
123
+ self.embedding_dim = embedding_dim
124
+ self.n_embed = n_embed
125
+ print(n_embed)
126
+ self.straight_through = straight_through
127
+ self.temperature = temp_init
128
+ self.kl_weight = kl_weight
129
+
130
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
131
+ self.embed = nn.Embedding(n_embed, embedding_dim)
132
+
133
+ self.use_vqinterface = use_vqinterface
134
+
135
+ self.remap = remap
136
+
137
+ if self.remap is not None:
138
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
139
+ self.re_embed = self.used.shape[0]
140
+ self.unknown_index = unknown_index # "random" or "extra" or integer
141
+ if self.unknown_index == "extra":
142
+ self.unknown_index = self.re_embed
143
+ self.re_embed = self.re_embed + 1
144
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
145
+ f"Using {self.unknown_index} for unknown indices.")
146
+ else:
147
+ self.re_embed = n_embed
148
+
149
+ def remap_to_used(self, inds):
150
+ ishape = inds.shape
151
+ assert len(ishape) > 1
152
+ inds = inds.reshape(ishape[0], -1)
153
+ used = self.used.to(inds)
154
+ match = (inds[:, :, None] == used[None, None, ...]).long()
155
+ new = match.argmax(-1)
156
+ unknown = match.sum(2) < 1
157
+ if self.unknown_index == "random":
158
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
159
+ else:
160
+ new[unknown] = self.unknown_index
161
+ return new.reshape(ishape)
162
+
163
+ def unmap_to_all(self, inds):
164
+ ishape = inds.shape
165
+ assert len(ishape) > 1
166
+ inds = inds.reshape(ishape[0], -1)
167
+ used = self.used.to(inds)
168
+ if self.re_embed > self.used.shape[0]: # extra token
169
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
170
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
171
+ return back.reshape(ishape)
172
+
173
+ def forward(self, z, temp=None, return_logits=False):
174
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
175
+ hard = self.straight_through if self.training else True
176
+ temp = self.temperature if temp is None else temp
177
+
178
+ logits = self.proj(z)
179
+ if self.remap is not None:
180
+ # continue only with used logits
181
+ full_zeros = torch.zeros_like(logits)
182
+ logits = logits[:, self.used, ...]
183
+
184
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
185
+ if self.remap is not None:
186
+ # go back to all entries but unused set to zero
187
+ full_zeros[:, self.used, ...] = soft_one_hot
188
+ soft_one_hot = full_zeros
189
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
190
+
191
+ # + kl divergence to the prior loss
192
+ qy = F.softmax(logits, dim=1)
193
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
194
+
195
+ ind = soft_one_hot.argmax(dim=1)
196
+ if self.remap is not None:
197
+ ind = self.remap_to_used(ind)
198
+ if self.use_vqinterface:
199
+ if return_logits:
200
+ return z_q, diff, (None, None, ind), logits
201
+ return z_q, diff, (None, None, ind)
202
+ return z_q, diff, ind
203
+
204
+ def get_codebook_entry(self, indices, shape):
205
+ b, h, w, c = shape
206
+ assert b * h * w == indices.shape[0]
207
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
208
+ if self.remap is not None:
209
+ indices = self.unmap_to_all(indices)
210
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
211
+ # print(one_hot.size())
212
+ # exit()
213
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
214
+
215
+ return z_q
216
+
217
+
218
+ class VectorQuantizer2(nn.Module):
219
+ """
220
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
221
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
222
+ """
223
+
224
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
225
+ # backwards compatibility we use the buggy version by default, but you can
226
+ # specify legacy=False to fix it.
227
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
228
+ sane_index_shape=False, legacy=True):
229
+ super().__init__()
230
+ self.n_e = n_e
231
+ self.e_dim = e_dim
232
+ self.beta = beta
233
+ self.legacy = legacy
234
+
235
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
236
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
237
+
238
+ self.remap = remap
239
+ if self.remap is not None:
240
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
241
+ self.re_embed = self.used.shape[0]
242
+ self.unknown_index = unknown_index # "random" or "extra" or integer
243
+ if self.unknown_index == "extra":
244
+ self.unknown_index = self.re_embed
245
+ self.re_embed = self.re_embed + 1
246
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
247
+ f"Using {self.unknown_index} for unknown indices.")
248
+ else:
249
+ self.re_embed = n_e
250
+
251
+ self.sane_index_shape = sane_index_shape
252
+
253
+ def remap_to_used(self, inds):
254
+ ishape = inds.shape
255
+ assert len(ishape) > 1
256
+ inds = inds.reshape(ishape[0], -1)
257
+ used = self.used.to(inds)
258
+ match = (inds[:, :, None] == used[None, None, ...]).long()
259
+ new = match.argmax(-1)
260
+ unknown = match.sum(2) < 1
261
+ if self.unknown_index == "random":
262
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
263
+ else:
264
+ new[unknown] = self.unknown_index
265
+ return new.reshape(ishape)
266
+
267
+ def unmap_to_all(self, inds):
268
+ ishape = inds.shape
269
+ assert len(ishape) > 1
270
+ inds = inds.reshape(ishape[0], -1)
271
+ used = self.used.to(inds)
272
+ if self.re_embed > self.used.shape[0]: # extra token
273
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
274
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
275
+ return back.reshape(ishape)
276
+
277
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
278
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
279
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
280
+ assert return_logits == False, "Only for interface compatible with Gumbel"
281
+ # reshape z -> (batch, height, width, channel) and flatten
282
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
283
+ z_flattened = z.view(-1, self.e_dim)
284
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
285
+
286
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
287
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
288
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
289
+
290
+ min_encoding_indices = torch.argmin(d, dim=1)
291
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
292
+ perplexity = None
293
+ min_encodings = None
294
+
295
+ # compute loss for embedding
296
+ if not self.legacy:
297
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
298
+ torch.mean((z_q - z.detach()) ** 2)
299
+ else:
300
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
301
+ torch.mean((z_q - z.detach()) ** 2)
302
+
303
+ # preserve gradients
304
+ z_q = z + (z_q - z).detach()
305
+
306
+ # reshape back to match original input shape
307
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
308
+
309
+ if self.remap is not None:
310
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
311
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
312
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
313
+
314
+ if self.sane_index_shape:
315
+ min_encoding_indices = min_encoding_indices.reshape(
316
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
317
+
318
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
319
+
320
+ def get_codebook_entry(self, indices, shape):
321
+ # shape specifying (batch, height, width, channel)
322
+ if self.remap is not None:
323
+ indices = indices.reshape(shape[0], -1) # add batch axis
324
+ indices = self.unmap_to_all(indices)
325
+ indices = indices.reshape(-1) # flatten again
326
+
327
+ # get quantized latent vectors
328
+ z_q = self.embedding(indices)
329
+
330
+ if shape is not None:
331
+ z_q = z_q.view(shape)
332
+ # reshape back to match original input shape
333
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
334
+
335
+ return z_q
Models/util.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, hashlib
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+ URL_MAP = {
6
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
7
+ }
8
+
9
+ CKPT_MAP = {
10
+ "vgg_lpips": "vgg.pth"
11
+ }
12
+
13
+ MD5_MAP = {
14
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
15
+ }
16
+
17
+
18
+ def download(url, local_path, chunk_size=1024):
19
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
20
+ with requests.get(url, stream=True) as r:
21
+ total_size = int(r.headers.get("content-length", 0))
22
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
23
+ with open(local_path, "wb") as f:
24
+ for data in r.iter_content(chunk_size=chunk_size):
25
+ if data:
26
+ f.write(data)
27
+ pbar.update(chunk_size)
28
+
29
+
30
+ def md5_hash(path):
31
+ with open(path, "rb") as f:
32
+ content = f.read()
33
+ return hashlib.md5(content).hexdigest()
34
+
35
+
36
+ def get_ckpt_path(name, root, check=False):
37
+ assert name in URL_MAP
38
+ path = os.path.join(root, CKPT_MAP[name])
39
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
40
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
41
+ download(URL_MAP[name], path)
42
+ md5 = md5_hash(path)
43
+ assert md5 == MD5_MAP[name], md5
44
+ return path
45
+
46
+
47
+ class KeyNotFoundError(Exception):
48
+ def __init__(self, cause, keys=None, visited=None):
49
+ self.cause = cause
50
+ self.keys = keys
51
+ self.visited = visited
52
+ messages = list()
53
+ if keys is not None:
54
+ messages.append("Key not found: {}".format(keys))
55
+ if visited is not None:
56
+ messages.append("Visited: {}".format(visited))
57
+ messages.append("Cause:\n{}".format(cause))
58
+ message = "\n".join(messages)
59
+ super().__init__(message)
60
+
61
+
62
+ def retrieve(
63
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
64
+ ):
65
+ """Given a nested list or dict return the desired value at key expanding
66
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
67
+ is done in-place.
68
+
69
+ Parameters
70
+ ----------
71
+ list_or_dict : list or dict
72
+ Possibly nested list or dictionary.
73
+ key : str
74
+ key/to/value, path like string describing all keys necessary to
75
+ consider to get to the desired value. List indices can also be
76
+ passed here.
77
+ splitval : str
78
+ String that defines the delimiter between keys of the
79
+ different depth levels in `key`.
80
+ default : obj
81
+ Value returned if :attr:`key` is not found.
82
+ expand : bool
83
+ Whether to expand callable nodes on the path or not.
84
+
85
+ Returns
86
+ -------
87
+ The desired value or if :attr:`default` is not ``None`` and the
88
+ :attr:`key` is not found returns ``default``.
89
+
90
+ Raises
91
+ ------
92
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
93
+ ``None``.
94
+ """
95
+
96
+ keys = key.split(splitval)
97
+
98
+ success = True
99
+ try:
100
+ visited = []
101
+ parent = None
102
+ last_key = None
103
+ for key in keys:
104
+ if callable(list_or_dict):
105
+ if not expand:
106
+ raise KeyNotFoundError(
107
+ ValueError(
108
+ "Trying to get past callable node with expand=False."
109
+ ),
110
+ keys=keys,
111
+ visited=visited,
112
+ )
113
+ list_or_dict = list_or_dict()
114
+ parent[last_key] = list_or_dict
115
+
116
+ last_key = key
117
+ parent = list_or_dict
118
+
119
+ try:
120
+ if isinstance(list_or_dict, dict):
121
+ list_or_dict = list_or_dict[key]
122
+ else:
123
+ list_or_dict = list_or_dict[int(key)]
124
+ except (KeyError, IndexError, ValueError) as e:
125
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
126
+
127
+ visited += [key]
128
+ # final expansion of retrieved value
129
+ if expand and callable(list_or_dict):
130
+ list_or_dict = list_or_dict()
131
+ parent[last_key] = list_or_dict
132
+ except KeyNotFoundError as e:
133
+ if default is None:
134
+ raise e
135
+ else:
136
+ list_or_dict = default
137
+ success = False
138
+
139
+ if not pass_success:
140
+ return list_or_dict
141
+ else:
142
+ return list_or_dict, success
143
+
144
+
145
+ if __name__ == "__main__":
146
+ config = {"keya": "a",
147
+ "keyb": "b",
148
+ "keyc":
149
+ {"cc1": 1,
150
+ "cc2": 2,
151
+ }
152
+ }
153
+ from omegaconf import OmegaConf
154
+ config = OmegaConf.create(config)
155
+ print(config)
156
+ retrieve(config, "keya")
157
+
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import gradio as gr
4
+ from torchvision import transforms
5
+ from runner import MaskGIT
6
+ import numpy as np
7
+ import random
8
+ import torchvision.utils as vutils
9
+
10
+
11
+ class Args(argparse.Namespace):
12
+ data_folder = ""
13
+ vqgan_folder = "pretrained_maskgit/VQGAN"
14
+ writer_log = ""
15
+ data = ""
16
+ mask_value = 1024
17
+ seed = 1
18
+ channel = 3
19
+ num_workers = 0
20
+ iter = 0
21
+ global_epoch = 0
22
+ lr = 1e-4
23
+ drop_label = 0.1
24
+ resume = True
25
+ device = "cpu"
26
+ print(device)
27
+ debug = True
28
+ test_only = False
29
+ is_master = True
30
+ is_multi_gpus = False
31
+ vit_size = "base"
32
+ vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
33
+ img_size = 256
34
+ patch_size = 256 // 16
35
+
36
+
37
+ def set_seed(seed):
38
+ if seed > 0:
39
+ torch.manual_seed(seed)
40
+ torch.cuda.manual_seed(seed)
41
+ np.random.seed(seed)
42
+ random.seed(seed)
43
+ torch.backends.cudnn.enable = False
44
+ torch.backends.cudnn.deterministic = True
45
+
46
+ args = Args()
47
+ maskgit = MaskGIT(args)
48
+
49
+
50
+ # Function to perform image synthesis
51
+ def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1):
52
+ # Perform image synthesis using your model
53
+ set_seed(seed)
54
+ with torch.no_grad():
55
+ labels = [cls] * nb_img
56
+ labels = torch.LongTensor(labels).to(args.device)
57
+ gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w,
58
+ randomize="linear", r_temp=r_temp, sched_mode="arccos",
59
+ step=step)[0]
60
+
61
+ # Post-process the output image (adjust based on your needs)
62
+ output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True))
63
+
64
+ return output_image
65
+
66
+
67
+ # Gradio Interface
68
+ app = gr.Interface(
69
+ fn=synthesize_image,
70
+ inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16),
71
+ gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)],
72
+ outputs=gr.Image(),
73
+ title="Image Synthesis using MaskGIT",
74
+ )
75
+
76
+ # Launch the Gradio app
77
+ app.launch(share=True)
78
+
runner.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Trainer for MaskGIT
2
+ import os
3
+ import random
4
+ import math
5
+
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from omegaconf import OmegaConf
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+
14
+ from Models.models.transformer import MaskTransformer
15
+ from Models.models.vqgan import VQModel
16
+
17
+
18
+ class MaskGIT(nn.Module):
19
+
20
+ def __init__(self, args):
21
+ """ Initialization of the model (VQGAN and Masked Transformer), optimizer, criterion, etc."""
22
+ super().__init__()
23
+
24
+ self.args = args # Main argument see main.py
25
+ self.patch_size = self.args.img_size // 16 # Number of vizual token (+1 for the class)
26
+ self.scaler = torch.cuda.amp.GradScaler() # Init Scaler for multi GPUs
27
+ self.vit = self.get_network("vit") # Load Masked Bidirectional Transformer
28
+ self.ae = self.get_network("autoencoder") # Load VQGAN
29
+
30
+ def get_network(self, archi):
31
+ """ return the network, load checkpoint if self.args.resume == True
32
+ :param
33
+ archi -> str: vit|autoencoder, the architecture to load
34
+ :return
35
+ model -> nn.Module: the network
36
+ """
37
+ if archi == "vit":
38
+ if self.args.vit_size == "base":
39
+ model = MaskTransformer(
40
+ img_size=self.args.img_size, hidden_dim=768, codebook_size=1024, depth=24, heads=16, mlp_dim=3072, dropout=0.1 # Small
41
+ )
42
+ elif self.args.vit_size == "big":
43
+ model = MaskTransformer(
44
+ img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=32, heads=16, mlp_dim=3072, dropout=0.1 # Big
45
+ )
46
+ elif self.args.vit_size == "huge":
47
+ model = MaskTransformer(
48
+ img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=48, heads=16, mlp_dim=3072, dropout=0.1 # Huge
49
+ )
50
+
51
+ if self.args.resume:
52
+ ckpt = self.args.vit_folder
53
+ ckpt += "current.pth" if os.path.isdir(self.args.vit_folder) else ""
54
+ if self.args.is_master:
55
+ print("load ckpt from:", ckpt)
56
+ # Read checkpoint file
57
+ checkpoint = torch.load(ckpt, map_location='cpu')
58
+ # Load network
59
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
60
+
61
+ model = model.to(self.args.device)
62
+
63
+ if self.args.is_multi_gpus: # put model on multi GPUs if available
64
+ model = DDP(model, device_ids=[self.args.device])
65
+
66
+ elif archi == "autoencoder":
67
+ # Load config
68
+ config = OmegaConf.load(os.path.join(self.args.vqgan_folder, "model.yaml"))
69
+ model = VQModel(**config.model.params)
70
+ checkpoint = torch.load(os.path.join(self.args.vqgan_folder, "last.ckpt"), map_location="cpu")["state_dict"]
71
+ # Load network
72
+ model.load_state_dict(checkpoint, strict=False)
73
+ model = model.eval()
74
+ model = model.to(self.args.device)
75
+
76
+ if self.args.is_multi_gpus: # put model on multi GPUs if available
77
+ model = DDP(model, device_ids=[self.args.device])
78
+ model = model.module
79
+ else:
80
+ model = None
81
+
82
+ if self.args.is_master:
83
+ print(f"Size of model {archi}: "
84
+ f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M")
85
+
86
+ return model
87
+
88
+ def adap_sche(self, step, mode="arccos", leave=False):
89
+ """ Create a sampling scheduler
90
+ :param
91
+ step -> int: number of prediction during inference
92
+ mode -> str: the rate of value to unmask
93
+ leave -> bool: tqdm arg on either to keep the bar or not
94
+ :return
95
+ scheduler -> torch.LongTensor(): the list of token to predict at each step
96
+ """
97
+ r = torch.linspace(1, 0, step)
98
+ if mode == "root": # root scheduler
99
+ val_to_mask = 1 - (r ** .5)
100
+ elif mode == "linear": # linear scheduler
101
+ val_to_mask = 1 - r
102
+ elif mode == "square": # square scheduler
103
+ val_to_mask = 1 - (r ** 2)
104
+ elif mode == "cosine": # cosine scheduler
105
+ val_to_mask = torch.cos(r * math.pi * 0.5)
106
+ elif mode == "arccos": # arc cosine scheduler
107
+ val_to_mask = torch.arccos(r) / (math.pi * 0.5)
108
+ else:
109
+ return
110
+
111
+ # fill the scheduler by the ratio of tokens to predict at each step
112
+ sche = (val_to_mask / val_to_mask.sum()) * (self.patch_size * self.patch_size)
113
+ sche = sche.round()
114
+ sche[sche == 0] = 1 # add 1 to predict a least 1 token / step
115
+ sche[-1] += (self.patch_size * self.patch_size) - sche.sum() # need to sum up nb of code
116
+ return tqdm(sche.int(), leave=leave)
117
+
118
+ def sample(self, init_code=None, nb_sample=50, labels=None, sm_temp=1, w=3,
119
+ randomize="linear", r_temp=4.5, sched_mode="arccos", step=12):
120
+ """ Generate sample with the MaskGIT model
121
+ :param
122
+ init_code -> torch.LongTensor: nb_sample x 16 x 16, the starting initialization code
123
+ nb_sample -> int: the number of image to generated
124
+ labels -> torch.LongTensor: the list of classes to generate
125
+ sm_temp -> float: the temperature before softmax
126
+ w -> float: scale for the classifier free guidance
127
+ randomize -> str: linear|warm_up|random|no, either or not to add randomness
128
+ r_temp -> float: temperature for the randomness
129
+ sched_mode -> str: root|linear|square|cosine|arccos, the shape of the scheduler
130
+ step: -> int: number of step for the decoding
131
+ :return
132
+ x -> torch.FloatTensor: nb_sample x 3 x 256 x 256, the generated images
133
+ code -> torch.LongTensor: nb_sample x step x 16 x 16, the code corresponding to the generated images
134
+ """
135
+ self.vit.eval()
136
+ l_codes = [] # Save the intermediate codes predicted
137
+ l_mask = [] # Save the intermediate masks
138
+ with torch.no_grad():
139
+ if labels is None: # Default classes generated
140
+ # goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear, random
141
+ labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, random.randint(0, 999)] * (nb_sample // 10)
142
+ labels = torch.LongTensor(labels).to(self.args.device)
143
+
144
+ drop = torch.ones(nb_sample, dtype=torch.bool).to(self.args.device)
145
+ if init_code is not None: # Start with a pre-define code
146
+ code = init_code
147
+ mask = (init_code == 1024).float().view(nb_sample, self.patch_size*self.patch_size)
148
+ else: # Initialize a code
149
+ if self.args.mask_value < 0: # Code initialize with random tokens
150
+ code = torch.randint(0, 1024, (nb_sample, self.patch_size, self.patch_size)).to(self.args.device)
151
+ else: # Code initialize with masked tokens
152
+ code = torch.full((nb_sample, self.patch_size, self.patch_size), self.args.mask_value).to(self.args.device)
153
+ mask = torch.ones(nb_sample, self.patch_size*self.patch_size).to(self.args.device)
154
+
155
+ # Instantiate scheduler
156
+ if isinstance(sched_mode, str): # Standard ones
157
+ scheduler = self.adap_sche(step, mode=sched_mode)
158
+ else: # Custom one
159
+ scheduler = sched_mode
160
+
161
+ # Beginning of sampling, t = number of token to predict a step "indice"
162
+ for indice, t in enumerate(scheduler):
163
+ if mask.sum() < t: # Cannot predict more token than 16*16 or 32*32
164
+ t = int(mask.sum().item())
165
+
166
+ if mask.sum() == 0: # Break if code is fully predicted
167
+ break
168
+
169
+ with torch.cuda.amp.autocast(): # half precision
170
+ if w != 0:
171
+ # Model Prediction
172
+ logit = self.vit(torch.cat([code.clone(), code.clone()], dim=0),
173
+ torch.cat([labels, labels], dim=0),
174
+ torch.cat([~drop, drop], dim=0))
175
+ logit_c, logit_u = torch.chunk(logit, 2, dim=0)
176
+ _w = w * (indice / (len(scheduler)-1))
177
+ # Classifier Free Guidance
178
+ logit = (1 + _w) * logit_c - _w * logit_u
179
+ else:
180
+ logit = self.vit(code.clone(), labels, drop_label=~drop)
181
+
182
+ prob = torch.softmax(logit * sm_temp, -1)
183
+ # Sample the code from the softmax prediction
184
+ distri = torch.distributions.Categorical(probs=prob)
185
+ pred_code = distri.sample()
186
+
187
+ conf = torch.gather(prob, 2, pred_code.view(nb_sample, self.patch_size*self.patch_size, 1))
188
+
189
+ if randomize == "linear": # add gumbel noise decreasing over the sampling process
190
+ ratio = (indice / len(scheduler))
191
+ rand = r_temp * np.random.gumbel(size=(nb_sample, self.patch_size*self.patch_size)) * (1 - ratio)
192
+ conf = torch.log(conf.squeeze()) + torch.from_numpy(rand).to(self.args.device)
193
+ elif randomize == "warm_up": # chose random sample for the 2 first steps
194
+ conf = torch.rand_like(conf) if indice < 2 else conf
195
+ elif randomize == "random": # chose random prediction at each step
196
+ conf = torch.rand_like(conf)
197
+
198
+ # do not predict on already predicted tokens
199
+ conf[~mask.bool()] = -math.inf
200
+
201
+ # chose the predicted token with the highest confidence
202
+ tresh_conf, indice_mask = torch.topk(conf.view(nb_sample, -1), k=t, dim=-1)
203
+ tresh_conf = tresh_conf[:, -1]
204
+
205
+ # replace the chosen tokens
206
+ conf = (conf >= tresh_conf.unsqueeze(-1)).view(nb_sample, self.patch_size, self.patch_size)
207
+ f_mask = (mask.view(nb_sample, self.patch_size, self.patch_size).float() * conf.view(nb_sample, self.patch_size, self.patch_size).float()).bool()
208
+ code[f_mask] = pred_code.view(nb_sample, self.patch_size, self.patch_size)[f_mask]
209
+
210
+ # update the mask
211
+ for i_mask, ind_mask in enumerate(indice_mask):
212
+ mask[i_mask, ind_mask] = 0
213
+ l_codes.append(pred_code.view(nb_sample, self.patch_size, self.patch_size).clone())
214
+ l_mask.append(mask.view(nb_sample, self.patch_size, self.patch_size).clone())
215
+
216
+ # decode the final prediction
217
+ _code = torch.clamp(code, 0, 1023) # VQGAN has only 1024 codebook
218
+ x = self.ae.decode_code(_code)
219
+ x = (torch.clamp(x, -1, 1) + 1) / 2
220
+ self.vit.train()
221
+ return x, l_codes, l_mask