File size: 4,126 Bytes
659e74f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import math
import torch
import itertools as it
from torch.optim import Optimizer
from collections import defaultdict
class Lookahead(Optimizer):
'''
PyTorch implementation of the lookahead wrapper.
Lookahead Optimizer: https://arxiv.org/abs/1907.08610
'''
def __init__(self, optimizer,alpha=0.5, k=6,pullback_momentum="none"):
'''
:param optimizer:inner optimizer
:param k (int): number of lookahead steps
:param alpha(float): linear interpolation factor. 1.0 recovers the inner optimizer.
:param pullback_momentum (str): change to inner optimizer momentum on interpolation update
'''
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
self.optimizer = optimizer
self.param_groups = self.optimizer.param_groups
self.alpha = alpha
self.k = k
self.step_counter = 0
assert pullback_momentum in ["reset", "pullback", "none"]
self.pullback_momentum = pullback_momentum
self.state = defaultdict(dict)
# Cache the current optimizer parameters
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['cached_params'] = torch.zeros_like(p.data)
param_state['cached_params'].copy_(p.data)
def __getstate__(self):
return {
'state': self.state,
'optimizer': self.optimizer,
'alpha': self.alpha,
'step_counter': self.step_counter,
'k':self.k,
'pullback_momentum': self.pullback_momentum
}
def zero_grad(self):
self.optimizer.zero_grad()
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
def _backup_and_load_cache(self):
"""Useful for performing evaluation on the slow weights (which typically generalize better)
"""
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['backup_params'] = torch.zeros_like(p.data)
param_state['backup_params'].copy_(p.data)
p.data.copy_(param_state['cached_params'])
def _clear_and_load_backup(self):
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.copy_(param_state['backup_params'])
del param_state['backup_params']
def step(self, closure=None):
"""Performs a single Lookahead optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = self.optimizer.step(closure)
self.step_counter += 1
if self.step_counter >= self.k:
self.step_counter = 0
# Lookahead and cache the current optimizer parameters
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.mul_(self.alpha).add_(1.0 - self.alpha, param_state['cached_params']) # crucial line
param_state['cached_params'].copy_(p.data)
if self.pullback_momentum == "pullback":
internal_momentum = self.optimizer.state[p]["momentum_buffer"]
self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_(
1.0 - self.alpha, param_state["cached_mom"])
param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
elif self.pullback_momentum == "reset":
self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)
return loss
|