File size: 1,611 Bytes
88b0dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
@Date: 2021/09/14
@description:
"""


class WarmupScheduler:
    def __init__(self, optimizer, lr_pow, init_lr, warmup_lr, warmup_step, max_step, **kwargs):
        self.lr_pow = lr_pow
        self.init_lr = init_lr
        self.running_lr = init_lr
        self.warmup_lr = warmup_lr
        self.warmup_step = warmup_step
        self.max_step = max_step
        self.optimizer = optimizer

    def step_update(self, cur_step):
        if cur_step < self.warmup_step:
            frac = cur_step / self.warmup_step
            step = self.warmup_lr - self.init_lr
            self.running_lr = self.init_lr + step * frac
        else:
            frac = (float(cur_step) - self.warmup_step) / (self.max_step - self.warmup_step)
            scale_running_lr = max((1. - frac), 0.) ** self.lr_pow
            self.running_lr = self.warmup_lr * scale_running_lr

        if self.optimizer is not None:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.running_lr


if __name__ == '__main__':
    import matplotlib.pyplot as plt

    scheduler = WarmupScheduler(optimizer=None,
                                lr_pow=4,
                                init_lr=0.0000003,
                                warmup_lr=0.00003,
                                warmup_step=10000,
                                max_step=100000)

    x = []
    y = []
    for i in range(100000):
        if i == 10000-1:
            print()
        scheduler.step_update(i)
        x.append(i)
        y.append(scheduler.running_lr)
    plt.plot(x, y, linewidth=1)
    plt.show()