|
|
|
|
|
""" |
|
This module has the EMA class used to store a copy of the exponentially decayed |
|
model params. |
|
|
|
Typical usage of EMA class involves initializing an object using an existing |
|
model (random or from a seed model) and setting the config like ema_decay, |
|
ema_start_update which determine how the EMA model is updated. After every |
|
update of the model i.e. at the end of the train_step, the EMA should be updated |
|
by passing the new model to the EMA.step function. The EMA model state dict |
|
can be stored in the extra state under the key of "ema" and dumped |
|
into a checkpoint and loaded. The EMA object can be passed to tasks |
|
by setting task.uses_ema property. |
|
EMA is a smoothed/ensemble model which might have better performance |
|
when used for inference or further fine-tuning. EMA class has a |
|
reverse function to load the EMA params into a model and use it |
|
like a regular model. |
|
|
|
This implementation is used for trainer-level ema tracking. For EMA tracking |
|
inside the model, please use fairseq/modules/ema_module.py instead. |
|
""" |
|
|
|
import copy |
|
import logging |
|
|
|
import torch |
|
|
|
from fairseq import checkpoint_utils |
|
|
|
|
|
class EMA(object): |
|
"""Exponential Moving Average of Fairseq Models |
|
EMA keeps a copy of the exponentially decayed model params. |
|
The set of params should include both gradient-descent and |
|
non-gradient descent params, such as batch mean/var and buffers. |
|
This is a modified implementation of |
|
the open source code in https://github.com/zhawe01/fairseq-gec.git, |
|
and internal source code in |
|
fbcode/mobile-vision/projects/classification_pytorch/lib/utils/model_ema.py. |
|
|
|
Similar to TF EMA. |
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage. |
|
EMA provides a averaged and smoothed set of model weights, and has been shown to |
|
improve vision models. EMA class does all necessary functions to update, reload, |
|
or init EMA methods. |
|
|
|
EMA object is initialized from an arbitrary model. By default, it is stored in |
|
the same device (unless device specified at initialization) and with the |
|
same precision as the model (unless ema_fp32 is True). ema_fp32 is recommended. |
|
This stores the EMA parameters in fp32 only for the EMA update step, and |
|
is used at the default precision otherwise. |
|
EMA is usually enabled using EMAConfig with store_ema=True. Some important |
|
parameters to configure EMA are |
|
1) ema_decay - The decay of EMA |
|
2) ema_update_freq - EMA is updated every this many model updates. |
|
3) ema_start_update - Start EMA update after this many model updates [default 0] |
|
|
|
Key methods: |
|
1) step - One update of EMA using new model |
|
2) restore - Update EMA from a state dict |
|
3) reverse - Load EMA into a model |
|
4) get_decay, _set_decay - Used to get or set the decay. Note _set_decay is |
|
called from step. |
|
5) build_fp32_params - Used to initialize or update the fp32 copy of EMA params. |
|
Note this is enabled only when ema_fp32=True |
|
""" |
|
|
|
def __init__(self, model, config, device=None, skip_keys=None): |
|
""" |
|
@param model model to initialize the EMA with |
|
@param config EMAConfig object with configuration like |
|
ema_decay, ema_update_freq, ema_fp32 |
|
@param device If provided, copy EMA to this device (e.g. gpu). |
|
Otherwise EMA is in the same device as the model. |
|
""" |
|
|
|
self.decay = config.ema_decay |
|
self.model = copy.deepcopy(model) |
|
self.model.requires_grad_(False) |
|
self.config = config |
|
self.skip_keys = skip_keys or set() |
|
self.fp32_params = {} |
|
|
|
if self.config.ema_seed_model is not None: |
|
state = checkpoint_utils.load_ema_from_checkpoint( |
|
self.config.ema_seed_model |
|
) |
|
self.model.load_state_dict(state["model"], strict=True) |
|
|
|
if device is not None: |
|
logging.info(f"Copying EMA model to device {device}") |
|
self.model = self.model.to(device=device) |
|
|
|
if self.config.ema_fp32: |
|
self.build_fp32_params() |
|
|
|
self.update_freq_counter = 0 |
|
|
|
def get_model(self): |
|
return self.model |
|
|
|
def build_fp32_params(self, state_dict=None): |
|
""" |
|
Store a copy of the EMA params in fp32. |
|
If state dict is passed, the EMA params is copied from |
|
the provided state dict. Otherwise, it is copied from the |
|
current EMA model parameters. |
|
""" |
|
if not self.config.ema_fp32: |
|
raise RuntimeError( |
|
"build_fp32_params should not be called if ema_fp32=False. " |
|
"Use ema_fp32=True if this is really intended." |
|
) |
|
|
|
if state_dict is None: |
|
state_dict = self.model.state_dict() |
|
|
|
def _to_float(t): |
|
return t.float() if torch.is_floating_point(t) else t |
|
|
|
for param_key in state_dict: |
|
if param_key in self.fp32_params: |
|
self.fp32_params[param_key].copy_(state_dict[param_key]) |
|
else: |
|
self.fp32_params[param_key] = _to_float(state_dict[param_key]) |
|
|
|
def restore(self, state_dict, build_fp32_params=False): |
|
"""Load data from a model spec into EMA model""" |
|
self.model.load_state_dict(state_dict, strict=False) |
|
if build_fp32_params: |
|
self.build_fp32_params(state_dict) |
|
|
|
def _set_decay(self, decay): |
|
self.decay = decay |
|
|
|
def get_decay(self): |
|
return self.decay |
|
|
|
def _step_internal(self, new_model, updates=None): |
|
"""One update of the EMA model based on new model weights""" |
|
decay = self.decay |
|
|
|
ema_state_dict = {} |
|
ema_params = ( |
|
self.fp32_params if self.config.ema_fp32 else self.model.state_dict() |
|
) |
|
for key, param in new_model.state_dict().items(): |
|
if isinstance(param, dict): |
|
continue |
|
try: |
|
ema_param = ema_params[key] |
|
except KeyError: |
|
ema_param = ( |
|
param.float().clone() if param.ndim == 1 else copy.deepcopy(param) |
|
) |
|
|
|
if param.shape != ema_param.shape: |
|
raise ValueError( |
|
"incompatible tensor shapes between model param and ema param" |
|
+ "{} vs. {}".format(param.shape, ema_param.shape) |
|
) |
|
|
|
if "version" in key: |
|
|
|
continue |
|
|
|
if key in self.skip_keys: |
|
ema_param = param.to(dtype=ema_param.dtype).clone() |
|
else: |
|
ema_param.mul_(decay) |
|
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay) |
|
ema_state_dict[key] = ema_param |
|
self.restore(ema_state_dict, build_fp32_params=False) |
|
|
|
def step(self, new_model, updates=None): |
|
""" |
|
One update of EMA which is done every self.config.ema_update_freq |
|
updates of the model. |
|
|
|
@param updates The current number of model updates done. |
|
Decay is set of 0 if model updates < ema_start_update, which means |
|
the model will be simply copied over to the EMA. |
|
When model updates >= ema_start_updates, then EMA is updated with |
|
a decay of self.config.ema_decay. |
|
""" |
|
if updates is not None: |
|
self._set_decay( |
|
0 if updates < self.config.ema_start_update else self.config.ema_decay |
|
) |
|
if self.config.ema_update_freq > 1: |
|
self.update_freq_counter += 1 |
|
if self.update_freq_counter >= self.config.ema_update_freq: |
|
self._step_internal(new_model, updates) |
|
self.update_freq_counter = 0 |
|
else: |
|
self._step_internal(new_model, updates) |
|
|
|
def reverse(self, model): |
|
""" |
|
Load the model parameters from EMA model. |
|
Useful for inference or fine-tuning from the EMA model. |
|
""" |
|
d = self.model.state_dict() |
|
if "_ema" in d: |
|
del d["_ema"] |
|
|
|
model.load_state_dict(d, strict=False) |
|
return model |
|
|