# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han # International Conference on Computer Vision (ICCV), 2023 import torch import torch.nn as nn import torch.nn.functional as F from torch.cuda.amp import autocast from efficientvit.models.nn.act import build_act from efficientvit.models.nn.norm import build_norm from efficientvit.models.utils import (get_same_padding, list_sum, resize, val2list, val2tuple) __all__ = [ "ConvLayer", "UpSampleLayer", "LinearLayer", "IdentityLayer", "DSConv", "MBConv", "FusedMBConv", "ResBlock", "LiteMLA", "EfficientViTBlock", "ResidualBlock", "DAGBlock", "OpSequential", ] ################################################################################# # Basic Layers # ################################################################################# class ConvLayer(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, dilation=1, groups=1, use_bias=False, dropout=0, norm="bn2d", act_func="relu", ): super(ConvLayer, self).__init__() padding = get_same_padding(kernel_size) padding *= dilation self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=(kernel_size, kernel_size), stride=(stride, stride), padding=padding, dilation=(dilation, dilation), groups=groups, bias=use_bias, ) self.norm = build_norm(norm, num_features=out_channels) self.act = build_act(act_func) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.dropout is not None: x = self.dropout(x) x = self.conv(x) if self.norm: x = self.norm(x) if self.act: x = self.act(x) return x class UpSampleLayer(nn.Module): def __init__( self, mode="bicubic", size: int or tuple[int, int] or list[int] or None = None, factor=2, align_corners=False, ): super(UpSampleLayer, self).__init__() self.mode = mode self.size = val2list(size, 2) if size is not None else None self.factor = None if self.size is not None else factor self.align_corners = align_corners def forward(self, x: torch.Tensor) -> torch.Tensor: if ( self.size is not None and tuple(x.shape[-2:]) == self.size ) or self.factor == 1: return x return resize(x, self.size, self.factor, self.mode, self.align_corners) class LinearLayer(nn.Module): def __init__( self, in_features: int, out_features: int, use_bias=True, dropout=0, norm=None, act_func=None, ): super(LinearLayer, self).__init__() self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None self.linear = nn.Linear(in_features, out_features, use_bias) self.norm = build_norm(norm, num_features=out_features) self.act = build_act(act_func) def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor: if x.dim() > 2: x = torch.flatten(x, start_dim=1) return x def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._try_squeeze(x) if self.dropout: x = self.dropout(x) x = self.linear(x) if self.norm: x = self.norm(x) if self.act: x = self.act(x) return x class IdentityLayer(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return x ################################################################################# # Basic Blocks # ################################################################################# class DSConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, use_bias=False, norm=("bn2d", "bn2d"), act_func=("relu6", None), ): super(DSConv, self).__init__() use_bias = val2tuple(use_bias, 2) norm = val2tuple(norm, 2) act_func = val2tuple(act_func, 2) self.depth_conv = ConvLayer( in_channels, in_channels, kernel_size, stride, groups=in_channels, norm=norm[0], act_func=act_func[0], use_bias=use_bias[0], ) self.point_conv = ConvLayer( in_channels, out_channels, 1, norm=norm[1], act_func=act_func[1], use_bias=use_bias[1], ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.depth_conv(x) x = self.point_conv(x) return x class MBConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=6, use_bias=False, norm=("bn2d", "bn2d", "bn2d"), act_func=("relu6", "relu6", None), ): super(MBConv, self).__init__() use_bias = val2tuple(use_bias, 3) norm = val2tuple(norm, 3) act_func = val2tuple(act_func, 3) mid_channels = mid_channels or round(in_channels * expand_ratio) self.inverted_conv = ConvLayer( in_channels, mid_channels, 1, stride=1, norm=norm[0], act_func=act_func[0], use_bias=use_bias[0], ) self.depth_conv = ConvLayer( mid_channels, mid_channels, kernel_size, stride=stride, groups=mid_channels, norm=norm[1], act_func=act_func[1], use_bias=use_bias[1], ) self.point_conv = ConvLayer( mid_channels, out_channels, 1, norm=norm[2], act_func=act_func[2], use_bias=use_bias[2], ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inverted_conv(x) x = self.depth_conv(x) x = self.point_conv(x) return x class FusedMBConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=6, groups=1, use_bias=False, norm=("bn2d", "bn2d"), act_func=("relu6", None), ): super().__init__() use_bias = val2tuple(use_bias, 2) norm = val2tuple(norm, 2) act_func = val2tuple(act_func, 2) mid_channels = mid_channels or round(in_channels * expand_ratio) self.spatial_conv = ConvLayer( in_channels, mid_channels, kernel_size, stride, groups=groups, use_bias=use_bias[0], norm=norm[0], act_func=act_func[0], ) self.point_conv = ConvLayer( mid_channels, out_channels, 1, use_bias=use_bias[1], norm=norm[1], act_func=act_func[1], ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.spatial_conv(x) x = self.point_conv(x) return x class ResBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=1, use_bias=False, norm=("bn2d", "bn2d"), act_func=("relu6", None), ): super().__init__() use_bias = val2tuple(use_bias, 2) norm = val2tuple(norm, 2) act_func = val2tuple(act_func, 2) mid_channels = mid_channels or round(in_channels * expand_ratio) self.conv1 = ConvLayer( in_channels, mid_channels, kernel_size, stride, use_bias=use_bias[0], norm=norm[0], act_func=act_func[0], ) self.conv2 = ConvLayer( mid_channels, out_channels, kernel_size, 1, use_bias=use_bias[1], norm=norm[1], act_func=act_func[1], ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) x = self.conv2(x) return x class LiteMLA(nn.Module): r"""Lightweight multi-scale linear attention""" def __init__( self, in_channels: int, out_channels: int, heads: int or None = None, heads_ratio: float = 1.0, dim=8, use_bias=False, norm=(None, "bn2d"), act_func=(None, None), kernel_func="relu", scales: tuple[int, ...] = (5,), eps=1.0e-15, ): super(LiteMLA, self).__init__() self.eps = eps heads = heads or int(in_channels // dim * heads_ratio) total_dim = heads * dim use_bias = val2tuple(use_bias, 2) norm = val2tuple(norm, 2) act_func = val2tuple(act_func, 2) self.dim = dim self.qkv = ConvLayer( in_channels, 3 * total_dim, 1, use_bias=use_bias[0], norm=norm[0], act_func=act_func[0], ) self.aggreg = nn.ModuleList( [ nn.Sequential( nn.Conv2d( 3 * total_dim, 3 * total_dim, scale, padding=get_same_padding(scale), groups=3 * total_dim, bias=use_bias[0], ), nn.Conv2d( 3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0], ), ) for scale in scales ] ) self.kernel_func = build_act(kernel_func, inplace=False) self.proj = ConvLayer( total_dim * (1 + len(scales)), out_channels, 1, use_bias=use_bias[1], norm=norm[1], act_func=act_func[1], ) @autocast(enabled=False) def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor: B, _, H, W = list(qkv.size()) if qkv.dtype == torch.float16: qkv = qkv.float() qkv = torch.reshape( qkv, ( B, -1, 3 * self.dim, H * W, ), ) qkv = torch.transpose(qkv, -1, -2) q, k, v = ( qkv[..., 0 : self.dim], qkv[..., self.dim : 2 * self.dim], qkv[..., 2 * self.dim :], ) # lightweight linear attention q = self.kernel_func(q) k = self.kernel_func(k) # linear matmul trans_k = k.transpose(-1, -2) v = F.pad(v, (0, 1), mode="constant", value=1) kv = torch.matmul(trans_k, v) out = torch.matmul(q, kv) out = out[..., :-1] / (out[..., -1:] + self.eps) out = torch.transpose(out, -1, -2) out = torch.reshape(out, (B, -1, H, W)) return out def forward(self, x: torch.Tensor) -> torch.Tensor: # generate multi-scale q, k, v qkv = self.qkv(x) multi_scale_qkv = [qkv] for op in self.aggreg: multi_scale_qkv.append(op(qkv)) multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) out = self.relu_linear_att(multi_scale_qkv) out = self.proj(out) return out class EfficientViTBlock(nn.Module): def __init__( self, in_channels: int, heads_ratio: float = 1.0, dim=32, expand_ratio: float = 4, scales=(5,), norm="bn2d", act_func="hswish", ): super(EfficientViTBlock, self).__init__() self.context_module = ResidualBlock( LiteMLA( in_channels=in_channels, out_channels=in_channels, heads_ratio=heads_ratio, dim=dim, norm=(None, norm), scales=scales, ), IdentityLayer(), ) local_module = MBConv( in_channels=in_channels, out_channels=in_channels, expand_ratio=expand_ratio, use_bias=(True, True, False), norm=(None, None, norm), act_func=(act_func, act_func, None), ) self.local_module = ResidualBlock(local_module, IdentityLayer()) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.context_module(x) x = self.local_module(x) return x ################################################################################# # Functional Blocks # ################################################################################# class ResidualBlock(nn.Module): def __init__( self, main: nn.Module or None, shortcut: nn.Module or None, post_act=None, pre_norm: nn.Module or None = None, ): super(ResidualBlock, self).__init__() self.pre_norm = pre_norm self.main = main self.shortcut = shortcut self.post_act = build_act(post_act) def forward_main(self, x: torch.Tensor) -> torch.Tensor: if self.pre_norm is None: return self.main(x) else: return self.main(self.pre_norm(x)) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.main is None: res = x elif self.shortcut is None: res = self.forward_main(x) else: res = self.forward_main(x) + self.shortcut(x) if self.post_act: res = self.post_act(res) return res class DAGBlock(nn.Module): def __init__( self, inputs: dict[str, nn.Module], merge: str, post_input: nn.Module or None, middle: nn.Module, outputs: dict[str, nn.Module], ): super(DAGBlock, self).__init__() self.input_keys = list(inputs.keys()) self.input_ops = nn.ModuleList(list(inputs.values())) self.merge = merge self.post_input = post_input self.middle = middle self.output_keys = list(outputs.keys()) self.output_ops = nn.ModuleList(list(outputs.values())) def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: feat = [ op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops) ] if self.merge == "add": feat = list_sum(feat) elif self.merge == "cat": feat = torch.concat(feat, dim=1) else: raise NotImplementedError if self.post_input is not None: feat = self.post_input(feat) feat = self.middle(feat) for key, op in zip(self.output_keys, self.output_ops): feature_dict[key] = op(feat) return feature_dict class OpSequential(nn.Module): def __init__(self, op_list: list[nn.Module or None]): super(OpSequential, self).__init__() valid_op_list = [] for op in op_list: if op is not None: valid_op_list.append(op) self.op_list = nn.ModuleList(valid_op_list) def forward(self, x: torch.Tensor) -> torch.Tensor: for op in self.op_list: x = op(x) return x