File size: 6,407 Bytes
0b11a42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import math
from collections.abc import Iterable
from math import cos, floor, log, pi

import skorch
from torch.optim.lr_scheduler import _LRScheduler

_LRScheduler


class CyclicCosineDecayLR(skorch.callbacks.Callback):
    def __init__(
        self,
        optimizer,
        init_interval,
        min_lr,
        len_param_groups,
        base_lrs,
        restart_multiplier=None,
        restart_interval=None,
        restart_lr=None,
        last_epoch=-1,
    ):
        """
        Initialize new CyclicCosineDecayLR object
        :param optimizer: (Optimizer) - Wrapped optimizer.
        :param init_interval: (int) - Initial decay cycle interval.
        :param min_lr: (float or iterable of floats) - Minimal learning rate.
        :param restart_multiplier: (float) - Multiplication coefficient for increasing cycle intervals,
            if this parameter is set, restart_interval must be None.
        :param restart_interval: (int) - Restart interval for fixed cycle intervals,
            if this parameter is set, restart_multiplier must be None.
        :param restart_lr: (float or iterable of floats) - Optional, the learning rate at cycle restarts,
            if not provided, initial learning rate will be used.
        :param last_epoch: (int) - Last epoch.
        """
        self.len_param_groups = len_param_groups
        if restart_interval is not None and restart_multiplier is not None:
            raise ValueError(
                "You can either set restart_interval or restart_multiplier but not both"
            )

        if isinstance(min_lr, Iterable) and len(min_lr) != self.len_param_groups:
            raise ValueError(
                "Expected len(min_lr) to be equal to len(optimizer.param_groups), "
                "got {} and {} instead".format(len(min_lr), self.len_param_groups)
            )

        if isinstance(restart_lr, Iterable) and len(restart_lr) != len(
            self.len_param_groups
        ):
            raise ValueError(
                "Expected len(restart_lr) to be equal to len(optimizer.param_groups), "
                "got {} and {} instead".format(len(restart_lr), self.len_param_groups)
            )

        if init_interval <= 0:
            raise ValueError(
                "init_interval must be a positive number, got {} instead".format(
                    init_interval
                )
            )

        group_num = self.len_param_groups
        self._init_interval = init_interval
        self._min_lr = [min_lr] * group_num if isinstance(min_lr, float) else min_lr
        self._restart_lr = (
            [restart_lr] * group_num if isinstance(restart_lr, float) else restart_lr
        )
        self._restart_interval = restart_interval
        self._restart_multiplier = restart_multiplier
        self.last_epoch = last_epoch
        self.base_lrs = base_lrs
        super().__init__()

    def on_batch_end(self, net, training, **kwargs):
        if self.last_epoch < self._init_interval:
            return self._calc(self.last_epoch, self._init_interval, self.base_lrs)

        elif self._restart_interval is not None:
            cycle_epoch = (
                self.last_epoch - self._init_interval
            ) % self._restart_interval
            lrs = self.base_lrs if self._restart_lr is None else self._restart_lr
            return self._calc(cycle_epoch, self._restart_interval, lrs)

        elif self._restart_multiplier is not None:
            n = self._get_n(self.last_epoch)
            sn_prev = self._partial_sum(n)
            cycle_epoch = self.last_epoch - sn_prev
            interval = self._init_interval * self._restart_multiplier ** n
            lrs = self.base_lrs if self._restart_lr is None else self._restart_lr
            return self._calc(cycle_epoch, interval, lrs)
        else:
            return self._min_lr

    def _calc(self, t, T, lrs):
        return [
            min_lr + (lr - min_lr) * (1 + cos(pi * t / T)) / 2
            for lr, min_lr in zip(lrs, self._min_lr)
        ]

    def _get_n(self, epoch):
        a = self._init_interval
        r = self._restart_multiplier
        _t = 1 - (1 - r) * epoch / a
        return floor(log(_t, r))

    def _partial_sum(self, n):
        a = self._init_interval
        r = self._restart_multiplier
        return a * (1 - r ** n) / (1 - r)


class LearningRateDecayCallback(skorch.callbacks.Callback):
    def __init__(
        self,
        config,
    ):
        super().__init__()
        self.lr_warmup_end = config.lr_warmup_end
        self.lr_warmup_start = config.lr_warmup_start
        self.learning_rate = config.learning_rate
        self.warmup_batch = config.warmup_epoch * config.batch_per_epoch
        self.final_batch = config.final_epoch * config.batch_per_epoch

        self.batch_idx = 0

    def on_batch_end(self, net, training, **kwargs):
        """

        :param trainer:
        :type trainer:
        :param pl_module:
        :type pl_module:
        :param batch:
        :type batch:
        :param batch_idx:
        :type batch_idx:
        :param dataloader_idx:
        :type dataloader_idx:
        """
        # to avoid updating after validation batch
        if training:

            if self.batch_idx < self.warmup_batch:
                # linear warmup, in paper: start from 0.1 to 1 over lr_warmup_end batches
                lr_mult = float(self.batch_idx) / float(max(1, self.warmup_batch))
                lr = self.lr_warmup_start + lr_mult * (
                    self.lr_warmup_end - self.lr_warmup_start
                )
            else:
                # Cosine learning rate decay
                progress = float(self.batch_idx - self.warmup_batch) / float(
                    max(1, self.final_batch - self.warmup_batch)
                )
                lr = max(
                    self.learning_rate
                    + 0.5
                    * (1.0 + math.cos(math.pi * progress))
                    * (self.lr_warmup_end - self.learning_rate),
                    self.learning_rate,
                )
            net.lr = lr
            # for param_group in net.optimizer.param_groups:
            #   param_group["lr"] = lr

            self.batch_idx += 1


class LRAnnealing(skorch.callbacks.Callback):
    def on_epoch_end(self, net, **kwargs):
        if not net.history[-1]["valid_loss_best"]:
            net.lr /= 4.0