from collections import OrderedDict import copy import torch import torch.nn as nn from dockformerpp.utils.tensor_utils import tensor_tree_map class ExponentialMovingAverage: """ Maintains moving averages of parameters with exponential decay At each step, the stored copy `copy` of each parameter `param` is updated as follows: `copy = decay * copy + (1 - decay) * param` where `decay` is an attribute of the ExponentialMovingAverage object. """ def __init__(self, model: nn.Module, decay: float): """ Args: model: A torch.nn.Module whose parameters are to be tracked decay: A value (usually close to 1.) by which updates are weighted as part of the above formula """ super(ExponentialMovingAverage, self).__init__() clone_param = lambda t: t.clone().detach() self.params = tensor_tree_map(clone_param, model.state_dict()) self.decay = decay self.device = next(model.parameters()).device def to(self, device): self.params = tensor_tree_map(lambda t: t.to(device), self.params) self.device = device def _update_state_dict_(self, update, state_dict): with torch.no_grad(): for k, v in update.items(): stored = state_dict[k] if not isinstance(v, torch.Tensor): self._update_state_dict_(v, stored) else: diff = stored - v diff *= 1 - self.decay stored -= diff def update(self, model: torch.nn.Module) -> None: """ Updates the stored parameters using the state dict of the provided module. The module should have the same structure as that used to initialize the ExponentialMovingAverage object. """ self._update_state_dict_(model.state_dict(), self.params) def load_state_dict(self, state_dict: OrderedDict) -> None: for k in state_dict["params"].keys(): self.params[k] = state_dict["params"][k].clone() self.decay = state_dict["decay"] def state_dict(self) -> OrderedDict: return OrderedDict( { "params": self.params, "decay": self.decay, } )