Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# DeiT: https://github.com/facebookresearch/deit | |
# -------------------------------------------------------- | |
from functools import partial | |
from json import encoder | |
import torch | |
import torch.nn as nn | |
from timm.models.vision_transformer import Block | |
from qa_mdt.audioldm_train.modules.audiomae.util.pos_embed import ( | |
get_2d_sincos_pos_embed, | |
get_2d_sincos_pos_embed_flexible, | |
get_1d_sincos_pos_embed_from_grid, | |
) | |
from qa_mdt.audioldm_train.modules.audiomae.util.patch_embed import ( | |
PatchEmbed_new, | |
PatchEmbed_org, | |
) | |
class MaskedAutoencoderViT(nn.Module): | |
"""Masked Autoencoder with VisionTransformer backbone""" | |
def __init__( | |
self, | |
img_size=224, | |
patch_size=16, | |
stride=10, | |
in_chans=3, | |
embed_dim=1024, | |
depth=24, | |
num_heads=16, | |
decoder_embed_dim=512, | |
decoder_depth=8, | |
decoder_num_heads=16, | |
mlp_ratio=4.0, | |
norm_layer=nn.LayerNorm, | |
norm_pix_loss=False, | |
audio_exp=False, | |
alpha=0.0, | |
temperature=0.2, | |
mode=0, | |
contextual_depth=8, | |
use_custom_patch=False, | |
split_pos=False, | |
pos_trainable=False, | |
use_nce=False, | |
beta=4.0, | |
decoder_mode=0, | |
mask_t_prob=0.6, | |
mask_f_prob=0.5, | |
mask_2d=False, | |
epoch=0, | |
no_shift=False, | |
): | |
super().__init__() | |
self.audio_exp = audio_exp | |
self.embed_dim = embed_dim | |
self.decoder_embed_dim = decoder_embed_dim | |
# -------------------------------------------------------------------------- | |
# MAE encoder specifics | |
if use_custom_patch: | |
print( | |
f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}" | |
) | |
self.patch_embed = PatchEmbed_new( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
stride=stride, | |
) | |
else: | |
self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim) | |
self.use_custom_patch = use_custom_patch | |
num_patches = self.patch_embed.num_patches | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
# self.split_pos = split_pos # not useful | |
self.pos_embed = nn.Parameter( | |
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable | |
) # fixed sin-cos embedding | |
self.encoder_depth = depth | |
self.contextual_depth = contextual_depth | |
self.blocks = nn.ModuleList( | |
[ | |
Block( | |
embed_dim, | |
num_heads, | |
mlp_ratio, | |
qkv_bias=True, | |
norm_layer=norm_layer, | |
) # qk_scale=None | |
for i in range(depth) | |
] | |
) | |
self.norm = norm_layer(embed_dim) | |
# -------------------------------------------------------------------------- | |
# MAE decoder specifics | |
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) | |
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
self.decoder_pos_embed = nn.Parameter( | |
torch.zeros(1, num_patches + 1, decoder_embed_dim), | |
requires_grad=pos_trainable, | |
) # fixed sin-cos embedding | |
self.no_shift = no_shift | |
self.decoder_mode = decoder_mode | |
if ( | |
self.use_custom_patch | |
): # overlapped patches as in AST. Similar performance yet compute heavy | |
window_size = (6, 6) | |
feat_size = (102, 12) | |
else: | |
window_size = (4, 4) | |
feat_size = (64, 8) | |
if self.decoder_mode == 1: | |
decoder_modules = [] | |
for index in range(16): | |
if self.no_shift: | |
shift_size = (0, 0) | |
else: | |
if (index % 2) == 0: | |
shift_size = (0, 0) | |
else: | |
shift_size = (2, 0) | |
# shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]) | |
decoder_modules.append( | |
SwinTransformerBlock( | |
dim=decoder_embed_dim, | |
num_heads=16, | |
feat_size=feat_size, | |
window_size=window_size, | |
shift_size=shift_size, | |
mlp_ratio=mlp_ratio, | |
drop=0.0, | |
drop_attn=0.0, | |
drop_path=0.0, | |
extra_norm=False, | |
sequential_attn=False, | |
norm_layer=norm_layer, # nn.LayerNorm, | |
) | |
) | |
self.decoder_blocks = nn.ModuleList(decoder_modules) | |
else: | |
# Transfomer | |
self.decoder_blocks = nn.ModuleList( | |
[ | |
Block( | |
decoder_embed_dim, | |
decoder_num_heads, | |
mlp_ratio, | |
qkv_bias=True, | |
norm_layer=norm_layer, | |
) # qk_scale=None, | |
for i in range(decoder_depth) | |
] | |
) | |
self.decoder_norm = norm_layer(decoder_embed_dim) | |
self.decoder_pred = nn.Linear( | |
decoder_embed_dim, patch_size**2 * in_chans, bias=True | |
) # decoder to patch | |
# -------------------------------------------------------------------------- | |
self.norm_pix_loss = norm_pix_loss | |
self.patch_size = patch_size | |
self.stride = stride | |
# audio exps | |
self.alpha = alpha | |
self.T = temperature | |
self.mode = mode | |
self.use_nce = use_nce | |
self.beta = beta | |
self.log_softmax = nn.LogSoftmax(dim=-1) | |
self.mask_t_prob = mask_t_prob | |
self.mask_f_prob = mask_f_prob | |
self.mask_2d = mask_2d | |
self.epoch = epoch | |
self.initialize_weights() | |
def initialize_weights(self): | |
# initialization | |
# initialize (and freeze) pos_embed by sin-cos embedding | |
if self.audio_exp: | |
pos_embed = get_2d_sincos_pos_embed_flexible( | |
self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True | |
) | |
else: | |
pos_embed = get_2d_sincos_pos_embed( | |
self.pos_embed.shape[-1], | |
int(self.patch_embed.num_patches**0.5), | |
cls_token=True, | |
) | |
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
if self.audio_exp: | |
decoder_pos_embed = get_2d_sincos_pos_embed_flexible( | |
self.decoder_pos_embed.shape[-1], | |
self.patch_embed.patch_hw, | |
cls_token=True, | |
) | |
else: | |
decoder_pos_embed = get_2d_sincos_pos_embed( | |
self.decoder_pos_embed.shape[-1], | |
int(self.patch_embed.num_patches**0.5), | |
cls_token=True, | |
) | |
self.decoder_pos_embed.data.copy_( | |
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) | |
) | |
# initialize patch_embed like nn.Linear (instead of nn.Conv2d) | |
w = self.patch_embed.proj.weight.data | |
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
torch.nn.init.normal_(self.cls_token, std=0.02) | |
torch.nn.init.normal_(self.mask_token, std=0.02) | |
# initialize nn.Linear and nn.LayerNorm | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
# we use xavier_uniform following official JAX ViT: | |
torch.nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def patchify(self, imgs): | |
""" | |
imgs: (N, 3, H, W) | |
x: (N, L, patch_size**2 *3) | |
L = (H/p)*(W/p) | |
""" | |
p = self.patch_embed.patch_size[0] | |
# assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 | |
if self.audio_exp: | |
if self.use_custom_patch: # overlapped patch | |
h, w = self.patch_embed.patch_hw | |
# todo: fixed h/w patch size and stride size. Make hw custom in the future | |
x = imgs.unfold(2, self.patch_size, self.stride).unfold( | |
3, self.patch_size, self.stride | |
) # n,1,H,W -> n,1,h,w,p,p | |
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) | |
# x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) | |
# x = torch.einsum('nchpwq->nhwpqc', x) | |
# x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) | |
else: | |
h = imgs.shape[2] // p | |
w = imgs.shape[3] // p | |
# h,w = self.patch_embed.patch_hw | |
x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) | |
x = torch.einsum("nchpwq->nhwpqc", x) | |
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) | |
else: | |
h = w = imgs.shape[2] // p | |
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) | |
x = torch.einsum("nchpwq->nhwpqc", x) | |
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) | |
return x | |
def unpatchify(self, x): | |
""" | |
x: (N, L, patch_size**2 *3) | |
specs: (N, 1, H, W) | |
""" | |
p = self.patch_embed.patch_size[0] | |
h = 1024 // p | |
w = 128 // p | |
x = x.reshape(shape=(x.shape[0], h, w, p, p, 1)) | |
x = torch.einsum("nhwpqc->nchpwq", x) | |
specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p)) | |
return specs | |
def random_masking(self, x, mask_ratio): | |
""" | |
Perform per-sample random masking by per-sample shuffling. | |
Per-sample shuffling is done by argsort random noise. | |
x: [N, L, D], sequence | |
""" | |
N, L, D = x.shape # batch, length, dim | |
len_keep = int(L * (1 - mask_ratio)) | |
noise = torch.rand(N, L, device=x.device) # noise in [0, 1] | |
# sort noise for each sample | |
ids_shuffle = torch.argsort( | |
noise, dim=1 | |
) # ascend: small is keep, large is remove | |
ids_restore = torch.argsort(ids_shuffle, dim=1) | |
# keep the first subset | |
ids_keep = ids_shuffle[:, :len_keep] | |
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) | |
# generate the binary mask: 0 is keep, 1 is remove | |
mask = torch.ones([N, L], device=x.device) | |
mask[:, :len_keep] = 0 | |
# unshuffle to get the binary mask | |
mask = torch.gather(mask, dim=1, index=ids_restore) | |
return x_masked, mask, ids_restore | |
def random_masking_2d(self, x, mask_t_prob, mask_f_prob): | |
""" | |
2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob) | |
Perform per-sample random masking by per-sample shuffling. | |
Per-sample shuffling is done by argsort random noise. | |
x: [N, L, D], sequence | |
""" | |
N, L, D = x.shape # batch, length, dim | |
if self.use_custom_patch: # overlapped patch | |
T = 101 | |
F = 12 | |
else: | |
T = 64 | |
F = 8 | |
# x = x.reshape(N, T, F, D) | |
len_keep_t = int(T * (1 - mask_t_prob)) | |
len_keep_f = int(F * (1 - mask_f_prob)) | |
# noise for mask in time | |
noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1] | |
# sort noise for each sample aling time | |
ids_shuffle_t = torch.argsort( | |
noise_t, dim=1 | |
) # ascend: small is keep, large is remove | |
ids_restore_t = torch.argsort(ids_shuffle_t, dim=1) | |
ids_keep_t = ids_shuffle_t[:, :len_keep_t] | |
# noise mask in freq | |
noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1] | |
ids_shuffle_f = torch.argsort( | |
noise_f, dim=1 | |
) # ascend: small is keep, large is remove | |
ids_restore_f = torch.argsort(ids_shuffle_f, dim=1) | |
ids_keep_f = ids_shuffle_f[:, :len_keep_f] # | |
# generate the binary mask: 0 is keep, 1 is remove | |
# mask in freq | |
mask_f = torch.ones(N, F, device=x.device) | |
mask_f[:, :len_keep_f] = 0 | |
mask_f = ( | |
torch.gather(mask_f, dim=1, index=ids_restore_f) | |
.unsqueeze(1) | |
.repeat(1, T, 1) | |
) # N,T,F | |
# mask in time | |
mask_t = torch.ones(N, T, device=x.device) | |
mask_t[:, :len_keep_t] = 0 | |
mask_t = ( | |
torch.gather(mask_t, dim=1, index=ids_restore_t) | |
.unsqueeze(1) | |
.repeat(1, F, 1) | |
.permute(0, 2, 1) | |
) # N,T,F | |
mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F | |
# get masked x | |
id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device) | |
id2res = id2res + 999 * mask # add a large value for masked elements | |
id2res2 = torch.argsort(id2res.flatten(start_dim=1)) | |
ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t] | |
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) | |
ids_restore = torch.argsort(id2res2.flatten(start_dim=1)) | |
mask = mask.flatten(start_dim=1) | |
return x_masked, mask, ids_restore | |
def forward_encoder(self, x, mask_ratio, mask_2d=False): | |
# embed patches | |
x = self.patch_embed(x) | |
# add pos embed w/o cls token | |
x = x + self.pos_embed[:, 1:, :] | |
# masking: length -> length * mask_ratio | |
if mask_2d: | |
x, mask, ids_restore = self.random_masking_2d( | |
x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob | |
) | |
else: | |
x, mask, ids_restore = self.random_masking(x, mask_ratio) | |
# append cls token | |
cls_token = self.cls_token + self.pos_embed[:, :1, :] | |
cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# apply Transformer blocks | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.norm(x) | |
return x, mask, ids_restore, None | |
def forward_encoder_no_random_mask_no_average(self, x): | |
# embed patches | |
x = self.patch_embed(x) | |
# add pos embed w/o cls token | |
x = x + self.pos_embed[:, 1:, :] | |
# masking: length -> length * mask_ratio | |
# if mask_2d: | |
# x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob) | |
# else: | |
# x, mask, ids_restore = self.random_masking(x, mask_ratio) | |
# append cls token | |
cls_token = self.cls_token + self.pos_embed[:, :1, :] | |
cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# apply Transformer blocks | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.norm(x) | |
return x | |
def forward_encoder_no_mask(self, x): | |
# embed patches | |
x = self.patch_embed(x) | |
# add pos embed w/o cls token | |
x = x + self.pos_embed[:, 1:, :] | |
# masking: length -> length * mask_ratio | |
# x, mask, ids_restore = self.random_masking(x, mask_ratio) | |
# append cls token | |
cls_token = self.cls_token + self.pos_embed[:, :1, :] | |
cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# apply Transformer blocks | |
contextual_embs = [] | |
for n, blk in enumerate(self.blocks): | |
x = blk(x) | |
if n > self.contextual_depth: | |
contextual_embs.append(self.norm(x)) | |
# x = self.norm(x) | |
contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0) | |
return contextual_emb | |
def forward_decoder(self, x, ids_restore): | |
# embed tokens | |
x = self.decoder_embed(x) | |
# append mask tokens to sequence | |
mask_tokens = self.mask_token.repeat( | |
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 | |
) | |
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token | |
x_ = torch.gather( | |
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) | |
) # unshuffle | |
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token | |
# add pos embed | |
x = x + self.decoder_pos_embed | |
if self.decoder_mode != 0: | |
B, L, D = x.shape | |
x = x[:, 1:, :] | |
if self.use_custom_patch: | |
x = x.reshape(B, 101, 12, D) | |
x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack | |
x = x.reshape(B, 1224, D) | |
if self.decoder_mode > 3: # mvit | |
x = self.decoder_blocks(x) | |
else: | |
# apply Transformer blocks | |
for blk in self.decoder_blocks: | |
x = blk(x) | |
x = self.decoder_norm(x) | |
# predictor projection | |
pred = self.decoder_pred(x) | |
# remove cls token | |
if self.decoder_mode != 0: | |
if self.use_custom_patch: | |
pred = pred.reshape(B, 102, 12, 256) | |
pred = pred[:, :101, :, :] | |
pred = pred.reshape(B, 1212, 256) | |
else: | |
pred = pred | |
else: | |
pred = pred[:, 1:, :] | |
return pred, None, None # emb, emb_pixel | |
def forward_loss(self, imgs, pred, mask, norm_pix_loss=False): | |
""" | |
imgs: [N, 3, H, W] | |
pred: [N, L, p*p*3] | |
mask: [N, L], 0 is keep, 1 is remove, | |
""" | |
target = self.patchify(imgs) | |
if norm_pix_loss: | |
mean = target.mean(dim=-1, keepdim=True) | |
var = target.var(dim=-1, keepdim=True) | |
target = (target - mean) / (var + 1.0e-6) ** 0.5 | |
loss = (pred - target) ** 2 | |
loss = loss.mean(dim=-1) # [N, L], mean loss per patch | |
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches | |
return loss | |
def forward(self, imgs, mask_ratio=0.8): | |
emb_enc, mask, ids_restore, _ = self.forward_encoder( | |
imgs, mask_ratio, mask_2d=self.mask_2d | |
) | |
pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3] | |
loss_recon = self.forward_loss( | |
imgs, pred, mask, norm_pix_loss=self.norm_pix_loss | |
) | |
loss_contrastive = torch.FloatTensor([0.0]).cuda() | |
return loss_recon, pred, mask, loss_contrastive | |
def mae_vit_small_patch16_dec512d8b(**kwargs): | |
model = MaskedAutoencoderViT( | |
patch_size=16, | |
embed_dim=384, | |
depth=12, | |
num_heads=6, | |
decoder_embed_dim=512, | |
decoder_num_heads=16, | |
mlp_ratio=4, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs, | |
) | |
return model | |
def mae_vit_base_patch16_dec512d8b(**kwargs): | |
model = MaskedAutoencoderViT( | |
patch_size=16, | |
embed_dim=768, | |
depth=12, | |
num_heads=12, | |
decoder_embed_dim=512, | |
decoder_num_heads=16, | |
mlp_ratio=4, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs, | |
) | |
return model | |
def mae_vit_large_patch16_dec512d8b(**kwargs): | |
model = MaskedAutoencoderViT( | |
patch_size=16, | |
embed_dim=1024, | |
depth=24, | |
num_heads=16, | |
decoder_embed_dim=512, | |
decoder_num_heads=16, | |
mlp_ratio=4, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs, | |
) | |
return model | |
def mae_vit_huge_patch14_dec512d8b(**kwargs): | |
model = MaskedAutoencoderViT( | |
patch_size=14, | |
embed_dim=1280, | |
depth=32, | |
num_heads=16, | |
decoder_embed_dim=512, | |
decoder_num_heads=16, | |
mlp_ratio=4, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs, | |
) | |
return model | |
# set recommended archs | |
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks | |
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks | |
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks | |
mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks | |