Spaces:
Configuration error
Configuration error
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
# International Conference on Computer Vision (ICCV), 2023 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from efficientvit.apps.trainer.run_config import Scheduler | |
from efficientvit.models.nn.ops import IdentityLayer, ResidualBlock | |
from efficientvit.models.utils import build_kwargs_from_config | |
__all__ = ["apply_drop_func"] | |
def apply_drop_func(network: nn.Module, drop_config: dict[str, any] or None) -> None: | |
if drop_config is None: | |
return | |
drop_lookup_table = { | |
"droppath": apply_droppath, | |
} | |
drop_func = drop_lookup_table[drop_config["name"]] | |
drop_kwargs = build_kwargs_from_config(drop_config, drop_func) | |
drop_func(network, **drop_kwargs) | |
def apply_droppath( | |
network: nn.Module, | |
drop_prob: float, | |
linear_decay=True, | |
scheduled=True, | |
skip=0, | |
) -> None: | |
all_valid_blocks = [] | |
for m in network.modules(): | |
for name, sub_module in m.named_children(): | |
if isinstance(sub_module, ResidualBlock) and isinstance( | |
sub_module.shortcut, IdentityLayer | |
): | |
all_valid_blocks.append((m, name, sub_module)) | |
all_valid_blocks = all_valid_blocks[skip:] | |
for i, (m, name, sub_module) in enumerate(all_valid_blocks): | |
prob = ( | |
drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob | |
) | |
new_module = DropPathResidualBlock( | |
sub_module.main, | |
sub_module.shortcut, | |
sub_module.post_act, | |
sub_module.pre_norm, | |
prob, | |
scheduled, | |
) | |
m._modules[name] = new_module | |
class DropPathResidualBlock(ResidualBlock): | |
def __init__( | |
self, | |
main: nn.Module, | |
shortcut: nn.Module or None, | |
post_act=None, | |
pre_norm: nn.Module or None = None, | |
###################################### | |
drop_prob: float = 0, | |
scheduled=True, | |
): | |
super().__init__(main, shortcut, post_act, pre_norm) | |
self.drop_prob = drop_prob | |
self.scheduled = scheduled | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if ( | |
not self.training | |
or self.drop_prob == 0 | |
or not isinstance(self.shortcut, IdentityLayer) | |
): | |
return ResidualBlock.forward(self, x) | |
else: | |
drop_prob = self.drop_prob | |
if self.scheduled: | |
drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1) | |
keep_prob = 1 - drop_prob | |
shape = (x.shape[0],) + (1,) * (x.ndim - 1) | |
random_tensor = keep_prob + torch.rand( | |
shape, dtype=x.dtype, device=x.device | |
) | |
random_tensor.floor_() # binarize | |
res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x) | |
if self.post_act: | |
res = self.post_act(res) | |
return res | |