|
import torch |
|
|
|
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): |
|
"""Maintains moving averages of model parameters using an exponential decay. |
|
``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` |
|
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_ |
|
is used to compute the EMA. |
|
""" |
|
|
|
def __init__(self, model, decay, device="cpu"): |
|
def ema_avg(avg_model_param, model_param, num_averaged): |
|
return decay * avg_model_param + (1 - decay) * model_param |
|
super().__init__(model, device, ema_avg) |