Spaces:
Running
Running
""" Lookahead Optimizer Wrapper. | |
Implementation modified from: https://github.com/alphadl/lookahead.pytorch | |
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 | |
Hacked together by / Copyright 2020 Ross Wightman | |
""" | |
import torch | |
from torch.optim.optimizer import Optimizer | |
from collections import defaultdict | |
class Lookahead(Optimizer): | |
def __init__(self, base_optimizer, alpha=0.5, k=6): | |
# NOTE super().__init__() not called on purpose | |
if not 0.0 <= alpha <= 1.0: | |
raise ValueError(f'Invalid slow update rate: {alpha}') | |
if not 1 <= k: | |
raise ValueError(f'Invalid lookahead steps: {k}') | |
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) | |
self._base_optimizer = base_optimizer | |
self.param_groups = base_optimizer.param_groups | |
self.defaults = base_optimizer.defaults | |
self.defaults.update(defaults) | |
self.state = defaultdict(dict) | |
# manually add our defaults to the param groups | |
for name, default in defaults.items(): | |
for group in self._base_optimizer.param_groups: | |
group.setdefault(name, default) | |
def update_slow(self, group): | |
for fast_p in group["params"]: | |
if fast_p.grad is None: | |
continue | |
param_state = self._base_optimizer.state[fast_p] | |
if 'lookahead_slow_buff' not in param_state: | |
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) | |
param_state['lookahead_slow_buff'].copy_(fast_p) | |
slow = param_state['lookahead_slow_buff'] | |
slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) | |
fast_p.copy_(slow) | |
def sync_lookahead(self): | |
for group in self._base_optimizer.param_groups: | |
self.update_slow(group) | |
def step(self, closure=None): | |
loss = self._base_optimizer.step(closure) | |
for group in self._base_optimizer.param_groups: | |
group['lookahead_step'] += 1 | |
if group['lookahead_step'] % group['lookahead_k'] == 0: | |
self.update_slow(group) | |
return loss | |
def state_dict(self): | |
return self._base_optimizer.state_dict() | |
def load_state_dict(self, state_dict): | |
self._base_optimizer.load_state_dict(state_dict) | |
self.param_groups = self._base_optimizer.param_groups | |