Spaces:
Running
Running
""" Multi-Scale Vision Transformer v2 | |
@inproceedings{li2021improved, | |
title={MViTv2: Improved multiscale vision transformers for classification and detection}, | |
author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph}, | |
booktitle={CVPR}, | |
year={2022} | |
} | |
Code adapted from original Apache 2.0 licensed impl at https://github.com/facebookresearch/mvit | |
Original copyright below. | |
Modifications and timm support by / Copyright 2022, Ross Wightman | |
""" | |
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved. | |
import operator | |
from collections import OrderedDict | |
from dataclasses import dataclass | |
from functools import partial, reduce | |
from typing import Union, List, Tuple, Optional | |
import torch | |
import torch.utils.checkpoint as checkpoint | |
from torch import nn | |
from custom_timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from .fx_features import register_notrace_function | |
from .helpers import build_model_with_cfg | |
from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple | |
from .registry import register_model | |
def _cfg(url='', **kwargs): | |
return { | |
'url': url, | |
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, | |
'crop_pct': .9, 'interpolation': 'bicubic', | |
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, | |
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', | |
'fixed_input_size': True, | |
**kwargs | |
} | |
default_cfgs = dict( | |
mvitv2_tiny=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'), | |
mvitv2_small=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'), | |
mvitv2_base=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'), | |
mvitv2_large=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'), | |
mvitv2_base_in21k=_cfg( | |
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth', | |
num_classes=19168), | |
mvitv2_large_in21k=_cfg( | |
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth', | |
num_classes=19168), | |
mvitv2_huge_in21k=_cfg( | |
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth', | |
num_classes=19168), | |
mvitv2_small_cls=_cfg(url=''), | |
) | |
class MultiScaleVitCfg: | |
depths: Tuple[int, ...] = (2, 3, 16, 3) | |
embed_dim: Union[int, Tuple[int, ...]] = 96 | |
num_heads: Union[int, Tuple[int, ...]] = 1 | |
mlp_ratio: float = 4. | |
pool_first: bool = False | |
expand_attn: bool = True | |
qkv_bias: bool = True | |
use_cls_token: bool = False | |
use_abs_pos: bool = False | |
residual_pooling: bool = True | |
mode: str = 'conv' | |
kernel_qkv: Tuple[int, int] = (3, 3) | |
stride_q: Optional[Tuple[Tuple[int, int]]] = ((1, 1), (2, 2), (2, 2), (2, 2)) | |
stride_kv: Optional[Tuple[Tuple[int, int]]] = None | |
stride_kv_adaptive: Optional[Tuple[int, int]] = (4, 4) | |
patch_kernel: Tuple[int, int] = (7, 7) | |
patch_stride: Tuple[int, int] = (4, 4) | |
patch_padding: Tuple[int, int] = (3, 3) | |
pool_type: str = 'max' | |
rel_pos_type: str = 'spatial' | |
act_layer: Union[str, Tuple[str, str]] = 'gelu' | |
norm_layer: Union[str, Tuple[str, str]] = 'layernorm' | |
norm_eps: float = 1e-6 | |
def __post_init__(self): | |
num_stages = len(self.depths) | |
if not isinstance(self.embed_dim, (tuple, list)): | |
self.embed_dim = tuple(self.embed_dim * 2 ** i for i in range(num_stages)) | |
assert len(self.embed_dim) == num_stages | |
if not isinstance(self.num_heads, (tuple, list)): | |
self.num_heads = tuple(self.num_heads * 2 ** i for i in range(num_stages)) | |
assert len(self.num_heads) == num_stages | |
if self.stride_kv_adaptive is not None and self.stride_kv is None: | |
_stride_kv = self.stride_kv_adaptive | |
pool_kv_stride = [] | |
for i in range(num_stages): | |
if min(self.stride_q[i]) > 1: | |
_stride_kv = [ | |
max(_stride_kv[d] // self.stride_q[i][d], 1) | |
for d in range(len(_stride_kv)) | |
] | |
pool_kv_stride.append(tuple(_stride_kv)) | |
self.stride_kv = tuple(pool_kv_stride) | |
model_cfgs = dict( | |
mvitv2_tiny=MultiScaleVitCfg( | |
depths=(1, 2, 5, 2), | |
), | |
mvitv2_small=MultiScaleVitCfg( | |
depths=(1, 2, 11, 2), | |
), | |
mvitv2_base=MultiScaleVitCfg( | |
depths=(2, 3, 16, 3), | |
), | |
mvitv2_large=MultiScaleVitCfg( | |
depths=(2, 6, 36, 4), | |
embed_dim=144, | |
num_heads=2, | |
expand_attn=False, | |
), | |
mvitv2_base_in21k=MultiScaleVitCfg( | |
depths=(2, 3, 16, 3), | |
), | |
mvitv2_large_in21k=MultiScaleVitCfg( | |
depths=(2, 6, 36, 4), | |
embed_dim=144, | |
num_heads=2, | |
expand_attn=False, | |
), | |
mvitv2_small_cls=MultiScaleVitCfg( | |
depths=(1, 2, 11, 2), | |
use_cls_token=True, | |
), | |
) | |
def prod(iterable): | |
return reduce(operator.mul, iterable, 1) | |
class PatchEmbed(nn.Module): | |
""" | |
PatchEmbed. | |
""" | |
def __init__( | |
self, | |
dim_in=3, | |
dim_out=768, | |
kernel=(7, 7), | |
stride=(4, 4), | |
padding=(3, 3), | |
): | |
super().__init__() | |
self.proj = nn.Conv2d( | |
dim_in, | |
dim_out, | |
kernel_size=kernel, | |
stride=stride, | |
padding=padding, | |
) | |
def forward(self, x) -> Tuple[torch.Tensor, List[int]]: | |
x = self.proj(x) | |
# B C H W -> B HW C | |
return x.flatten(2).transpose(1, 2), x.shape[-2:] | |
def reshape_pre_pool( | |
x, | |
feat_size: List[int], | |
has_cls_token: bool = True | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
H, W = feat_size | |
if has_cls_token: | |
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] | |
else: | |
cls_tok = None | |
x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous() | |
return x, cls_tok | |
def reshape_post_pool( | |
x, | |
num_heads: int, | |
cls_tok: Optional[torch.Tensor] = None | |
) -> Tuple[torch.Tensor, List[int]]: | |
feat_size = [x.shape[2], x.shape[3]] | |
L_pooled = x.shape[2] * x.shape[3] | |
x = x.reshape(-1, num_heads, x.shape[1], L_pooled).transpose(2, 3) | |
if cls_tok is not None: | |
x = torch.cat((cls_tok, x), dim=2) | |
return x, feat_size | |
def cal_rel_pos_type( | |
attn: torch.Tensor, | |
q: torch.Tensor, | |
has_cls_token: bool, | |
q_size: List[int], | |
k_size: List[int], | |
rel_pos_h: torch.Tensor, | |
rel_pos_w: torch.Tensor, | |
): | |
""" | |
Spatial Relative Positional Embeddings. | |
""" | |
sp_idx = 1 if has_cls_token else 0 | |
q_h, q_w = q_size | |
k_h, k_w = k_size | |
# Scale up rel pos if shapes for q and k are different. | |
q_h_ratio = max(k_h / q_h, 1.0) | |
k_h_ratio = max(q_h / k_h, 1.0) | |
dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio | |
dist_h += (k_h - 1) * k_h_ratio | |
q_w_ratio = max(k_w / q_w, 1.0) | |
k_w_ratio = max(q_w / k_w, 1.0) | |
dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio | |
dist_w += (k_w - 1) * k_w_ratio | |
Rh = rel_pos_h[dist_h.long()] | |
Rw = rel_pos_w[dist_w.long()] | |
B, n_head, q_N, dim = q.shape | |
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim) | |
rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh) | |
rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw) | |
attn[:, :, sp_idx:, sp_idx:] = ( | |
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) | |
+ rel_h[:, :, :, :, :, None] | |
+ rel_w[:, :, :, :, None, :] | |
).view(B, -1, q_h * q_w, k_h * k_w) | |
return attn | |
class MultiScaleAttentionPoolFirst(nn.Module): | |
def __init__( | |
self, | |
dim, | |
dim_out, | |
feat_size, | |
num_heads=8, | |
qkv_bias=True, | |
mode="conv", | |
kernel_q=(1, 1), | |
kernel_kv=(1, 1), | |
stride_q=(1, 1), | |
stride_kv=(1, 1), | |
has_cls_token=True, | |
rel_pos_type='spatial', | |
residual_pooling=True, | |
norm_layer=nn.LayerNorm, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
self.dim_out = dim_out | |
self.head_dim = dim_out // num_heads | |
self.scale = self.head_dim ** -0.5 | |
self.has_cls_token = has_cls_token | |
padding_q = tuple([int(q // 2) for q in kernel_q]) | |
padding_kv = tuple([int(kv // 2) for kv in kernel_kv]) | |
self.q = nn.Linear(dim, dim_out, bias=qkv_bias) | |
self.k = nn.Linear(dim, dim_out, bias=qkv_bias) | |
self.v = nn.Linear(dim, dim_out, bias=qkv_bias) | |
self.proj = nn.Linear(dim_out, dim_out) | |
# Skip pooling with kernel and stride size of (1, 1, 1). | |
if prod(kernel_q) == 1 and prod(stride_q) == 1: | |
kernel_q = None | |
if prod(kernel_kv) == 1 and prod(stride_kv) == 1: | |
kernel_kv = None | |
self.mode = mode | |
self.unshared = mode == 'conv_unshared' | |
self.pool_q, self.pool_k, self.pool_v = None, None, None | |
self.norm_q, self.norm_k, self.norm_v = None, None, None | |
if mode in ("avg", "max"): | |
pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d | |
if kernel_q: | |
self.pool_q = pool_op(kernel_q, stride_q, padding_q) | |
if kernel_kv: | |
self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv) | |
self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv) | |
elif mode == "conv" or mode == "conv_unshared": | |
dim_conv = dim // num_heads if mode == "conv" else dim | |
if kernel_q: | |
self.pool_q = nn.Conv2d( | |
dim_conv, | |
dim_conv, | |
kernel_q, | |
stride=stride_q, | |
padding=padding_q, | |
groups=dim_conv, | |
bias=False, | |
) | |
self.norm_q = norm_layer(dim_conv) | |
if kernel_kv: | |
self.pool_k = nn.Conv2d( | |
dim_conv, | |
dim_conv, | |
kernel_kv, | |
stride=stride_kv, | |
padding=padding_kv, | |
groups=dim_conv, | |
bias=False, | |
) | |
self.norm_k = norm_layer(dim_conv) | |
self.pool_v = nn.Conv2d( | |
dim_conv, | |
dim_conv, | |
kernel_kv, | |
stride=stride_kv, | |
padding=padding_kv, | |
groups=dim_conv, | |
bias=False, | |
) | |
self.norm_v = norm_layer(dim_conv) | |
else: | |
raise NotImplementedError(f"Unsupported model {mode}") | |
# relative pos embedding | |
self.rel_pos_type = rel_pos_type | |
if self.rel_pos_type == 'spatial': | |
assert feat_size[0] == feat_size[1] | |
size = feat_size[0] | |
q_size = size // stride_q[1] if len(stride_q) > 0 else size | |
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size | |
rel_sp_dim = 2 * max(q_size, kv_size) - 1 | |
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) | |
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) | |
trunc_normal_tf_(self.rel_pos_h, std=0.02) | |
trunc_normal_tf_(self.rel_pos_w, std=0.02) | |
self.residual_pooling = residual_pooling | |
def forward(self, x, feat_size: List[int]): | |
B, N, _ = x.shape | |
fold_dim = 1 if self.unshared else self.num_heads | |
x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3) | |
q = k = v = x | |
if self.pool_q is not None: | |
q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token) | |
q = self.pool_q(q) | |
q, q_size = reshape_post_pool(q, self.num_heads, q_tok) | |
else: | |
q_size = feat_size | |
if self.norm_q is not None: | |
q = self.norm_q(q) | |
if self.pool_k is not None: | |
k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token) | |
k = self.pool_k(k) | |
k, k_size = reshape_post_pool(k, self.num_heads, k_tok) | |
else: | |
k_size = feat_size | |
if self.norm_k is not None: | |
k = self.norm_k(k) | |
if self.pool_v is not None: | |
v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token) | |
v = self.pool_v(v) | |
v, v_size = reshape_post_pool(v, self.num_heads, v_tok) | |
else: | |
v_size = feat_size | |
if self.norm_v is not None: | |
v = self.norm_v(v) | |
q_N = q_size[0] * q_size[1] + int(self.has_cls_token) | |
q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1) | |
q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3) | |
k_N = k_size[0] * k_size[1] + int(self.has_cls_token) | |
k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1) | |
k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3) | |
v_N = v_size[0] * v_size[1] + int(self.has_cls_token) | |
v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1) | |
v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3) | |
attn = (q * self.scale) @ k.transpose(-2, -1) | |
if self.rel_pos_type == 'spatial': | |
attn = cal_rel_pos_type( | |
attn, | |
q, | |
self.has_cls_token, | |
q_size, | |
k_size, | |
self.rel_pos_h, | |
self.rel_pos_w, | |
) | |
attn = attn.softmax(dim=-1) | |
x = attn @ v | |
if self.residual_pooling: | |
x = x + q | |
x = x.transpose(1, 2).reshape(B, -1, self.dim_out) | |
x = self.proj(x) | |
return x, q_size | |
class MultiScaleAttention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
dim_out, | |
feat_size, | |
num_heads=8, | |
qkv_bias=True, | |
mode="conv", | |
kernel_q=(1, 1), | |
kernel_kv=(1, 1), | |
stride_q=(1, 1), | |
stride_kv=(1, 1), | |
has_cls_token=True, | |
rel_pos_type='spatial', | |
residual_pooling=True, | |
norm_layer=nn.LayerNorm, | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
self.dim_out = dim_out | |
self.head_dim = dim_out // num_heads | |
self.scale = self.head_dim ** -0.5 | |
self.has_cls_token = has_cls_token | |
padding_q = tuple([int(q // 2) for q in kernel_q]) | |
padding_kv = tuple([int(kv // 2) for kv in kernel_kv]) | |
self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) | |
self.proj = nn.Linear(dim_out, dim_out) | |
# Skip pooling with kernel and stride size of (1, 1, 1). | |
if prod(kernel_q) == 1 and prod(stride_q) == 1: | |
kernel_q = None | |
if prod(kernel_kv) == 1 and prod(stride_kv) == 1: | |
kernel_kv = None | |
self.mode = mode | |
self.unshared = mode == 'conv_unshared' | |
self.norm_q, self.norm_k, self.norm_v = None, None, None | |
self.pool_q, self.pool_k, self.pool_v = None, None, None | |
if mode in ("avg", "max"): | |
pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d | |
if kernel_q: | |
self.pool_q = pool_op(kernel_q, stride_q, padding_q) | |
if kernel_kv: | |
self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv) | |
self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv) | |
elif mode == "conv" or mode == "conv_unshared": | |
dim_conv = dim_out // num_heads if mode == "conv" else dim_out | |
if kernel_q: | |
self.pool_q = nn.Conv2d( | |
dim_conv, | |
dim_conv, | |
kernel_q, | |
stride=stride_q, | |
padding=padding_q, | |
groups=dim_conv, | |
bias=False, | |
) | |
self.norm_q = norm_layer(dim_conv) | |
if kernel_kv: | |
self.pool_k = nn.Conv2d( | |
dim_conv, | |
dim_conv, | |
kernel_kv, | |
stride=stride_kv, | |
padding=padding_kv, | |
groups=dim_conv, | |
bias=False, | |
) | |
self.norm_k = norm_layer(dim_conv) | |
self.pool_v = nn.Conv2d( | |
dim_conv, | |
dim_conv, | |
kernel_kv, | |
stride=stride_kv, | |
padding=padding_kv, | |
groups=dim_conv, | |
bias=False, | |
) | |
self.norm_v = norm_layer(dim_conv) | |
else: | |
raise NotImplementedError(f"Unsupported model {mode}") | |
# relative pos embedding | |
self.rel_pos_type = rel_pos_type | |
if self.rel_pos_type == 'spatial': | |
assert feat_size[0] == feat_size[1] | |
size = feat_size[0] | |
q_size = size // stride_q[1] if len(stride_q) > 0 else size | |
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size | |
rel_sp_dim = 2 * max(q_size, kv_size) - 1 | |
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) | |
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) | |
trunc_normal_tf_(self.rel_pos_h, std=0.02) | |
trunc_normal_tf_(self.rel_pos_w, std=0.02) | |
self.residual_pooling = residual_pooling | |
def forward(self, x, feat_size: List[int]): | |
B, N, _ = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(dim=0) | |
if self.pool_q is not None: | |
q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token) | |
q = self.pool_q(q) | |
q, q_size = reshape_post_pool(q, self.num_heads, q_tok) | |
else: | |
q_size = feat_size | |
if self.norm_q is not None: | |
q = self.norm_q(q) | |
if self.pool_k is not None: | |
k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token) | |
k = self.pool_k(k) | |
k, k_size = reshape_post_pool(k, self.num_heads, k_tok) | |
else: | |
k_size = feat_size | |
if self.norm_k is not None: | |
k = self.norm_k(k) | |
if self.pool_v is not None: | |
v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token) | |
v = self.pool_v(v) | |
v, _ = reshape_post_pool(v, self.num_heads, v_tok) | |
if self.norm_v is not None: | |
v = self.norm_v(v) | |
attn = (q * self.scale) @ k.transpose(-2, -1) | |
if self.rel_pos_type == 'spatial': | |
attn = cal_rel_pos_type( | |
attn, | |
q, | |
self.has_cls_token, | |
q_size, | |
k_size, | |
self.rel_pos_h, | |
self.rel_pos_w, | |
) | |
attn = attn.softmax(dim=-1) | |
x = attn @ v | |
if self.residual_pooling: | |
x = x + q | |
x = x.transpose(1, 2).reshape(B, -1, self.dim_out) | |
x = self.proj(x) | |
return x, q_size | |
class MultiScaleBlock(nn.Module): | |
def __init__( | |
self, | |
dim, | |
dim_out, | |
num_heads, | |
feat_size, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
drop_path=0.0, | |
norm_layer=nn.LayerNorm, | |
kernel_q=(1, 1), | |
kernel_kv=(1, 1), | |
stride_q=(1, 1), | |
stride_kv=(1, 1), | |
mode="conv", | |
has_cls_token=True, | |
expand_attn=False, | |
pool_first=False, | |
rel_pos_type='spatial', | |
residual_pooling=True, | |
): | |
super().__init__() | |
proj_needed = dim != dim_out | |
self.dim = dim | |
self.dim_out = dim_out | |
self.has_cls_token = has_cls_token | |
self.norm1 = norm_layer(dim) | |
self.shortcut_proj_attn = nn.Linear(dim, dim_out) if proj_needed and expand_attn else None | |
if stride_q and prod(stride_q) > 1: | |
kernel_skip = [s + 1 if s > 1 else s for s in stride_q] | |
stride_skip = stride_q | |
padding_skip = [int(skip // 2) for skip in kernel_skip] | |
self.shortcut_pool_attn = nn.MaxPool2d(kernel_skip, stride_skip, padding_skip) | |
else: | |
self.shortcut_pool_attn = None | |
att_dim = dim_out if expand_attn else dim | |
attn_layer = MultiScaleAttentionPoolFirst if pool_first else MultiScaleAttention | |
self.attn = attn_layer( | |
dim, | |
att_dim, | |
num_heads=num_heads, | |
feat_size=feat_size, | |
qkv_bias=qkv_bias, | |
kernel_q=kernel_q, | |
kernel_kv=kernel_kv, | |
stride_q=stride_q, | |
stride_kv=stride_kv, | |
norm_layer=norm_layer, | |
has_cls_token=has_cls_token, | |
mode=mode, | |
rel_pos_type=rel_pos_type, | |
residual_pooling=residual_pooling, | |
) | |
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
self.norm2 = norm_layer(att_dim) | |
mlp_dim_out = dim_out | |
self.shortcut_proj_mlp = nn.Linear(dim, dim_out) if proj_needed and not expand_attn else None | |
self.mlp = Mlp( | |
in_features=att_dim, | |
hidden_features=int(att_dim * mlp_ratio), | |
out_features=mlp_dim_out, | |
) | |
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
def _shortcut_pool(self, x, feat_size: List[int]): | |
if self.shortcut_pool_attn is None: | |
return x | |
if self.has_cls_token: | |
cls_tok, x = x[:, :1, :], x[:, 1:, :] | |
else: | |
cls_tok = None | |
B, L, C = x.shape | |
H, W = feat_size | |
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() | |
x = self.shortcut_pool_attn(x) | |
x = x.reshape(B, C, -1).transpose(1, 2) | |
if cls_tok is not None: | |
x = torch.cat((cls_tok, x), dim=1) | |
return x | |
def forward(self, x, feat_size: List[int]): | |
x_norm = self.norm1(x) | |
# NOTE as per the original impl, this seems odd, but shortcut uses un-normalized input if no proj | |
x_shortcut = x if self.shortcut_proj_attn is None else self.shortcut_proj_attn(x_norm) | |
x_shortcut = self._shortcut_pool(x_shortcut, feat_size) | |
x, feat_size_new = self.attn(x_norm, feat_size) | |
x = x_shortcut + self.drop_path1(x) | |
x_norm = self.norm2(x) | |
x_shortcut = x if self.shortcut_proj_mlp is None else self.shortcut_proj_mlp(x_norm) | |
x = x_shortcut + self.drop_path2(self.mlp(x_norm)) | |
return x, feat_size_new | |
class MultiScaleVitStage(nn.Module): | |
def __init__( | |
self, | |
dim, | |
dim_out, | |
depth, | |
num_heads, | |
feat_size, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
mode="conv", | |
kernel_q=(1, 1), | |
kernel_kv=(1, 1), | |
stride_q=(1, 1), | |
stride_kv=(1, 1), | |
has_cls_token=True, | |
expand_attn=False, | |
pool_first=False, | |
rel_pos_type='spatial', | |
residual_pooling=True, | |
norm_layer=nn.LayerNorm, | |
drop_path=0.0, | |
): | |
super().__init__() | |
self.grad_checkpointing = False | |
self.blocks = nn.ModuleList() | |
if expand_attn: | |
out_dims = (dim_out,) * depth | |
else: | |
out_dims = (dim,) * (depth - 1) + (dim_out,) | |
for i in range(depth): | |
attention_block = MultiScaleBlock( | |
dim=dim, | |
dim_out=out_dims[i], | |
num_heads=num_heads, | |
feat_size=feat_size, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
kernel_q=kernel_q, | |
kernel_kv=kernel_kv, | |
stride_q=stride_q if i == 0 else (1, 1), | |
stride_kv=stride_kv, | |
mode=mode, | |
has_cls_token=has_cls_token, | |
pool_first=pool_first, | |
rel_pos_type=rel_pos_type, | |
residual_pooling=residual_pooling, | |
expand_attn=expand_attn, | |
norm_layer=norm_layer, | |
drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, | |
) | |
dim = out_dims[i] | |
self.blocks.append(attention_block) | |
if i == 0: | |
feat_size = tuple([size // stride for size, stride in zip(feat_size, stride_q)]) | |
self.feat_size = feat_size | |
def forward(self, x, feat_size: List[int]): | |
for blk in self.blocks: | |
if self.grad_checkpointing and not torch.jit.is_scripting(): | |
x, feat_size = checkpoint.checkpoint(blk, x, feat_size) | |
else: | |
x, feat_size = blk(x, feat_size) | |
return x, feat_size | |
class MultiScaleVit(nn.Module): | |
""" | |
Improved Multiscale Vision Transformers for Classification and Detection | |
Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik, | |
Christoph Feichtenhofer* | |
https://arxiv.org/abs/2112.01526 | |
Multiscale Vision Transformers | |
Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik, | |
Christoph Feichtenhofer* | |
https://arxiv.org/abs/2104.11227 | |
""" | |
def __init__( | |
self, | |
cfg: MultiScaleVitCfg, | |
img_size: Tuple[int, int] = (224, 224), | |
in_chans: int = 3, | |
global_pool: str = 'avg', | |
num_classes: int = 1000, | |
drop_path_rate: float = 0., | |
drop_rate: float = 0., | |
): | |
super().__init__() | |
img_size = to_2tuple(img_size) | |
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) | |
self.num_classes = num_classes | |
self.drop_rate = drop_rate | |
self.global_pool = global_pool | |
self.depths = tuple(cfg.depths) | |
self.expand_attn = cfg.expand_attn | |
embed_dim = cfg.embed_dim[0] | |
self.patch_embed = PatchEmbed( | |
dim_in=in_chans, | |
dim_out=embed_dim, | |
kernel=cfg.patch_kernel, | |
stride=cfg.patch_stride, | |
padding=cfg.patch_padding, | |
) | |
patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1]) | |
num_patches = prod(patch_dims) | |
if cfg.use_cls_token: | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.num_prefix_tokens = 1 | |
pos_embed_dim = num_patches + 1 | |
else: | |
self.num_prefix_tokens = 0 | |
self.cls_token = None | |
pos_embed_dim = num_patches | |
if cfg.use_abs_pos: | |
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim)) | |
else: | |
self.pos_embed = None | |
num_stages = len(cfg.embed_dim) | |
feat_size = patch_dims | |
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] | |
self.stages = nn.ModuleList() | |
for i in range(num_stages): | |
if cfg.expand_attn: | |
dim_out = cfg.embed_dim[i] | |
else: | |
dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)] | |
stage = MultiScaleVitStage( | |
dim=embed_dim, | |
dim_out=dim_out, | |
depth=cfg.depths[i], | |
num_heads=cfg.num_heads[i], | |
feat_size=feat_size, | |
mlp_ratio=cfg.mlp_ratio, | |
qkv_bias=cfg.qkv_bias, | |
mode=cfg.mode, | |
pool_first=cfg.pool_first, | |
expand_attn=cfg.expand_attn, | |
kernel_q=cfg.kernel_qkv, | |
kernel_kv=cfg.kernel_qkv, | |
stride_q=cfg.stride_q[i], | |
stride_kv=cfg.stride_kv[i], | |
has_cls_token=cfg.use_cls_token, | |
rel_pos_type=cfg.rel_pos_type, | |
residual_pooling=cfg.residual_pooling, | |
norm_layer=norm_layer, | |
drop_path=dpr[i], | |
) | |
embed_dim = dim_out | |
feat_size = stage.feat_size | |
self.stages.append(stage) | |
self.num_features = embed_dim | |
self.norm = norm_layer(embed_dim) | |
self.head = nn.Sequential(OrderedDict([ | |
('drop', nn.Dropout(self.drop_rate)), | |
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) | |
])) | |
if self.pos_embed is not None: | |
trunc_normal_tf_(self.pos_embed, std=0.02) | |
if self.cls_token is not None: | |
trunc_normal_tf_(self.cls_token, std=0.02) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_tf_(m.weight, std=0.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0.0) | |
def no_weight_decay(self): | |
return {k for k, _ in self.named_parameters() | |
if any(n in k for n in ["pos_embed", "rel_pos_h", "rel_pos_w", "cls_token"])} | |
def group_matcher(self, coarse=False): | |
matcher = dict( | |
stem=r'^patch_embed', # stem and embed | |
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] | |
) | |
return matcher | |
def set_grad_checkpointing(self, enable=True): | |
for s in self.stages: | |
s.grad_checkpointing = enable | |
def get_classifier(self): | |
return self.head.fc | |
def reset_classifier(self, num_classes, global_pool=None): | |
self.num_classes = num_classes | |
if global_pool is not None: | |
self.global_pool = global_pool | |
self.head = nn.Sequential(OrderedDict([ | |
('drop', nn.Dropout(self.drop_rate)), | |
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) | |
])) | |
def forward_features(self, x): | |
x, feat_size = self.patch_embed(x) | |
B, N, C = x.shape | |
if self.cls_token is not None: | |
cls_tokens = self.cls_token.expand(B, -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
if self.pos_embed is not None: | |
x = x + self.pos_embed | |
for stage in self.stages: | |
x, feat_size = stage(x, feat_size) | |
x = self.norm(x) | |
return x | |
def forward_head(self, x, pre_logits: bool = False): | |
if self.global_pool: | |
if self.global_pool == 'avg': | |
x = x[:, self.num_prefix_tokens:].mean(1) | |
else: | |
x = x[:, 0] | |
return x if pre_logits else self.head(x) | |
def forward(self, x): | |
x = self.forward_features(x) | |
x = self.forward_head(x) | |
return x | |
def checkpoint_filter_fn(state_dict, model): | |
if 'stages.0.blocks.0.norm1.weight' in state_dict: | |
return state_dict | |
import re | |
if 'model_state' in state_dict: | |
state_dict = state_dict['model_state'] | |
depths = getattr(model, 'depths', None) | |
expand_attn = getattr(model, 'expand_attn', True) | |
assert depths is not None, 'model requires depth attribute to remap checkpoints' | |
depth_map = {} | |
block_idx = 0 | |
for stage_idx, d in enumerate(depths): | |
depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)}) | |
block_idx += d | |
out_dict = {} | |
for k, v in state_dict.items(): | |
k = re.sub( | |
r'blocks\.(\d+)', | |
lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}', | |
k) | |
if expand_attn: | |
k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k) | |
else: | |
k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k) | |
if 'head' in k: | |
k = k.replace('head.projection', 'head.fc') | |
out_dict[k] = v | |
# for k, v in state_dict.items(): | |
# if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: | |
# # To resize pos embedding when using model at different size from pretrained weights | |
# v = resize_pos_embed( | |
# v, | |
# model.pos_embed, | |
# 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), | |
# model.patch_embed.grid_size | |
# ) | |
return out_dict | |
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs): | |
return build_model_with_cfg( | |
MultiScaleVit, variant, pretrained, | |
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], | |
pretrained_filter_fn=checkpoint_filter_fn, | |
feature_cfg=dict(flatten_sequential=True), | |
**kwargs) | |
def mvitv2_tiny(pretrained=False, **kwargs): | |
return _create_mvitv2('mvitv2_tiny', pretrained=pretrained, **kwargs) | |
def mvitv2_small(pretrained=False, **kwargs): | |
return _create_mvitv2('mvitv2_small', pretrained=pretrained, **kwargs) | |
def mvitv2_base(pretrained=False, **kwargs): | |
return _create_mvitv2('mvitv2_base', pretrained=pretrained, **kwargs) | |
def mvitv2_large(pretrained=False, **kwargs): | |
return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs) | |
# @register_model | |
# def mvitv2_base_in21k(pretrained=False, **kwargs): | |
# return _create_mvitv2('mvitv2_base_in21k', pretrained=pretrained, **kwargs) | |
# | |
# | |
# @register_model | |
# def mvitv2_large_in21k(pretrained=False, **kwargs): | |
# return _create_mvitv2('mvitv2_large_in21k', pretrained=pretrained, **kwargs) | |
# | |
# | |
# @register_model | |
# def mvitv2_huge_in21k(pretrained=False, **kwargs): | |
# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs) | |
def mvitv2_small_cls(pretrained=False, **kwargs): | |
return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs) | |