|
|
|
|
|
|
|
|
|
|
|
from collections.abc import Collection |
|
from dataclasses import dataclass, field |
|
from typing import List |
|
|
|
import torch |
|
from fairseq.dataclass import FairseqDataclass |
|
from omegaconf import II, DictConfig |
|
from torch.optim.optimizer import Optimizer, required |
|
|
|
from . import FairseqOptimizer, register_optimizer |
|
|
|
|
|
@dataclass |
|
class FairseqNAGConfig(FairseqDataclass): |
|
momentum: float = field(default=0.99, metadata={"help": "momentum factor"}) |
|
weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) |
|
|
|
lr: List[float] = II("optimization.lr") |
|
|
|
|
|
@register_optimizer("nag", dataclass=FairseqNAGConfig) |
|
class FairseqNAG(FairseqOptimizer): |
|
def __init__(self, cfg: DictConfig, params): |
|
super().__init__(cfg) |
|
self._optimizer = NAG(params, **self.optimizer_config) |
|
|
|
@property |
|
def optimizer_config(self): |
|
""" |
|
Return a kwarg dictionary that will be used to override optimizer |
|
args stored in checkpoints. This allows us to load a checkpoint and |
|
resume training using a different set of optimizer args, e.g., with a |
|
different learning rate. |
|
""" |
|
return { |
|
"lr": self.cfg.lr[0] |
|
if isinstance(self.cfg.lr, Collection) |
|
else self.cfg.lr, |
|
"momentum": self.cfg.momentum, |
|
"weight_decay": self.cfg.weight_decay, |
|
} |
|
|
|
|
|
class NAG(Optimizer): |
|
def __init__(self, params, lr=required, momentum=0, weight_decay=0): |
|
defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) |
|
super(NAG, self).__init__(params, defaults) |
|
|
|
@property |
|
def supports_memory_efficient_fp16(self): |
|
return True |
|
|
|
@property |
|
def supports_flat_params(self): |
|
return True |
|
|
|
def step(self, closure=None): |
|
"""Performs a single optimization step. |
|
|
|
Args: |
|
closure (callable, optional): A closure that reevaluates the model |
|
and returns the loss. |
|
""" |
|
loss = None |
|
if closure is not None: |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
weight_decay = group["weight_decay"] |
|
momentum = group["momentum"] |
|
lr = group["lr"] |
|
lr_old = group.get("lr_old", lr) |
|
lr_correct = lr / lr_old if lr_old > 0 else lr |
|
|
|
for p in group["params"]: |
|
if p.grad is None: |
|
continue |
|
|
|
p_data_fp32 = p.data |
|
if p_data_fp32.dtype in {torch.float16, torch.bfloat16}: |
|
p_data_fp32 = p_data_fp32.float() |
|
|
|
d_p = p.grad.data.float() |
|
param_state = self.state[p] |
|
if "momentum_buffer" not in param_state: |
|
param_state["momentum_buffer"] = torch.zeros_like(d_p) |
|
else: |
|
param_state["momentum_buffer"] = param_state["momentum_buffer"].to( |
|
d_p |
|
) |
|
|
|
buf = param_state["momentum_buffer"] |
|
|
|
if weight_decay != 0: |
|
p_data_fp32.mul_(1 - lr * weight_decay) |
|
p_data_fp32.add_(buf, alpha=momentum * momentum * lr_correct) |
|
p_data_fp32.add_(d_p, alpha=-(1 + momentum) * lr) |
|
|
|
buf.mul_(momentum * lr_correct).add_(d_p, alpha=-lr) |
|
|
|
if p.data.dtype in {torch.float16, torch.bfloat16}: |
|
p.data.copy_(p_data_fp32) |
|
|
|
group["lr_old"] = lr |
|
|
|
return loss |
|
|