Maskgit-pytorch / runner.py
llvictorll's picture
add gradio app
8513f87 verified
raw
history blame
11.5 kB
# Trainer for MaskGIT
import os
import random
import math
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from Models.models.transformer import MaskTransformer
from Models.models.vqgan import VQModel
class MaskGIT(nn.Module):
def __init__(self, args):
""" Initialization of the model (VQGAN and Masked Transformer), optimizer, criterion, etc."""
super().__init__()
self.args = args # Main argument see main.py
self.patch_size = self.args.img_size // 16 # Number of vizual token (+1 for the class)
self.scaler = torch.cuda.amp.GradScaler() # Init Scaler for multi GPUs
self.vit = self.get_network("vit") # Load Masked Bidirectional Transformer
self.ae = self.get_network("autoencoder") # Load VQGAN
def get_network(self, archi):
""" return the network, load checkpoint if self.args.resume == True
:param
archi -> str: vit|autoencoder, the architecture to load
:return
model -> nn.Module: the network
"""
if archi == "vit":
if self.args.vit_size == "base":
model = MaskTransformer(
img_size=self.args.img_size, hidden_dim=768, codebook_size=1024, depth=24, heads=16, mlp_dim=3072, dropout=0.1 # Small
)
elif self.args.vit_size == "big":
model = MaskTransformer(
img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=32, heads=16, mlp_dim=3072, dropout=0.1 # Big
)
elif self.args.vit_size == "huge":
model = MaskTransformer(
img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=48, heads=16, mlp_dim=3072, dropout=0.1 # Huge
)
if self.args.resume:
ckpt = self.args.vit_folder
ckpt += "current.pth" if os.path.isdir(self.args.vit_folder) else ""
if self.args.is_master:
print("load ckpt from:", ckpt)
# Read checkpoint file
checkpoint = torch.load(ckpt, map_location='cpu')
# Load network
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model = model.to(self.args.device)
if self.args.is_multi_gpus: # put model on multi GPUs if available
model = DDP(model, device_ids=[self.args.device])
elif archi == "autoencoder":
# Load config
config = OmegaConf.load(os.path.join(self.args.vqgan_folder, "model.yaml"))
model = VQModel(**config.model.params)
checkpoint = torch.load(os.path.join(self.args.vqgan_folder, "last.ckpt"), map_location="cpu")["state_dict"]
# Load network
model.load_state_dict(checkpoint, strict=False)
model = model.eval()
model = model.to(self.args.device)
if self.args.is_multi_gpus: # put model on multi GPUs if available
model = DDP(model, device_ids=[self.args.device])
model = model.module
else:
model = None
if self.args.is_master:
print(f"Size of model {archi}: "
f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M")
return model
def adap_sche(self, step, mode="arccos", leave=False):
""" Create a sampling scheduler
:param
step -> int: number of prediction during inference
mode -> str: the rate of value to unmask
leave -> bool: tqdm arg on either to keep the bar or not
:return
scheduler -> torch.LongTensor(): the list of token to predict at each step
"""
r = torch.linspace(1, 0, step)
if mode == "root": # root scheduler
val_to_mask = 1 - (r ** .5)
elif mode == "linear": # linear scheduler
val_to_mask = 1 - r
elif mode == "square": # square scheduler
val_to_mask = 1 - (r ** 2)
elif mode == "cosine": # cosine scheduler
val_to_mask = torch.cos(r * math.pi * 0.5)
elif mode == "arccos": # arc cosine scheduler
val_to_mask = torch.arccos(r) / (math.pi * 0.5)
else:
return
# fill the scheduler by the ratio of tokens to predict at each step
sche = (val_to_mask / val_to_mask.sum()) * (self.patch_size * self.patch_size)
sche = sche.round()
sche[sche == 0] = 1 # add 1 to predict a least 1 token / step
sche[-1] += (self.patch_size * self.patch_size) - sche.sum() # need to sum up nb of code
return tqdm(sche.int(), leave=leave)
def sample(self, init_code=None, nb_sample=50, labels=None, sm_temp=1, w=3,
randomize="linear", r_temp=4.5, sched_mode="arccos", step=12):
""" Generate sample with the MaskGIT model
:param
init_code -> torch.LongTensor: nb_sample x 16 x 16, the starting initialization code
nb_sample -> int: the number of image to generated
labels -> torch.LongTensor: the list of classes to generate
sm_temp -> float: the temperature before softmax
w -> float: scale for the classifier free guidance
randomize -> str: linear|warm_up|random|no, either or not to add randomness
r_temp -> float: temperature for the randomness
sched_mode -> str: root|linear|square|cosine|arccos, the shape of the scheduler
step: -> int: number of step for the decoding
:return
x -> torch.FloatTensor: nb_sample x 3 x 256 x 256, the generated images
code -> torch.LongTensor: nb_sample x step x 16 x 16, the code corresponding to the generated images
"""
self.vit.eval()
l_codes = [] # Save the intermediate codes predicted
l_mask = [] # Save the intermediate masks
with torch.no_grad():
if labels is None: # Default classes generated
# goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear, random
labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, random.randint(0, 999)] * (nb_sample // 10)
labels = torch.LongTensor(labels).to(self.args.device)
drop = torch.ones(nb_sample, dtype=torch.bool).to(self.args.device)
if init_code is not None: # Start with a pre-define code
code = init_code
mask = (init_code == 1024).float().view(nb_sample, self.patch_size*self.patch_size)
else: # Initialize a code
if self.args.mask_value < 0: # Code initialize with random tokens
code = torch.randint(0, 1024, (nb_sample, self.patch_size, self.patch_size)).to(self.args.device)
else: # Code initialize with masked tokens
code = torch.full((nb_sample, self.patch_size, self.patch_size), self.args.mask_value).to(self.args.device)
mask = torch.ones(nb_sample, self.patch_size*self.patch_size).to(self.args.device)
# Instantiate scheduler
if isinstance(sched_mode, str): # Standard ones
scheduler = self.adap_sche(step, mode=sched_mode)
else: # Custom one
scheduler = sched_mode
# Beginning of sampling, t = number of token to predict a step "indice"
for indice, t in enumerate(scheduler):
if mask.sum() < t: # Cannot predict more token than 16*16 or 32*32
t = int(mask.sum().item())
if mask.sum() == 0: # Break if code is fully predicted
break
with torch.cuda.amp.autocast(): # half precision
if w != 0:
# Model Prediction
logit = self.vit(torch.cat([code.clone(), code.clone()], dim=0),
torch.cat([labels, labels], dim=0),
torch.cat([~drop, drop], dim=0))
logit_c, logit_u = torch.chunk(logit, 2, dim=0)
_w = w * (indice / (len(scheduler)-1))
# Classifier Free Guidance
logit = (1 + _w) * logit_c - _w * logit_u
else:
logit = self.vit(code.clone(), labels, drop_label=~drop)
prob = torch.softmax(logit * sm_temp, -1)
# Sample the code from the softmax prediction
distri = torch.distributions.Categorical(probs=prob)
pred_code = distri.sample()
conf = torch.gather(prob, 2, pred_code.view(nb_sample, self.patch_size*self.patch_size, 1))
if randomize == "linear": # add gumbel noise decreasing over the sampling process
ratio = (indice / len(scheduler))
rand = r_temp * np.random.gumbel(size=(nb_sample, self.patch_size*self.patch_size)) * (1 - ratio)
conf = torch.log(conf.squeeze()) + torch.from_numpy(rand).to(self.args.device)
elif randomize == "warm_up": # chose random sample for the 2 first steps
conf = torch.rand_like(conf) if indice < 2 else conf
elif randomize == "random": # chose random prediction at each step
conf = torch.rand_like(conf)
# do not predict on already predicted tokens
conf[~mask.bool()] = -math.inf
# chose the predicted token with the highest confidence
tresh_conf, indice_mask = torch.topk(conf.view(nb_sample, -1), k=t, dim=-1)
tresh_conf = tresh_conf[:, -1]
# replace the chosen tokens
conf = (conf >= tresh_conf.unsqueeze(-1)).view(nb_sample, self.patch_size, self.patch_size)
f_mask = (mask.view(nb_sample, self.patch_size, self.patch_size).float() * conf.view(nb_sample, self.patch_size, self.patch_size).float()).bool()
code[f_mask] = pred_code.view(nb_sample, self.patch_size, self.patch_size)[f_mask]
# update the mask
for i_mask, ind_mask in enumerate(indice_mask):
mask[i_mask, ind_mask] = 0
l_codes.append(pred_code.view(nb_sample, self.patch_size, self.patch_size).clone())
l_mask.append(mask.view(nb_sample, self.patch_size, self.patch_size).clone())
# decode the final prediction
_code = torch.clamp(code, 0, 1023) # VQGAN has only 1024 codebook
x = self.ae.decode_code(_code)
x = (torch.clamp(x, -1, 1) + 1) / 2
self.vit.train()
return x, l_codes, l_mask