Spaces:
Configuration error
Configuration error
File size: 5,601 Bytes
108b1ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from efficientvit.models.utils import build_kwargs_from_config
__all__ = ["LayerNorm2d", "build_norm", "reset_bn", "set_norm_eps"]
class LayerNorm2d(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = x - torch.mean(x, dim=1, keepdim=True)
out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
if self.elementwise_affine:
out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
return out
# register normalization function here
REGISTERED_NORM_DICT: dict[str, type] = {
"bn2d": nn.BatchNorm2d,
"ln": nn.LayerNorm,
"ln2d": LayerNorm2d,
}
def build_norm(name="bn2d", num_features=None, **kwargs) -> nn.Module or None:
if name in ["ln", "ln2d"]:
kwargs["normalized_shape"] = num_features
else:
kwargs["num_features"] = num_features
if name in REGISTERED_NORM_DICT:
norm_cls = REGISTERED_NORM_DICT[name]
args = build_kwargs_from_config(kwargs, norm_cls)
return norm_cls(**args)
else:
return None
def reset_bn(
model: nn.Module,
data_loader: list,
sync=True,
progress_bar=False,
) -> None:
import copy
import torch.nn.functional as F
from tqdm import tqdm
from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor
from efficientvit.models.utils import get_device, list_join
bn_mean = {}
bn_var = {}
tmp_model = copy.deepcopy(model)
for name, m in tmp_model.named_modules():
if isinstance(m, _BatchNorm):
bn_mean[name] = AverageMeter(is_distributed=False)
bn_var[name] = AverageMeter(is_distributed=False)
def new_forward(bn, mean_est, var_est):
def lambda_forward(x):
x = x.contiguous()
if sync:
batch_mean = (
x.mean(0, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
) # 1, C, 1, 1
batch_mean = sync_tensor(batch_mean, reduce="cat")
batch_mean = torch.mean(batch_mean, dim=0, keepdim=True)
batch_var = (x - batch_mean) * (x - batch_mean)
batch_var = (
batch_var.mean(0, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
batch_var = sync_tensor(batch_var, reduce="cat")
batch_var = torch.mean(batch_var, dim=0, keepdim=True)
else:
batch_mean = (
x.mean(0, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
) # 1, C, 1, 1
batch_var = (x - batch_mean) * (x - batch_mean)
batch_var = (
batch_var.mean(0, keepdim=True)
.mean(2, keepdim=True)
.mean(3, keepdim=True)
)
batch_mean = torch.squeeze(batch_mean)
batch_var = torch.squeeze(batch_var)
mean_est.update(batch_mean.data, x.size(0))
var_est.update(batch_var.data, x.size(0))
# bn forward using calculated mean & var
_feature_dim = batch_mean.shape[0]
return F.batch_norm(
x,
batch_mean,
batch_var,
bn.weight[:_feature_dim],
bn.bias[:_feature_dim],
False,
0.0,
bn.eps,
)
return lambda_forward
m.forward = new_forward(m, bn_mean[name], bn_var[name])
# skip if there is no batch normalization layers in the network
if len(bn_mean) == 0:
return
tmp_model.eval()
with torch.no_grad():
with tqdm(
total=len(data_loader),
desc="reset bn",
disable=not progress_bar or not is_master(),
) as t:
for images in data_loader:
images = images.to(get_device(tmp_model))
tmp_model(images)
t.set_postfix(
{
"bs": images.size(0),
"res": list_join(images.shape[-2:], "x"),
}
)
t.update()
for name, m in model.named_modules():
if name in bn_mean and bn_mean[name].count > 0:
feature_dim = bn_mean[name].avg.size(0)
assert isinstance(m, _BatchNorm)
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
def set_norm_eps(model: nn.Module, eps: float or None = None) -> None:
for m in model.modules():
if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
if eps is not None:
m.eps = eps
|