|
|
|
import os |
|
import random |
|
import math |
|
|
|
import numpy as np |
|
from tqdm import tqdm |
|
from omegaconf import OmegaConf |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
from Models.models.transformer import MaskTransformer |
|
from Models.models.vqgan import VQModel |
|
|
|
|
|
class MaskGIT(nn.Module): |
|
|
|
def __init__(self, args): |
|
""" Initialization of the model (VQGAN and Masked Transformer), optimizer, criterion, etc.""" |
|
super().__init__() |
|
|
|
self.args = args |
|
self.patch_size = self.args.img_size // 16 |
|
self.scaler = torch.cuda.amp.GradScaler() |
|
self.vit = self.get_network("vit") |
|
self.ae = self.get_network("autoencoder") |
|
|
|
def get_network(self, archi): |
|
""" return the network, load checkpoint if self.args.resume == True |
|
:param |
|
archi -> str: vit|autoencoder, the architecture to load |
|
:return |
|
model -> nn.Module: the network |
|
""" |
|
if archi == "vit": |
|
if self.args.vit_size == "base": |
|
model = MaskTransformer( |
|
img_size=self.args.img_size, hidden_dim=768, codebook_size=1024, depth=24, heads=16, mlp_dim=3072, dropout=0.1 |
|
) |
|
elif self.args.vit_size == "big": |
|
model = MaskTransformer( |
|
img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=32, heads=16, mlp_dim=3072, dropout=0.1 |
|
) |
|
elif self.args.vit_size == "huge": |
|
model = MaskTransformer( |
|
img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=48, heads=16, mlp_dim=3072, dropout=0.1 |
|
) |
|
|
|
if self.args.resume: |
|
ckpt = self.args.vit_folder |
|
ckpt += "current.pth" if os.path.isdir(self.args.vit_folder) else "" |
|
if self.args.is_master: |
|
print("load ckpt from:", ckpt) |
|
|
|
checkpoint = torch.load(ckpt, map_location='cpu') |
|
|
|
model.load_state_dict(checkpoint['model_state_dict'], strict=False) |
|
|
|
model = model.to(self.args.device) |
|
|
|
if self.args.is_multi_gpus: |
|
model = DDP(model, device_ids=[self.args.device]) |
|
|
|
elif archi == "autoencoder": |
|
|
|
config = OmegaConf.load(os.path.join(self.args.vqgan_folder, "model.yaml")) |
|
model = VQModel(**config.model.params) |
|
checkpoint = torch.load(os.path.join(self.args.vqgan_folder, "last.ckpt"), map_location="cpu")["state_dict"] |
|
|
|
model.load_state_dict(checkpoint, strict=False) |
|
model = model.eval() |
|
model = model.to(self.args.device) |
|
|
|
if self.args.is_multi_gpus: |
|
model = DDP(model, device_ids=[self.args.device]) |
|
model = model.module |
|
else: |
|
model = None |
|
|
|
if self.args.is_master: |
|
print(f"Size of model {archi}: " |
|
f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M") |
|
|
|
return model |
|
|
|
def adap_sche(self, step, mode="arccos", leave=False): |
|
""" Create a sampling scheduler |
|
:param |
|
step -> int: number of prediction during inference |
|
mode -> str: the rate of value to unmask |
|
leave -> bool: tqdm arg on either to keep the bar or not |
|
:return |
|
scheduler -> torch.LongTensor(): the list of token to predict at each step |
|
""" |
|
r = torch.linspace(1, 0, step) |
|
if mode == "root": |
|
val_to_mask = 1 - (r ** .5) |
|
elif mode == "linear": |
|
val_to_mask = 1 - r |
|
elif mode == "square": |
|
val_to_mask = 1 - (r ** 2) |
|
elif mode == "cosine": |
|
val_to_mask = torch.cos(r * math.pi * 0.5) |
|
elif mode == "arccos": |
|
val_to_mask = torch.arccos(r) / (math.pi * 0.5) |
|
else: |
|
return |
|
|
|
|
|
sche = (val_to_mask / val_to_mask.sum()) * (self.patch_size * self.patch_size) |
|
sche = sche.round() |
|
sche[sche == 0] = 1 |
|
sche[-1] += (self.patch_size * self.patch_size) - sche.sum() |
|
return tqdm(sche.int(), leave=leave) |
|
|
|
def sample(self, init_code=None, nb_sample=50, labels=None, sm_temp=1, w=3, |
|
randomize="linear", r_temp=4.5, sched_mode="arccos", step=12): |
|
""" Generate sample with the MaskGIT model |
|
:param |
|
init_code -> torch.LongTensor: nb_sample x 16 x 16, the starting initialization code |
|
nb_sample -> int: the number of image to generated |
|
labels -> torch.LongTensor: the list of classes to generate |
|
sm_temp -> float: the temperature before softmax |
|
w -> float: scale for the classifier free guidance |
|
randomize -> str: linear|warm_up|random|no, either or not to add randomness |
|
r_temp -> float: temperature for the randomness |
|
sched_mode -> str: root|linear|square|cosine|arccos, the shape of the scheduler |
|
step: -> int: number of step for the decoding |
|
:return |
|
x -> torch.FloatTensor: nb_sample x 3 x 256 x 256, the generated images |
|
code -> torch.LongTensor: nb_sample x step x 16 x 16, the code corresponding to the generated images |
|
""" |
|
self.vit.eval() |
|
l_codes = [] |
|
l_mask = [] |
|
with torch.no_grad(): |
|
if labels is None: |
|
|
|
labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, random.randint(0, 999)] * (nb_sample // 10) |
|
labels = torch.LongTensor(labels).to(self.args.device) |
|
|
|
drop = torch.ones(nb_sample, dtype=torch.bool).to(self.args.device) |
|
if init_code is not None: |
|
code = init_code |
|
mask = (init_code == 1024).float().view(nb_sample, self.patch_size*self.patch_size) |
|
else: |
|
if self.args.mask_value < 0: |
|
code = torch.randint(0, 1024, (nb_sample, self.patch_size, self.patch_size)).to(self.args.device) |
|
else: |
|
code = torch.full((nb_sample, self.patch_size, self.patch_size), self.args.mask_value).to(self.args.device) |
|
mask = torch.ones(nb_sample, self.patch_size*self.patch_size).to(self.args.device) |
|
|
|
|
|
if isinstance(sched_mode, str): |
|
scheduler = self.adap_sche(step, mode=sched_mode) |
|
else: |
|
scheduler = sched_mode |
|
|
|
|
|
for indice, t in enumerate(scheduler): |
|
if mask.sum() < t: |
|
t = int(mask.sum().item()) |
|
|
|
if mask.sum() == 0: |
|
break |
|
|
|
with torch.cuda.amp.autocast(): |
|
if w != 0: |
|
|
|
logit = self.vit(torch.cat([code.clone(), code.clone()], dim=0), |
|
torch.cat([labels, labels], dim=0), |
|
torch.cat([~drop, drop], dim=0)) |
|
logit_c, logit_u = torch.chunk(logit, 2, dim=0) |
|
_w = w * (indice / (len(scheduler)-1)) |
|
|
|
logit = (1 + _w) * logit_c - _w * logit_u |
|
else: |
|
logit = self.vit(code.clone(), labels, drop_label=~drop) |
|
|
|
prob = torch.softmax(logit * sm_temp, -1) |
|
|
|
distri = torch.distributions.Categorical(probs=prob) |
|
pred_code = distri.sample() |
|
|
|
conf = torch.gather(prob, 2, pred_code.view(nb_sample, self.patch_size*self.patch_size, 1)) |
|
|
|
if randomize == "linear": |
|
ratio = (indice / len(scheduler)) |
|
rand = r_temp * np.random.gumbel(size=(nb_sample, self.patch_size*self.patch_size)) * (1 - ratio) |
|
conf = torch.log(conf.squeeze()) + torch.from_numpy(rand).to(self.args.device) |
|
elif randomize == "warm_up": |
|
conf = torch.rand_like(conf) if indice < 2 else conf |
|
elif randomize == "random": |
|
conf = torch.rand_like(conf) |
|
|
|
|
|
conf[~mask.bool()] = -math.inf |
|
|
|
|
|
tresh_conf, indice_mask = torch.topk(conf.view(nb_sample, -1), k=t, dim=-1) |
|
tresh_conf = tresh_conf[:, -1] |
|
|
|
|
|
conf = (conf >= tresh_conf.unsqueeze(-1)).view(nb_sample, self.patch_size, self.patch_size) |
|
f_mask = (mask.view(nb_sample, self.patch_size, self.patch_size).float() * conf.view(nb_sample, self.patch_size, self.patch_size).float()).bool() |
|
code[f_mask] = pred_code.view(nb_sample, self.patch_size, self.patch_size)[f_mask] |
|
|
|
|
|
for i_mask, ind_mask in enumerate(indice_mask): |
|
mask[i_mask, ind_mask] = 0 |
|
l_codes.append(pred_code.view(nb_sample, self.patch_size, self.patch_size).clone()) |
|
l_mask.append(mask.view(nb_sample, self.patch_size, self.patch_size).clone()) |
|
|
|
|
|
_code = torch.clamp(code, 0, 1023) |
|
x = self.ae.decode_code(_code) |
|
x = (torch.clamp(x, -1, 1) + 1) / 2 |
|
self.vit.train() |
|
return x, l_codes, l_mask |
|
|