|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import math |
|
from collections.abc import Collection |
|
from dataclasses import dataclass, field |
|
from typing import Any, List |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.optim |
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.optim import FairseqOptimizer, register_optimizer |
|
from fairseq.optim.fused_adam import get_fused_adam_class |
|
from omegaconf import II, OmegaConf |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class FairseqAdamConfig(FairseqDataclass): |
|
adam_betas: Any = field( |
|
default=(0.9, 0.999), metadata={"help": "betas for Adam optimizer"} |
|
) |
|
adam_eps: float = field( |
|
default=1e-8, metadata={"help": "epsilon for Adam optimizer"} |
|
) |
|
weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) |
|
use_old_adam: bool = field( |
|
default=False, metadata={"help": "Use fairseq.optim.adam.Adam"} |
|
) |
|
fp16_adam_stats: bool = field( |
|
default=False, metadata={"help": "use FP16 stats (with automatic scaling)"} |
|
) |
|
|
|
tpu: bool = II("common.tpu") |
|
lr: List[float] = II("optimization.lr") |
|
|
|
|
|
@register_optimizer("adam", dataclass=FairseqAdamConfig) |
|
class FairseqAdam(FairseqOptimizer): |
|
"""Adam optimizer for fairseq. |
|
|
|
Important note: this optimizer corresponds to the "AdamW" variant of |
|
Adam in its weight decay behavior. As such, it is most closely |
|
analogous to torch.optim.AdamW from PyTorch. |
|
""" |
|
|
|
def __init__(self, cfg: FairseqAdamConfig, params): |
|
super().__init__(cfg) |
|
fused_adam_cls = get_fused_adam_class() |
|
use_fused_adam = ( |
|
not getattr(cfg, "use_old_adam", False) |
|
and fused_adam_cls is not None |
|
and torch.cuda.is_available() |
|
) |
|
if getattr(cfg, "tpu", False): |
|
if self.cfg.fp16_adam_stats: |
|
raise NotImplementedError("--fp16-adam-stats is only supported on GPU") |
|
|
|
|
|
self._optimizer = Adam(params, **self.optimizer_config) |
|
elif use_fused_adam: |
|
logger.info("using FusedAdam") |
|
self._optimizer = fused_adam_cls( |
|
params, use_fp16_stats=self.cfg.fp16_adam_stats, **self.optimizer_config |
|
) |
|
else: |
|
if self.cfg.fp16_adam_stats: |
|
raise NotImplementedError( |
|
"--fp16-adam-stats is only supported with FusedAdamV1" |
|
) |
|
self._optimizer = Adam(params, **self.optimizer_config) |
|
|
|
@property |
|
def optimizer_config(self): |
|
""" |
|
Return a kwarg dictionary that will be used to override optimizer |
|
args stored in checkpoints. This allows us to load a checkpoint and |
|
resume training using a different set of optimizer args, e.g., with a |
|
different learning rate. |
|
""" |
|
return { |
|
"lr": self.cfg.lr[0] |
|
if isinstance(self.cfg.lr, Collection) |
|
else self.cfg.lr, |
|
"betas": eval(self.cfg.adam_betas) |
|
if isinstance(self.cfg.adam_betas, str) |
|
else OmegaConf.to_container(self.cfg.adam_betas), |
|
"eps": self.cfg.adam_eps, |
|
"weight_decay": self.cfg.weight_decay, |
|
} |
|
|
|
def average_params(self): |
|
"""Reduce Params is only used during BMUF distributed training.""" |
|
state_dict = self.optimizer.state_dict() |
|
total_gpus = float(dist.get_world_size()) |
|
|
|
for _, value in state_dict["state"].items(): |
|
value["exp_avg"] /= total_gpus |
|
value["exp_avg_sq"] /= total_gpus |
|
dist.all_reduce(value["exp_avg"], op=dist.ReduceOp.SUM) |
|
dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM) |
|
|
|
|
|
class Adam(torch.optim.Optimizer): |
|
r"""Implements Adam algorithm. |
|
|
|
This implementation is modified from torch.optim.Adam based on: |
|
`Fixed Weight Decay Regularization in Adam` |
|
(see https://arxiv.org/abs/1711.05101) |
|
|
|
It has been proposed in `Adam: A Method for Stochastic Optimization`_. |
|
|
|
Args: |
|
params (iterable): iterable of parameters to optimize or dicts defining |
|
parameter groups |
|
lr (float, optional): learning rate (default: 1e-3) |
|
betas (Tuple[float, float], optional): coefficients used for computing |
|
running averages of gradient and its square (default: (0.9, 0.999)) |
|
eps (float, optional): term added to the denominator to improve |
|
numerical stability (default: 1e-8) |
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) |
|
amsgrad (boolean, optional): whether to use the AMSGrad variant of this |
|
algorithm from the paper `On the Convergence of Adam and Beyond`_ |
|
|
|
.. _Adam\: A Method for Stochastic Optimization: |
|
https://arxiv.org/abs/1412.6980 |
|
.. _On the Convergence of Adam and Beyond: |
|
https://openreview.net/forum?id=ryQu7f-RZ |
|
""" |
|
|
|
def __init__( |
|
self, |
|
params, |
|
lr=1e-3, |
|
betas=(0.9, 0.999), |
|
eps=1e-8, |
|
weight_decay=0, |
|
amsgrad=False, |
|
): |
|
defaults = dict( |
|
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad |
|
) |
|
super(Adam, self).__init__(params, defaults) |
|
|
|
@property |
|
def supports_memory_efficient_fp16(self): |
|
return True |
|
|
|
@property |
|
def supports_flat_params(self): |
|
return True |
|
|
|
def step(self, closure=None): |
|
"""Performs a single optimization step. |
|
|
|
Args: |
|
closure (callable, optional): A closure that reevaluates the model |
|
and returns the loss. |
|
""" |
|
loss = None |
|
if closure is not None: |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
for p in group["params"]: |
|
if p.grad is None: |
|
continue |
|
grad = p.grad.data |
|
if grad.dtype in {torch.float16, torch.bfloat16}: |
|
grad = grad.float() |
|
if grad.is_sparse: |
|
raise RuntimeError( |
|
"Adam does not support sparse gradients, please consider SparseAdam instead" |
|
) |
|
amsgrad = group.get("amsgrad", False) |
|
|
|
p_data_fp32 = p.data |
|
if p.data.dtype in {torch.float16, torch.bfloat16}: |
|
p_data_fp32 = p_data_fp32.float() |
|
|
|
state = self.state[p] |
|
|
|
|
|
if len(state) == 0: |
|
state["step"] = 0 |
|
|
|
state["exp_avg"] = torch.zeros_like(p_data_fp32) |
|
|
|
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) |
|
if amsgrad: |
|
|
|
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) |
|
else: |
|
state["exp_avg"] = state["exp_avg"].to(p_data_fp32) |
|
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) |
|
if amsgrad: |
|
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( |
|
p_data_fp32 |
|
) |
|
|
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
|
if amsgrad: |
|
max_exp_avg_sq = state["max_exp_avg_sq"] |
|
beta1, beta2 = group["betas"] |
|
|
|
state["step"] += 1 |
|
|
|
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
|
if amsgrad: |
|
|
|
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) |
|
|
|
denom = max_exp_avg_sq.sqrt().add_(group["eps"]) |
|
else: |
|
denom = exp_avg_sq.sqrt().add_(group["eps"]) |
|
|
|
bias_correction1 = 1 - beta1 ** state["step"] |
|
bias_correction2 = 1 - beta2 ** state["step"] |
|
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 |
|
|
|
if group["weight_decay"] != 0: |
|
p_data_fp32.add_( |
|
p_data_fp32, alpha=-group["weight_decay"] * group["lr"] |
|
) |
|
|
|
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) |
|
|
|
if p.data.dtype in {torch.float16, torch.bfloat16}: |
|
p.data.copy_(p_data_fp32) |
|
|
|
return loss |
|
|