Spaces:
Running
Running
import math | |
from functools import partial | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from huggingface_hub import PyTorchModelHubMixin | |
import dist | |
from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn | |
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_ | |
from models.vqvae import VQVAE, VectorQuantizer2 | |
class SharedAdaLin(nn.Linear): | |
def forward(self, cond_BD): | |
C = self.weight.shape[0] // 6 | |
return super().forward(cond_BD).view(-1, 1, 6, C) # B16C | |
class VAR(nn.Module): | |
def __init__( | |
self, vae_local: VQVAE, | |
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., | |
drop_path_rate=0., | |
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1, | |
attn_l2_norm=False, | |
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default | |
flash_if_available=True, fused_if_available=True, | |
): | |
super().__init__() | |
# 0. hyperparameters | |
assert embed_dim % num_heads == 0 | |
self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size | |
self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads | |
self.cond_drop_rate = cond_drop_rate | |
self.prog_si = -1 # progressive training | |
self.patch_nums: Tuple[int] = patch_nums | |
self.L = sum(pn ** 2 for pn in self.patch_nums) | |
self.first_l = self.patch_nums[0] ** 2 | |
self.begin_ends = [] | |
cur = 0 | |
for i, pn in enumerate(self.patch_nums): | |
self.begin_ends.append((cur, cur + pn ** 2)) | |
cur += pn ** 2 | |
self.num_stages_minus_1 = len(self.patch_nums) - 1 | |
self.rng = torch.Generator(device="cpu") | |
# 1. input (word) embedding | |
quant: VectorQuantizer2 = vae_local.quantize | |
self.vae_proxy: Tuple[VQVAE] = (vae_local,) | |
self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,) | |
self.word_embed = nn.Linear(self.Cvae, self.C) | |
# 2. class embedding | |
init_std = math.sqrt(1 / self.C / 3) | |
self.num_classes = num_classes | |
self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32, | |
device=dist.get_device()) | |
self.class_emb = nn.Embedding(self.num_classes + 1, self.C) | |
nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std) | |
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C)) | |
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std) | |
# 3. absolute position embedding | |
pos_1LC = [] | |
for i, pn in enumerate(self.patch_nums): | |
pe = torch.empty(1, pn * pn, self.C) | |
nn.init.trunc_normal_(pe, mean=0, std=init_std) | |
pos_1LC.append(pe) | |
pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C | |
assert tuple(pos_1LC.shape) == (1, self.L, self.C) | |
self.pos_1LC = nn.Parameter(pos_1LC) | |
# level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid) | |
self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C) | |
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std) | |
# 4. backbone blocks | |
self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), | |
SharedAdaLin(self.D, 6 * self.C)) if shared_aln else nn.Identity() | |
norm_layer = partial(nn.LayerNorm, eps=norm_eps) | |
self.drop_path_rate = drop_path_rate | |
dpr = [x.item() for x in | |
torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing) | |
self.blocks = nn.ModuleList([ | |
AdaLNSelfAttn( | |
cond_dim=self.D, shared_aln=shared_aln, | |
block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio, | |
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx], | |
last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1], | |
attn_l2_norm=attn_l2_norm, | |
flash_if_available=flash_if_available, fused_if_available=fused_if_available, | |
) | |
for block_idx in range(depth) | |
]) | |
fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks] | |
self.using_fused_add_norm_fn = any(fused_add_norm_fns) | |
print( | |
f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n' | |
f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n' | |
f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})', | |
end='\n\n', flush=True | |
) | |
# 5. attention mask used in training (for masking out the future) | |
# it won't be used in inference, since kv cache is enabled | |
d: torch.Tensor = torch.cat([torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, | |
1) | |
dT = d.transpose(1, 2) # dT: 11L | |
lvl_1L = dT[:, 0].contiguous() | |
self.register_buffer('lvl_1L', lvl_1L) | |
attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L) | |
self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous()) | |
# 6. classifier head | |
self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer) | |
self.head = nn.Linear(self.C, self.V) | |
def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], | |
cond_BD: Optional[torch.Tensor]): | |
if not isinstance(h_or_h_and_residual, torch.Tensor): | |
h, resi = h_or_h_and_residual # fused_add_norm must be used | |
h = resi + self.blocks[-1].drop_path(h) | |
else: # fused_add_norm is not used | |
h = h_or_h_and_residual | |
return self.head(self.head_nm(h.float(), cond_BD).float()).float() | |
def autoregressive_infer_cfg( | |
self, B: int, label_B: Optional[Union[int, torch.LongTensor]], | |
delta_condition: torch.Tensor, alpha: float, beta: float, | |
g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0, | |
more_smooth=False, | |
) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1] | |
""" | |
Generate images using autoregressive inference with classifier-free guidance. | |
:param B: batch size | |
:param label_B: class labels; if None, randomly sampled | |
:param delta_condition: tensor of shape (B, D) | |
:param alpha: scalar weight for class embedding | |
:param beta: scalar weight for delta_condition | |
:param g_seed: random seed | |
:param cfg: classifier-free guidance ratio | |
:param top_k: top-k sampling | |
:param top_p: top-p sampling | |
:param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking | |
:return: reconstructed images (B, 3, H, W) | |
""" | |
if g_seed is None: | |
rng = None | |
else: | |
self.rng.manual_seed(g_seed) | |
rng = self.rng | |
device = self.lvl_1L.device | |
if label_B is None: | |
label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B) | |
elif isinstance(label_B, int): | |
label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=device) | |
# Prepare labels for conditioned and unconditioned versions | |
label_B_cond = label_B | |
label_B_uncond = torch.full_like(label_B, fill_value=self.num_classes) | |
label_B = torch.cat((label_B_cond, label_B_uncond), dim=0) # shape (2B,) | |
# Prepare delta_condition for conditioned and unconditioned versions | |
delta_condition_uncond = torch.zeros_like(delta_condition) | |
delta_condition = torch.cat((delta_condition, delta_condition_uncond), dim=0) # shape (2B, D) | |
class_emb = self.class_emb(label_B) # shape (2B, D) | |
cond_BD = alpha * class_emb + beta * delta_condition # shape (2B, D) | |
sos = cond_BD.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1) | |
lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC | |
next_token_map = sos + lvl_pos[:, :self.first_l] | |
cur_L = 0 | |
f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1]) | |
for b in self.blocks: | |
b.attn.kv_caching(True) | |
for si, pn in enumerate(self.patch_nums): # si: i-th segment | |
ratio = si / self.num_stages_minus_1 | |
cur_L += pn * pn | |
cond_BD_or_gss = self.shared_ada_lin(cond_BD) | |
x = next_token_map | |
for b in self.blocks: | |
x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None) | |
logits_BlV = self.get_logits(x, cond_BD) | |
t = cfg * ratio | |
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:] | |
idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0] | |
if not more_smooth: # this is the default case | |
h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae | |
else: # not used when evaluating FID/IS/Precision/Recall | |
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git | |
h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ \ | |
self.vae_quant_proxy[0].embedding.weight.unsqueeze(0) | |
h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn) | |
f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums), | |
f_hat, h_BChw) | |
if si != self.num_stages_minus_1: # prepare for next stage | |
next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2) | |
next_token_map = self.word_embed(next_token_map) + lvl_pos[:, | |
cur_L:cur_L + self.patch_nums[si + 1] ** 2] | |
next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG | |
for b in self.blocks: | |
b.attn.kv_caching(False) | |
return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1] | |
def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor, delta_condition: torch.Tensor, | |
alpha: float, beta: float) -> torch.Tensor: | |
""" | |
:param label_B: label_B | |
:param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae) | |
:param delta_condition: tensor of shape (B, D) | |
:param alpha: scalar weight for class embedding | |
:param beta: scalar weight for delta_condition | |
:return: logits BLV, V is vocab_size | |
""" | |
bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L) | |
B = x_BLCv_wo_first_l.shape[0] | |
with torch.cuda.amp.autocast(enabled=False): | |
# Implement conditional dropout | |
drop_mask = torch.rand(B, device=label_B.device) < self.cond_drop_rate | |
label_B_dropped = torch.where(drop_mask, self.num_classes, label_B) | |
delta_condition_dropped = delta_condition.clone() | |
delta_condition_dropped[drop_mask] = 0.0 # Drop delta_condition | |
class_emb = self.class_emb(label_B_dropped) | |
cond_BD = alpha * class_emb + beta * delta_condition_dropped | |
sos = cond_BD.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1) | |
if self.prog_si == 0: | |
x_BLC = sos | |
else: | |
x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1) | |
x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC | |
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed] | |
cond_BD_or_gss = self.shared_ada_lin(cond_BD) | |
# hack: get the dtype if mixed precision is used | |
temp = x_BLC.new_ones(8, 8) | |
main_type = torch.matmul(temp, temp).dtype | |
x_BLC = x_BLC.to(dtype=main_type) | |
cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type) | |
attn_bias = attn_bias.to(dtype=main_type) | |
AdaLNSelfAttn.forward | |
for i, b in enumerate(self.blocks): | |
x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias) | |
x_BLC = self.get_logits(x_BLC.float(), cond_BD) | |
if self.prog_si == 0: | |
if isinstance(self.word_embed, nn.Linear): | |
x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0 | |
else: | |
s = 0 | |
for p in self.word_embed.parameters(): | |
if p.requires_grad: | |
s += p.view(-1)[0] * 0 | |
x_BLC[0, 0, 0] += s | |
return x_BLC # logits BLV, V is vocab_size | |
def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02): | |
if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated | |
print(f'[init_weights] {type(self).__name__} with {init_std=:g}') | |
for m in self.modules(): | |
with_weight = hasattr(m, 'weight') and m.weight is not None | |
with_bias = hasattr(m, 'bias') and m.bias is not None | |
if isinstance(m, nn.Linear): | |
nn.init.trunc_normal_(m.weight.data, std=init_std) | |
if with_bias: m.bias.data.zero_() | |
elif isinstance(m, nn.Embedding): | |
nn.init.trunc_normal_(m.weight.data, std=init_std) | |
if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_() | |
elif isinstance(m, ( | |
nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, | |
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)): | |
if with_weight: m.weight.data.fill_(1.) | |
if with_bias: m.bias.data.zero_() | |
# conv: VAR has no conv, only VQVAE has conv | |
elif isinstance(m, ( | |
nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): | |
if conv_std_or_gain > 0: | |
nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain) | |
else: | |
nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain) | |
if with_bias: m.bias.data.zero_() | |
if init_head >= 0: | |
if isinstance(self.head, nn.Linear): | |
self.head.weight.data.mul_(init_head) | |
self.head.bias.data.zero_() | |
elif isinstance(self.head, nn.Sequential): | |
self.head[-1].weight.data.mul_(init_head) | |
self.head[-1].bias.data.zero_() | |
if isinstance(self.head_nm, AdaLNBeforeHead): | |
self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln) | |
if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None: | |
self.head_nm.ada_lin[-1].bias.data.zero_() | |
depth = len(self.blocks) | |
for block_idx, sab in enumerate(self.blocks): | |
sab: AdaLNSelfAttn | |
sab.attn.proj.weight.data.div_(math.sqrt(2 * depth)) | |
sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth)) | |
if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None: | |
nn.init.ones_(sab.ffn.fcg.bias) | |
nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5) | |
if hasattr(sab, 'ada_lin'): | |
sab.ada_lin[-1].weight.data[2 * self.C:].mul_(init_adaln) | |
sab.ada_lin[-1].weight.data[:2 * self.C].mul_(init_adaln_gamma) | |
if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None: | |
sab.ada_lin[-1].bias.data.zero_() | |
elif hasattr(sab, 'ada_gss'): | |
sab.ada_gss.data[:, :, 2:].mul_(init_adaln) | |
sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma) | |
def extra_repr(self): | |
return f'drop_path_rate={self.drop_path_rate:g}' | |
class VARHF(VAR, PyTorchModelHubMixin): | |
def __init__( | |
self, | |
vae_kwargs, | |
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., | |
drop_path_rate=0., | |
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1, | |
attn_l2_norm=False, | |
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default | |
flash_if_available=True, fused_if_available=True, | |
): | |
vae_local = VQVAE(**vae_kwargs) | |
super().__init__( | |
vae_local=vae_local, | |
num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, | |
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, | |
norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate, | |
attn_l2_norm=attn_l2_norm, | |
patch_nums=patch_nums, | |
flash_if_available=flash_if_available, fused_if_available=fused_if_available, | |
) |