File size: 8,135 Bytes
6a62ffb |
|
#!/usr/bin/env python3
"""
This module has the EMA class used to store a copy of the exponentially decayed
model params.
Typical usage of EMA class involves initializing an object using an existing
model (random or from a seed model) and setting the config like ema_decay,
ema_start_update which determine how the EMA model is updated. After every
update of the model i.e. at the end of the train_step, the EMA should be updated
by passing the new model to the EMA.step function. The EMA model state dict
can be stored in the extra state under the key of "ema" and dumped
into a checkpoint and loaded. The EMA object can be passed to tasks
by setting task.uses_ema property.
EMA is a smoothed/ensemble model which might have better performance
when used for inference or further fine-tuning. EMA class has a
reverse function to load the EMA params into a model and use it
like a regular model.
This implementation is used for trainer-level ema tracking. For EMA tracking
inside the model, please use fairseq/modules/ema_module.py instead.
"""
import copy
import logging
import torch
from fairseq import checkpoint_utils
class EMA(object):
"""Exponential Moving Average of Fairseq Models
EMA keeps a copy of the exponentially decayed model params.
The set of params should include both gradient-descent and
non-gradient descent params, such as batch mean/var and buffers.
This is a modified implementation of
the open source code in https://github.com/zhawe01/fairseq-gec.git,
and internal source code in
fbcode/mobile-vision/projects/classification_pytorch/lib/utils/model_ema.py.
Similar to TF EMA.
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage.
EMA provides a averaged and smoothed set of model weights, and has been shown to
improve vision models. EMA class does all necessary functions to update, reload,
or init EMA methods.
EMA object is initialized from an arbitrary model. By default, it is stored in
the same device (unless device specified at initialization) and with the
same precision as the model (unless ema_fp32 is True). ema_fp32 is recommended.
This stores the EMA parameters in fp32 only for the EMA update step, and
is used at the default precision otherwise.
EMA is usually enabled using EMAConfig with store_ema=True. Some important
parameters to configure EMA are
1) ema_decay - The decay of EMA
2) ema_update_freq - EMA is updated every this many model updates.
3) ema_start_update - Start EMA update after this many model updates [default 0]
Key methods:
1) step - One update of EMA using new model
2) restore - Update EMA from a state dict
3) reverse - Load EMA into a model
4) get_decay, _set_decay - Used to get or set the decay. Note _set_decay is
called from step.
5) build_fp32_params - Used to initialize or update the fp32 copy of EMA params.
Note this is enabled only when ema_fp32=True
"""
def __init__(self, model, config, device=None, skip_keys=None):
"""
@param model model to initialize the EMA with
@param config EMAConfig object with configuration like
ema_decay, ema_update_freq, ema_fp32
@param device If provided, copy EMA to this device (e.g. gpu).
Otherwise EMA is in the same device as the model.
"""
self.decay = config.ema_decay
self.model = copy.deepcopy(model)
self.model.requires_grad_(False)
self.config = config
self.skip_keys = skip_keys or set()
self.fp32_params = {}
if self.config.ema_seed_model is not None:
state = checkpoint_utils.load_ema_from_checkpoint(
self.config.ema_seed_model
)
self.model.load_state_dict(state["model"], strict=True)
if device is not None:
logging.info(f"Copying EMA model to device {device}")
self.model = self.model.to(device=device)
if self.config.ema_fp32:
self.build_fp32_params()
self.update_freq_counter = 0
def get_model(self):
return self.model
def build_fp32_params(self, state_dict=None):
"""
Store a copy of the EMA params in fp32.
If state dict is passed, the EMA params is copied from
the provided state dict. Otherwise, it is copied from the
current EMA model parameters.
"""
if not self.config.ema_fp32:
raise RuntimeError(
"build_fp32_params should not be called if ema_fp32=False. "
"Use ema_fp32=True if this is really intended."
)
if state_dict is None:
state_dict = self.model.state_dict()
def _to_float(t):
return t.float() if torch.is_floating_point(t) else t
for param_key in state_dict:
if param_key in self.fp32_params:
self.fp32_params[param_key].copy_(state_dict[param_key])
else:
self.fp32_params[param_key] = _to_float(state_dict[param_key])
def restore(self, state_dict, build_fp32_params=False):
"""Load data from a model spec into EMA model"""
self.model.load_state_dict(state_dict, strict=False)
if build_fp32_params:
self.build_fp32_params(state_dict)
def _set_decay(self, decay):
self.decay = decay
def get_decay(self):
return self.decay
def _step_internal(self, new_model, updates=None):
"""One update of the EMA model based on new model weights"""
decay = self.decay
ema_state_dict = {}
ema_params = (
self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
)
for key, param in new_model.state_dict().items():
if isinstance(param, dict):
continue
try:
ema_param = ema_params[key]
except KeyError:
ema_param = (
param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
)
if param.shape != ema_param.shape:
raise ValueError(
"incompatible tensor shapes between model param and ema param"
+ "{} vs. {}".format(param.shape, ema_param.shape)
)
if "version" in key:
# Do not decay a model.version pytorch param
continue
if key in self.skip_keys:
ema_param = param.to(dtype=ema_param.dtype).clone()
else:
ema_param.mul_(decay)
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay)
ema_state_dict[key] = ema_param
self.restore(ema_state_dict, build_fp32_params=False)
def step(self, new_model, updates=None):
"""
One update of EMA which is done every self.config.ema_update_freq
updates of the model.
@param updates The current number of model updates done.
Decay is set of 0 if model updates < ema_start_update, which means
the model will be simply copied over to the EMA.
When model updates >= ema_start_updates, then EMA is updated with
a decay of self.config.ema_decay.
"""
if updates is not None:
self._set_decay(
0 if updates < self.config.ema_start_update else self.config.ema_decay
)
if self.config.ema_update_freq > 1:
self.update_freq_counter += 1
if self.update_freq_counter >= self.config.ema_update_freq:
self._step_internal(new_model, updates)
self.update_freq_counter = 0
else:
self._step_internal(new_model, updates)
def reverse(self, model):
"""
Load the model parameters from EMA model.
Useful for inference or fine-tuning from the EMA model.
"""
d = self.model.state_dict()
if "_ema" in d:
del d["_ema"]
model.load_state_dict(d, strict=False)
return model
|