Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import torch | |
from collections import OrderedDict | |
def get_ema_model(model): | |
ema_model = copy.deepcopy(model) | |
ema_model.eval() | |
for param in ema_model.parameters(): | |
param.requires_grad = False | |
return ema_model | |
def update_ema(ema_model, model, decay=0.9999): | |
""" | |
Step the EMA model towards the current model. | |
""" | |
ema_params = OrderedDict(ema_model.named_parameters()) | |
model_params = OrderedDict(model.named_parameters()) | |
for name, param in model_params.items(): | |
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed | |
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) | |