File size: 2,637 Bytes
1cc0836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py
# reference: https://github.com/lifeiteng/vall-e
import math

import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import Adam


class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
    """
    Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
    """

    def __init__(
        self,
        optimizer,
        init_lr,
        peak_lr,
        end_lr,
        warmup_steps=10000,
        total_steps=400000,
        current_step=0,
    ):
        self.init_lr = init_lr
        self.peak_lr = peak_lr
        self.end_lr = end_lr
        self.optimizer = optimizer
        self._warmup_rate = (peak_lr - init_lr) / warmup_steps
        self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
        self._current_step = current_step
        self.lr = init_lr
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self._last_lr = [self.lr]

    def set_lr(self, lr):
        self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
        for g in self.optimizer.param_groups:
            # g['lr'] = lr
            g["lr"] = self.end_lr  ###้”ๅฎš็”จ็บฟๆ€ง

    def step(self):
        if self._current_step < self.warmup_steps:
            lr = self.init_lr + self._warmup_rate * self._current_step

        elif self._current_step > self.total_steps:
            lr = self.end_lr

        else:
            decay_ratio = (self._current_step - self.warmup_steps) / (
                self.total_steps - self.warmup_steps
            )
            if decay_ratio < 0.0 or decay_ratio > 1.0:
                raise RuntimeError(
                    "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
                )
            coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
            lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)

        self.lr = lr = self.end_lr = 0.002  ###้”ๅฎš็”จ็บฟๆ€ง###ไธๅฌ่ฏ๏ผŒ็›ดๆŽฅ้”ๅฎš๏ผ
        self.set_lr(lr)
        self.lr = lr
        self._current_step += 1
        return self.lr


if __name__ == "__main__":
    m = nn.Linear(10, 10)
    opt = Adam(m.parameters(), lr=1e-4)
    s = WarmupCosineLRSchedule(
        opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
    )
    lrs = []
    for i in range(25000):
        s.step()
        lrs.append(s.lr)
        print(s.lr)

    plt.plot(lrs)
    plt.plot(range(0, 25000), lrs)
    plt.show()