Spaces:
Running
Running
""" PyTorch MADGRAD optimizer | |
MADGRAD: https://arxiv.org/abs/2101.11075 | |
Code from: https://github.com/facebookresearch/madgrad | |
""" | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
from typing import TYPE_CHECKING, Any, Callable, Optional | |
import torch | |
import torch.optim | |
if TYPE_CHECKING: | |
from torch.optim.optimizer import _params_t | |
else: | |
_params_t = Any | |
class MADGRAD(torch.optim.Optimizer): | |
""" | |
MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic | |
Optimization. | |
.. _MADGRAD: https://arxiv.org/abs/2101.11075 | |
MADGRAD is a general purpose optimizer that can be used in place of SGD or | |
Adam may converge faster and generalize better. Currently GPU-only. | |
Typically, the same learning rate schedule that is used for SGD or Adam may | |
be used. The overall learning rate is not comparable to either method and | |
should be determined by a hyper-parameter sweep. | |
MADGRAD requires less weight decay than other methods, often as little as | |
zero. Momentum values used for SGD or Adam's beta1 should work here also. | |
On sparse problems both weight_decay and momentum should be set to 0. | |
Arguments: | |
params (iterable): | |
Iterable of parameters to optimize or dicts defining parameter groups. | |
lr (float): | |
Learning rate (default: 1e-2). | |
momentum (float): | |
Momentum value in the range [0,1) (default: 0.9). | |
weight_decay (float): | |
Weight decay, i.e. a L2 penalty (default: 0). | |
eps (float): | |
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). | |
""" | |
def __init__( | |
self, | |
params: _params_t, | |
lr: float = 1e-2, | |
momentum: float = 0.9, | |
weight_decay: float = 0, | |
eps: float = 1e-6, | |
decoupled_decay: bool = False, | |
): | |
if momentum < 0 or momentum >= 1: | |
raise ValueError(f"Momentum {momentum} must be in the range [0,1]") | |
if lr <= 0: | |
raise ValueError(f"Learning rate {lr} must be positive") | |
if weight_decay < 0: | |
raise ValueError(f"Weight decay {weight_decay} must be non-negative") | |
if eps < 0: | |
raise ValueError(f"Eps must be non-negative") | |
defaults = dict( | |
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) | |
super().__init__(params, defaults) | |
def supports_memory_efficient_fp16(self) -> bool: | |
return False | |
def supports_flat_params(self) -> bool: | |
return True | |
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: | |
"""Performs a single optimization step. | |
Arguments: | |
closure (callable, optional): A closure that reevaluates the model and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
eps = group['eps'] | |
lr = group['lr'] + eps | |
weight_decay = group['weight_decay'] | |
momentum = group['momentum'] | |
ck = 1 - momentum | |
for p in group["params"]: | |
if p.grad is None: | |
continue | |
grad = p.grad | |
if momentum != 0.0 and grad.is_sparse: | |
raise RuntimeError("momentum != 0 is not compatible with sparse gradients") | |
state = self.state[p] | |
if len(state) == 0: | |
state['step'] = 0 | |
state['grad_sum_sq'] = torch.zeros_like(p) | |
state['s'] = torch.zeros_like(p) | |
if momentum != 0: | |
state['x0'] = torch.clone(p).detach() | |
state['step'] += 1 | |
grad_sum_sq = state['grad_sum_sq'] | |
s = state['s'] | |
lamb = lr * math.sqrt(state['step']) | |
# Apply weight decay | |
if weight_decay != 0: | |
if group['decoupled_decay']: | |
p.mul_(1.0 - group['lr'] * weight_decay) | |
else: | |
if grad.is_sparse: | |
raise RuntimeError("weight_decay option is not compatible with sparse gradients") | |
grad.add_(p, alpha=weight_decay) | |
if grad.is_sparse: | |
grad = grad.coalesce() | |
grad_val = grad._values() | |
p_masked = p.sparse_mask(grad) | |
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) | |
s_masked = s.sparse_mask(grad) | |
# Compute x_0 from other known quantities | |
rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps) | |
x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1) | |
# Dense + sparse op | |
grad_sq = grad * grad | |
grad_sum_sq.add_(grad_sq, alpha=lamb) | |
grad_sum_sq_masked.add_(grad_sq, alpha=lamb) | |
rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) | |
s.add_(grad, alpha=lamb) | |
s_masked._values().add_(grad_val, alpha=lamb) | |
# update masked copy of p | |
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1) | |
# Copy updated masked p to dense p using an add operation | |
p_masked._values().add_(p_kp1_masked_vals, alpha=-1) | |
p.add_(p_masked, alpha=-1) | |
else: | |
if momentum == 0: | |
# Compute x_0 from other known quantities | |
rms = grad_sum_sq.pow(1 / 3).add_(eps) | |
x0 = p.addcdiv(s, rms, value=1) | |
else: | |
x0 = state['x0'] | |
# Accumulate second moments | |
grad_sum_sq.addcmul_(grad, grad, value=lamb) | |
rms = grad_sum_sq.pow(1 / 3).add_(eps) | |
# Update s | |
s.add_(grad, alpha=lamb) | |
# Step | |
if momentum == 0: | |
p.copy_(x0.addcdiv(s, rms, value=-1)) | |
else: | |
z = x0.addcdiv(s, rms, value=-1) | |
# p is a moving average of z | |
p.mul_(1 - ck).add_(z, alpha=ck) | |
return loss | |