File size: 11,452 Bytes
8513f87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
# Trainer for MaskGIT
import os
import random
import math
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from Models.models.transformer import MaskTransformer
from Models.models.vqgan import VQModel
class MaskGIT(nn.Module):
def __init__(self, args):
""" Initialization of the model (VQGAN and Masked Transformer), optimizer, criterion, etc."""
super().__init__()
self.args = args # Main argument see main.py
self.patch_size = self.args.img_size // 16 # Number of vizual token (+1 for the class)
self.scaler = torch.cuda.amp.GradScaler() # Init Scaler for multi GPUs
self.vit = self.get_network("vit") # Load Masked Bidirectional Transformer
self.ae = self.get_network("autoencoder") # Load VQGAN
def get_network(self, archi):
""" return the network, load checkpoint if self.args.resume == True
:param
archi -> str: vit|autoencoder, the architecture to load
:return
model -> nn.Module: the network
"""
if archi == "vit":
if self.args.vit_size == "base":
model = MaskTransformer(
img_size=self.args.img_size, hidden_dim=768, codebook_size=1024, depth=24, heads=16, mlp_dim=3072, dropout=0.1 # Small
)
elif self.args.vit_size == "big":
model = MaskTransformer(
img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=32, heads=16, mlp_dim=3072, dropout=0.1 # Big
)
elif self.args.vit_size == "huge":
model = MaskTransformer(
img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=48, heads=16, mlp_dim=3072, dropout=0.1 # Huge
)
if self.args.resume:
ckpt = self.args.vit_folder
ckpt += "current.pth" if os.path.isdir(self.args.vit_folder) else ""
if self.args.is_master:
print("load ckpt from:", ckpt)
# Read checkpoint file
checkpoint = torch.load(ckpt, map_location='cpu')
# Load network
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model = model.to(self.args.device)
if self.args.is_multi_gpus: # put model on multi GPUs if available
model = DDP(model, device_ids=[self.args.device])
elif archi == "autoencoder":
# Load config
config = OmegaConf.load(os.path.join(self.args.vqgan_folder, "model.yaml"))
model = VQModel(**config.model.params)
checkpoint = torch.load(os.path.join(self.args.vqgan_folder, "last.ckpt"), map_location="cpu")["state_dict"]
# Load network
model.load_state_dict(checkpoint, strict=False)
model = model.eval()
model = model.to(self.args.device)
if self.args.is_multi_gpus: # put model on multi GPUs if available
model = DDP(model, device_ids=[self.args.device])
model = model.module
else:
model = None
if self.args.is_master:
print(f"Size of model {archi}: "
f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M")
return model
def adap_sche(self, step, mode="arccos", leave=False):
""" Create a sampling scheduler
:param
step -> int: number of prediction during inference
mode -> str: the rate of value to unmask
leave -> bool: tqdm arg on either to keep the bar or not
:return
scheduler -> torch.LongTensor(): the list of token to predict at each step
"""
r = torch.linspace(1, 0, step)
if mode == "root": # root scheduler
val_to_mask = 1 - (r ** .5)
elif mode == "linear": # linear scheduler
val_to_mask = 1 - r
elif mode == "square": # square scheduler
val_to_mask = 1 - (r ** 2)
elif mode == "cosine": # cosine scheduler
val_to_mask = torch.cos(r * math.pi * 0.5)
elif mode == "arccos": # arc cosine scheduler
val_to_mask = torch.arccos(r) / (math.pi * 0.5)
else:
return
# fill the scheduler by the ratio of tokens to predict at each step
sche = (val_to_mask / val_to_mask.sum()) * (self.patch_size * self.patch_size)
sche = sche.round()
sche[sche == 0] = 1 # add 1 to predict a least 1 token / step
sche[-1] += (self.patch_size * self.patch_size) - sche.sum() # need to sum up nb of code
return tqdm(sche.int(), leave=leave)
def sample(self, init_code=None, nb_sample=50, labels=None, sm_temp=1, w=3,
randomize="linear", r_temp=4.5, sched_mode="arccos", step=12):
""" Generate sample with the MaskGIT model
:param
init_code -> torch.LongTensor: nb_sample x 16 x 16, the starting initialization code
nb_sample -> int: the number of image to generated
labels -> torch.LongTensor: the list of classes to generate
sm_temp -> float: the temperature before softmax
w -> float: scale for the classifier free guidance
randomize -> str: linear|warm_up|random|no, either or not to add randomness
r_temp -> float: temperature for the randomness
sched_mode -> str: root|linear|square|cosine|arccos, the shape of the scheduler
step: -> int: number of step for the decoding
:return
x -> torch.FloatTensor: nb_sample x 3 x 256 x 256, the generated images
code -> torch.LongTensor: nb_sample x step x 16 x 16, the code corresponding to the generated images
"""
self.vit.eval()
l_codes = [] # Save the intermediate codes predicted
l_mask = [] # Save the intermediate masks
with torch.no_grad():
if labels is None: # Default classes generated
# goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear, random
labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, random.randint(0, 999)] * (nb_sample // 10)
labels = torch.LongTensor(labels).to(self.args.device)
drop = torch.ones(nb_sample, dtype=torch.bool).to(self.args.device)
if init_code is not None: # Start with a pre-define code
code = init_code
mask = (init_code == 1024).float().view(nb_sample, self.patch_size*self.patch_size)
else: # Initialize a code
if self.args.mask_value < 0: # Code initialize with random tokens
code = torch.randint(0, 1024, (nb_sample, self.patch_size, self.patch_size)).to(self.args.device)
else: # Code initialize with masked tokens
code = torch.full((nb_sample, self.patch_size, self.patch_size), self.args.mask_value).to(self.args.device)
mask = torch.ones(nb_sample, self.patch_size*self.patch_size).to(self.args.device)
# Instantiate scheduler
if isinstance(sched_mode, str): # Standard ones
scheduler = self.adap_sche(step, mode=sched_mode)
else: # Custom one
scheduler = sched_mode
# Beginning of sampling, t = number of token to predict a step "indice"
for indice, t in enumerate(scheduler):
if mask.sum() < t: # Cannot predict more token than 16*16 or 32*32
t = int(mask.sum().item())
if mask.sum() == 0: # Break if code is fully predicted
break
with torch.cuda.amp.autocast(): # half precision
if w != 0:
# Model Prediction
logit = self.vit(torch.cat([code.clone(), code.clone()], dim=0),
torch.cat([labels, labels], dim=0),
torch.cat([~drop, drop], dim=0))
logit_c, logit_u = torch.chunk(logit, 2, dim=0)
_w = w * (indice / (len(scheduler)-1))
# Classifier Free Guidance
logit = (1 + _w) * logit_c - _w * logit_u
else:
logit = self.vit(code.clone(), labels, drop_label=~drop)
prob = torch.softmax(logit * sm_temp, -1)
# Sample the code from the softmax prediction
distri = torch.distributions.Categorical(probs=prob)
pred_code = distri.sample()
conf = torch.gather(prob, 2, pred_code.view(nb_sample, self.patch_size*self.patch_size, 1))
if randomize == "linear": # add gumbel noise decreasing over the sampling process
ratio = (indice / len(scheduler))
rand = r_temp * np.random.gumbel(size=(nb_sample, self.patch_size*self.patch_size)) * (1 - ratio)
conf = torch.log(conf.squeeze()) + torch.from_numpy(rand).to(self.args.device)
elif randomize == "warm_up": # chose random sample for the 2 first steps
conf = torch.rand_like(conf) if indice < 2 else conf
elif randomize == "random": # chose random prediction at each step
conf = torch.rand_like(conf)
# do not predict on already predicted tokens
conf[~mask.bool()] = -math.inf
# chose the predicted token with the highest confidence
tresh_conf, indice_mask = torch.topk(conf.view(nb_sample, -1), k=t, dim=-1)
tresh_conf = tresh_conf[:, -1]
# replace the chosen tokens
conf = (conf >= tresh_conf.unsqueeze(-1)).view(nb_sample, self.patch_size, self.patch_size)
f_mask = (mask.view(nb_sample, self.patch_size, self.patch_size).float() * conf.view(nb_sample, self.patch_size, self.patch_size).float()).bool()
code[f_mask] = pred_code.view(nb_sample, self.patch_size, self.patch_size)[f_mask]
# update the mask
for i_mask, ind_mask in enumerate(indice_mask):
mask[i_mask, ind_mask] = 0
l_codes.append(pred_code.view(nb_sample, self.patch_size, self.patch_size).clone())
l_mask.append(mask.view(nb_sample, self.patch_size, self.patch_size).clone())
# decode the final prediction
_code = torch.clamp(code, 0, 1023) # VQGAN has only 1024 codebook
x = self.ae.decode_code(_code)
x = (torch.clamp(x, -1, 1) + 1) / 2
self.vit.train()
return x, l_codes, l_mask
|