gartajackhats1985's picture
Upload 1633 files
681fa96 verified
raw
history blame
2.52 kB
""" Lookahead Optimizer Wrapper.
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch.optim.optimizer import Optimizer
from collections import defaultdict
class Lookahead(Optimizer):
def __init__(self, base_optimizer, alpha=0.5, k=6):
# NOTE super().__init__() not called on purpose
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
self._base_optimizer = base_optimizer
self.param_groups = base_optimizer.param_groups
self.defaults = base_optimizer.defaults
self.defaults.update(defaults)
self.state = defaultdict(dict)
# manually add our defaults to the param groups
for name, default in defaults.items():
for group in self._base_optimizer.param_groups:
group.setdefault(name, default)
@torch.no_grad()
def update_slow(self, group):
for fast_p in group["params"]:
if fast_p.grad is None:
continue
param_state = self._base_optimizer.state[fast_p]
if 'lookahead_slow_buff' not in param_state:
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
param_state['lookahead_slow_buff'].copy_(fast_p)
slow = param_state['lookahead_slow_buff']
slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
fast_p.copy_(slow)
def sync_lookahead(self):
for group in self._base_optimizer.param_groups:
self.update_slow(group)
@torch.no_grad()
def step(self, closure=None):
loss = self._base_optimizer.step(closure)
for group in self._base_optimizer.param_groups:
group['lookahead_step'] += 1
if group['lookahead_step'] % group['lookahead_k'] == 0:
self.update_slow(group)
return loss
def state_dict(self):
return self._base_optimizer.state_dict()
def load_state_dict(self, state_dict):
self._base_optimizer.load_state_dict(state_dict)
self.param_groups = self._base_optimizer.param_groups