venite's picture
initial
f670afc
raw
history blame
2.01 kB
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
from torch.optim.optimizer import Optimizer, required
class Madam(Optimizer):
r"""MADAM optimizer implementation (https://arxiv.org/abs/2006.14560)"""
def __init__(self, params, lr=required, scale=3.0,
g_bound=None, momentum=0):
self.scale = scale
self.g_bound = g_bound
defaults = dict(lr=lr, momentum=momentum)
super(Madam, self).__init__(params, defaults)
def step(self, closure=None):
r"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
state['max'] = self.scale * (p * p).mean().sqrt().item()
state['step'] = 0
state['exp_avg_sq'] = torch.zeros_like(p)
state['step'] += 1
bias_correction = 1 - 0.999 ** state['step']
state['exp_avg_sq'] = 0.999 * state[
'exp_avg_sq'] + 0.001 * p.grad.data ** 2
g_normed = \
p.grad.data / (state['exp_avg_sq'] / bias_correction).sqrt()
g_normed[torch.isnan(g_normed)] = 0
if self.g_bound is not None:
g_normed.clamp_(-self.g_bound, self.g_bound)
p.data *= torch.exp(
-group['lr'] * g_normed * torch.sign(p.data))
p.data.clamp_(-state['max'], state['max'])
return loss