import copy import torch from collections import OrderedDict def get_ema_model(model): ema_model = copy.deepcopy(model) ema_model.eval() for param in ema_model.parameters(): param.requires_grad = False return ema_model @torch.no_grad() def update_ema(ema_model, model, decay=0.9999): """ Step the EMA model towards the current model. """ ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)