|
import math |
|
from collections.abc import Iterable |
|
from math import cos, floor, log, pi |
|
|
|
import skorch |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
|
|
_LRScheduler |
|
|
|
|
|
class CyclicCosineDecayLR(skorch.callbacks.Callback): |
|
def __init__( |
|
self, |
|
optimizer, |
|
init_interval, |
|
min_lr, |
|
len_param_groups, |
|
base_lrs, |
|
restart_multiplier=None, |
|
restart_interval=None, |
|
restart_lr=None, |
|
last_epoch=-1, |
|
): |
|
""" |
|
Initialize new CyclicCosineDecayLR object |
|
:param optimizer: (Optimizer) - Wrapped optimizer. |
|
:param init_interval: (int) - Initial decay cycle interval. |
|
:param min_lr: (float or iterable of floats) - Minimal learning rate. |
|
:param restart_multiplier: (float) - Multiplication coefficient for increasing cycle intervals, |
|
if this parameter is set, restart_interval must be None. |
|
:param restart_interval: (int) - Restart interval for fixed cycle intervals, |
|
if this parameter is set, restart_multiplier must be None. |
|
:param restart_lr: (float or iterable of floats) - Optional, the learning rate at cycle restarts, |
|
if not provided, initial learning rate will be used. |
|
:param last_epoch: (int) - Last epoch. |
|
""" |
|
self.len_param_groups = len_param_groups |
|
if restart_interval is not None and restart_multiplier is not None: |
|
raise ValueError( |
|
"You can either set restart_interval or restart_multiplier but not both" |
|
) |
|
|
|
if isinstance(min_lr, Iterable) and len(min_lr) != self.len_param_groups: |
|
raise ValueError( |
|
"Expected len(min_lr) to be equal to len(optimizer.param_groups), " |
|
"got {} and {} instead".format(len(min_lr), self.len_param_groups) |
|
) |
|
|
|
if isinstance(restart_lr, Iterable) and len(restart_lr) != len( |
|
self.len_param_groups |
|
): |
|
raise ValueError( |
|
"Expected len(restart_lr) to be equal to len(optimizer.param_groups), " |
|
"got {} and {} instead".format(len(restart_lr), self.len_param_groups) |
|
) |
|
|
|
if init_interval <= 0: |
|
raise ValueError( |
|
"init_interval must be a positive number, got {} instead".format( |
|
init_interval |
|
) |
|
) |
|
|
|
group_num = self.len_param_groups |
|
self._init_interval = init_interval |
|
self._min_lr = [min_lr] * group_num if isinstance(min_lr, float) else min_lr |
|
self._restart_lr = ( |
|
[restart_lr] * group_num if isinstance(restart_lr, float) else restart_lr |
|
) |
|
self._restart_interval = restart_interval |
|
self._restart_multiplier = restart_multiplier |
|
self.last_epoch = last_epoch |
|
self.base_lrs = base_lrs |
|
super().__init__() |
|
|
|
def on_batch_end(self, net, training, **kwargs): |
|
if self.last_epoch < self._init_interval: |
|
return self._calc(self.last_epoch, self._init_interval, self.base_lrs) |
|
|
|
elif self._restart_interval is not None: |
|
cycle_epoch = ( |
|
self.last_epoch - self._init_interval |
|
) % self._restart_interval |
|
lrs = self.base_lrs if self._restart_lr is None else self._restart_lr |
|
return self._calc(cycle_epoch, self._restart_interval, lrs) |
|
|
|
elif self._restart_multiplier is not None: |
|
n = self._get_n(self.last_epoch) |
|
sn_prev = self._partial_sum(n) |
|
cycle_epoch = self.last_epoch - sn_prev |
|
interval = self._init_interval * self._restart_multiplier ** n |
|
lrs = self.base_lrs if self._restart_lr is None else self._restart_lr |
|
return self._calc(cycle_epoch, interval, lrs) |
|
else: |
|
return self._min_lr |
|
|
|
def _calc(self, t, T, lrs): |
|
return [ |
|
min_lr + (lr - min_lr) * (1 + cos(pi * t / T)) / 2 |
|
for lr, min_lr in zip(lrs, self._min_lr) |
|
] |
|
|
|
def _get_n(self, epoch): |
|
a = self._init_interval |
|
r = self._restart_multiplier |
|
_t = 1 - (1 - r) * epoch / a |
|
return floor(log(_t, r)) |
|
|
|
def _partial_sum(self, n): |
|
a = self._init_interval |
|
r = self._restart_multiplier |
|
return a * (1 - r ** n) / (1 - r) |
|
|
|
|
|
class LearningRateDecayCallback(skorch.callbacks.Callback): |
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__() |
|
self.lr_warmup_end = config.lr_warmup_end |
|
self.lr_warmup_start = config.lr_warmup_start |
|
self.learning_rate = config.learning_rate |
|
self.warmup_batch = config.warmup_epoch * config.batch_per_epoch |
|
self.final_batch = config.final_epoch * config.batch_per_epoch |
|
|
|
self.batch_idx = 0 |
|
|
|
def on_batch_end(self, net, training, **kwargs): |
|
""" |
|
|
|
:param trainer: |
|
:type trainer: |
|
:param pl_module: |
|
:type pl_module: |
|
:param batch: |
|
:type batch: |
|
:param batch_idx: |
|
:type batch_idx: |
|
:param dataloader_idx: |
|
:type dataloader_idx: |
|
""" |
|
|
|
if training: |
|
|
|
if self.batch_idx < self.warmup_batch: |
|
|
|
lr_mult = float(self.batch_idx) / float(max(1, self.warmup_batch)) |
|
lr = self.lr_warmup_start + lr_mult * ( |
|
self.lr_warmup_end - self.lr_warmup_start |
|
) |
|
else: |
|
|
|
progress = float(self.batch_idx - self.warmup_batch) / float( |
|
max(1, self.final_batch - self.warmup_batch) |
|
) |
|
lr = max( |
|
self.learning_rate |
|
+ 0.5 |
|
* (1.0 + math.cos(math.pi * progress)) |
|
* (self.lr_warmup_end - self.learning_rate), |
|
self.learning_rate, |
|
) |
|
net.lr = lr |
|
|
|
|
|
|
|
self.batch_idx += 1 |
|
|
|
|
|
class LRAnnealing(skorch.callbacks.Callback): |
|
def on_epoch_end(self, net, **kwargs): |
|
if not net.history[-1]["valid_loss_best"]: |
|
net.lr /= 4.0 |
|
|