File size: 6,407 Bytes
0b11a42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import math
from collections.abc import Iterable
from math import cos, floor, log, pi
import skorch
from torch.optim.lr_scheduler import _LRScheduler
_LRScheduler
class CyclicCosineDecayLR(skorch.callbacks.Callback):
def __init__(
self,
optimizer,
init_interval,
min_lr,
len_param_groups,
base_lrs,
restart_multiplier=None,
restart_interval=None,
restart_lr=None,
last_epoch=-1,
):
"""
Initialize new CyclicCosineDecayLR object
:param optimizer: (Optimizer) - Wrapped optimizer.
:param init_interval: (int) - Initial decay cycle interval.
:param min_lr: (float or iterable of floats) - Minimal learning rate.
:param restart_multiplier: (float) - Multiplication coefficient for increasing cycle intervals,
if this parameter is set, restart_interval must be None.
:param restart_interval: (int) - Restart interval for fixed cycle intervals,
if this parameter is set, restart_multiplier must be None.
:param restart_lr: (float or iterable of floats) - Optional, the learning rate at cycle restarts,
if not provided, initial learning rate will be used.
:param last_epoch: (int) - Last epoch.
"""
self.len_param_groups = len_param_groups
if restart_interval is not None and restart_multiplier is not None:
raise ValueError(
"You can either set restart_interval or restart_multiplier but not both"
)
if isinstance(min_lr, Iterable) and len(min_lr) != self.len_param_groups:
raise ValueError(
"Expected len(min_lr) to be equal to len(optimizer.param_groups), "
"got {} and {} instead".format(len(min_lr), self.len_param_groups)
)
if isinstance(restart_lr, Iterable) and len(restart_lr) != len(
self.len_param_groups
):
raise ValueError(
"Expected len(restart_lr) to be equal to len(optimizer.param_groups), "
"got {} and {} instead".format(len(restart_lr), self.len_param_groups)
)
if init_interval <= 0:
raise ValueError(
"init_interval must be a positive number, got {} instead".format(
init_interval
)
)
group_num = self.len_param_groups
self._init_interval = init_interval
self._min_lr = [min_lr] * group_num if isinstance(min_lr, float) else min_lr
self._restart_lr = (
[restart_lr] * group_num if isinstance(restart_lr, float) else restart_lr
)
self._restart_interval = restart_interval
self._restart_multiplier = restart_multiplier
self.last_epoch = last_epoch
self.base_lrs = base_lrs
super().__init__()
def on_batch_end(self, net, training, **kwargs):
if self.last_epoch < self._init_interval:
return self._calc(self.last_epoch, self._init_interval, self.base_lrs)
elif self._restart_interval is not None:
cycle_epoch = (
self.last_epoch - self._init_interval
) % self._restart_interval
lrs = self.base_lrs if self._restart_lr is None else self._restart_lr
return self._calc(cycle_epoch, self._restart_interval, lrs)
elif self._restart_multiplier is not None:
n = self._get_n(self.last_epoch)
sn_prev = self._partial_sum(n)
cycle_epoch = self.last_epoch - sn_prev
interval = self._init_interval * self._restart_multiplier ** n
lrs = self.base_lrs if self._restart_lr is None else self._restart_lr
return self._calc(cycle_epoch, interval, lrs)
else:
return self._min_lr
def _calc(self, t, T, lrs):
return [
min_lr + (lr - min_lr) * (1 + cos(pi * t / T)) / 2
for lr, min_lr in zip(lrs, self._min_lr)
]
def _get_n(self, epoch):
a = self._init_interval
r = self._restart_multiplier
_t = 1 - (1 - r) * epoch / a
return floor(log(_t, r))
def _partial_sum(self, n):
a = self._init_interval
r = self._restart_multiplier
return a * (1 - r ** n) / (1 - r)
class LearningRateDecayCallback(skorch.callbacks.Callback):
def __init__(
self,
config,
):
super().__init__()
self.lr_warmup_end = config.lr_warmup_end
self.lr_warmup_start = config.lr_warmup_start
self.learning_rate = config.learning_rate
self.warmup_batch = config.warmup_epoch * config.batch_per_epoch
self.final_batch = config.final_epoch * config.batch_per_epoch
self.batch_idx = 0
def on_batch_end(self, net, training, **kwargs):
"""
:param trainer:
:type trainer:
:param pl_module:
:type pl_module:
:param batch:
:type batch:
:param batch_idx:
:type batch_idx:
:param dataloader_idx:
:type dataloader_idx:
"""
# to avoid updating after validation batch
if training:
if self.batch_idx < self.warmup_batch:
# linear warmup, in paper: start from 0.1 to 1 over lr_warmup_end batches
lr_mult = float(self.batch_idx) / float(max(1, self.warmup_batch))
lr = self.lr_warmup_start + lr_mult * (
self.lr_warmup_end - self.lr_warmup_start
)
else:
# Cosine learning rate decay
progress = float(self.batch_idx - self.warmup_batch) / float(
max(1, self.final_batch - self.warmup_batch)
)
lr = max(
self.learning_rate
+ 0.5
* (1.0 + math.cos(math.pi * progress))
* (self.lr_warmup_end - self.learning_rate),
self.learning_rate,
)
net.lr = lr
# for param_group in net.optimizer.param_groups:
# param_group["lr"] = lr
self.batch_idx += 1
class LRAnnealing(skorch.callbacks.Callback):
def on_epoch_end(self, net, **kwargs):
if not net.history[-1]["valid_loss_best"]:
net.lr /= 4.0
|