luxmorocco's picture
Upload 86 files
4efbc62 verified
raw
history blame
1.61 kB
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
import copy
import math
import torch
import torch.nn as nn
from efficientvit.models.utils import is_parallel
__all__ = ["EMA"]
def update_ema(
ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float
) -> None:
for k, v in ema.state_dict().items():
if v.dtype.is_floating_point:
v -= (1.0 - decay) * (v - new_state_dict[k].detach())
class EMA:
def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
self.shadows = copy.deepcopy(
model.module if is_parallel(model) else model
).eval()
self.decay = decay
self.warmup_steps = warmup_steps
for p in self.shadows.parameters():
p.requires_grad = False
def step(self, model: nn.Module, global_step: int) -> None:
with torch.no_grad():
msd = (model.module if is_parallel(model) else model).state_dict()
update_ema(
self.shadows,
msd,
self.decay * (1 - math.exp(-global_step / self.warmup_steps)),
)
def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
return {self.decay: self.shadows.state_dict()}
def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
for decay in state_dict:
if decay == self.decay:
self.shadows.load_state_dict(state_dict[decay])