# Trainer for MaskGIT import os import random import math import numpy as np from tqdm import tqdm from omegaconf import OmegaConf import torch import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from Models.models.transformer import MaskTransformer from Models.models.vqgan import VQModel class MaskGIT(nn.Module): def __init__(self, args): """ Initialization of the model (VQGAN and Masked Transformer), optimizer, criterion, etc.""" super().__init__() self.args = args # Main argument see main.py self.patch_size = self.args.img_size // 16 # Number of vizual token (+1 for the class) self.scaler = torch.cuda.amp.GradScaler() # Init Scaler for multi GPUs self.vit = self.get_network("vit") # Load Masked Bidirectional Transformer self.ae = self.get_network("autoencoder") # Load VQGAN def get_network(self, archi): """ return the network, load checkpoint if self.args.resume == True :param archi -> str: vit|autoencoder, the architecture to load :return model -> nn.Module: the network """ if archi == "vit": if self.args.vit_size == "base": model = MaskTransformer( img_size=self.args.img_size, hidden_dim=768, codebook_size=1024, depth=24, heads=16, mlp_dim=3072, dropout=0.1 # Small ) elif self.args.vit_size == "big": model = MaskTransformer( img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=32, heads=16, mlp_dim=3072, dropout=0.1 # Big ) elif self.args.vit_size == "huge": model = MaskTransformer( img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=48, heads=16, mlp_dim=3072, dropout=0.1 # Huge ) if self.args.resume: ckpt = self.args.vit_folder ckpt += "current.pth" if os.path.isdir(self.args.vit_folder) else "" if self.args.is_master: print("load ckpt from:", ckpt) # Read checkpoint file checkpoint = torch.load(ckpt, map_location='cpu') # Load network model.load_state_dict(checkpoint['model_state_dict'], strict=False) model = model.to(self.args.device) if self.args.is_multi_gpus: # put model on multi GPUs if available model = DDP(model, device_ids=[self.args.device]) elif archi == "autoencoder": # Load config config = OmegaConf.load(os.path.join(self.args.vqgan_folder, "model.yaml")) model = VQModel(**config.model.params) checkpoint = torch.load(os.path.join(self.args.vqgan_folder, "last.ckpt"), map_location="cpu")["state_dict"] # Load network model.load_state_dict(checkpoint, strict=False) model = model.eval() model = model.to(self.args.device) if self.args.is_multi_gpus: # put model on multi GPUs if available model = DDP(model, device_ids=[self.args.device]) model = model.module else: model = None if self.args.is_master: print(f"Size of model {archi}: " f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M") return model def adap_sche(self, step, mode="arccos", leave=False): """ Create a sampling scheduler :param step -> int: number of prediction during inference mode -> str: the rate of value to unmask leave -> bool: tqdm arg on either to keep the bar or not :return scheduler -> torch.LongTensor(): the list of token to predict at each step """ r = torch.linspace(1, 0, step) if mode == "root": # root scheduler val_to_mask = 1 - (r ** .5) elif mode == "linear": # linear scheduler val_to_mask = 1 - r elif mode == "square": # square scheduler val_to_mask = 1 - (r ** 2) elif mode == "cosine": # cosine scheduler val_to_mask = torch.cos(r * math.pi * 0.5) elif mode == "arccos": # arc cosine scheduler val_to_mask = torch.arccos(r) / (math.pi * 0.5) else: return # fill the scheduler by the ratio of tokens to predict at each step sche = (val_to_mask / val_to_mask.sum()) * (self.patch_size * self.patch_size) sche = sche.round() sche[sche == 0] = 1 # add 1 to predict a least 1 token / step sche[-1] += (self.patch_size * self.patch_size) - sche.sum() # need to sum up nb of code return tqdm(sche.int(), leave=leave) def sample(self, init_code=None, nb_sample=50, labels=None, sm_temp=1, w=3, randomize="linear", r_temp=4.5, sched_mode="arccos", step=12): """ Generate sample with the MaskGIT model :param init_code -> torch.LongTensor: nb_sample x 16 x 16, the starting initialization code nb_sample -> int: the number of image to generated labels -> torch.LongTensor: the list of classes to generate sm_temp -> float: the temperature before softmax w -> float: scale for the classifier free guidance randomize -> str: linear|warm_up|random|no, either or not to add randomness r_temp -> float: temperature for the randomness sched_mode -> str: root|linear|square|cosine|arccos, the shape of the scheduler step: -> int: number of step for the decoding :return x -> torch.FloatTensor: nb_sample x 3 x 256 x 256, the generated images code -> torch.LongTensor: nb_sample x step x 16 x 16, the code corresponding to the generated images """ self.vit.eval() l_codes = [] # Save the intermediate codes predicted l_mask = [] # Save the intermediate masks with torch.no_grad(): if labels is None: # Default classes generated # goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear, random labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, random.randint(0, 999)] * (nb_sample // 10) labels = torch.LongTensor(labels).to(self.args.device) drop = torch.ones(nb_sample, dtype=torch.bool).to(self.args.device) if init_code is not None: # Start with a pre-define code code = init_code mask = (init_code == 1024).float().view(nb_sample, self.patch_size*self.patch_size) else: # Initialize a code if self.args.mask_value < 0: # Code initialize with random tokens code = torch.randint(0, 1024, (nb_sample, self.patch_size, self.patch_size)).to(self.args.device) else: # Code initialize with masked tokens code = torch.full((nb_sample, self.patch_size, self.patch_size), self.args.mask_value).to(self.args.device) mask = torch.ones(nb_sample, self.patch_size*self.patch_size).to(self.args.device) # Instantiate scheduler if isinstance(sched_mode, str): # Standard ones scheduler = self.adap_sche(step, mode=sched_mode) else: # Custom one scheduler = sched_mode # Beginning of sampling, t = number of token to predict a step "indice" for indice, t in enumerate(scheduler): if mask.sum() < t: # Cannot predict more token than 16*16 or 32*32 t = int(mask.sum().item()) if mask.sum() == 0: # Break if code is fully predicted break with torch.cuda.amp.autocast(): # half precision if w != 0: # Model Prediction logit = self.vit(torch.cat([code.clone(), code.clone()], dim=0), torch.cat([labels, labels], dim=0), torch.cat([~drop, drop], dim=0)) logit_c, logit_u = torch.chunk(logit, 2, dim=0) _w = w * (indice / (len(scheduler)-1)) # Classifier Free Guidance logit = (1 + _w) * logit_c - _w * logit_u else: logit = self.vit(code.clone(), labels, drop_label=~drop) prob = torch.softmax(logit * sm_temp, -1) # Sample the code from the softmax prediction distri = torch.distributions.Categorical(probs=prob) pred_code = distri.sample() conf = torch.gather(prob, 2, pred_code.view(nb_sample, self.patch_size*self.patch_size, 1)) if randomize == "linear": # add gumbel noise decreasing over the sampling process ratio = (indice / len(scheduler)) rand = r_temp * np.random.gumbel(size=(nb_sample, self.patch_size*self.patch_size)) * (1 - ratio) conf = torch.log(conf.squeeze()) + torch.from_numpy(rand).to(self.args.device) elif randomize == "warm_up": # chose random sample for the 2 first steps conf = torch.rand_like(conf) if indice < 2 else conf elif randomize == "random": # chose random prediction at each step conf = torch.rand_like(conf) # do not predict on already predicted tokens conf[~mask.bool()] = -math.inf # chose the predicted token with the highest confidence tresh_conf, indice_mask = torch.topk(conf.view(nb_sample, -1), k=t, dim=-1) tresh_conf = tresh_conf[:, -1] # replace the chosen tokens conf = (conf >= tresh_conf.unsqueeze(-1)).view(nb_sample, self.patch_size, self.patch_size) f_mask = (mask.view(nb_sample, self.patch_size, self.patch_size).float() * conf.view(nb_sample, self.patch_size, self.patch_size).float()).bool() code[f_mask] = pred_code.view(nb_sample, self.patch_size, self.patch_size)[f_mask] # update the mask for i_mask, ind_mask in enumerate(indice_mask): mask[i_mask, ind_mask] = 0 l_codes.append(pred_code.view(nb_sample, self.patch_size, self.patch_size).clone()) l_mask.append(mask.view(nb_sample, self.patch_size, self.patch_size).clone()) # decode the final prediction _code = torch.clamp(code, 0, 1023) # VQGAN has only 1024 codebook x = self.ae.decode_code(_code) x = (torch.clamp(x, -1, 1) + 1) / 2 self.vit.train() return x, l_codes, l_mask