""" @Date: 2021/09/14 @description: """ class WarmupScheduler: def __init__(self, optimizer, lr_pow, init_lr, warmup_lr, warmup_step, max_step, **kwargs): self.lr_pow = lr_pow self.init_lr = init_lr self.running_lr = init_lr self.warmup_lr = warmup_lr self.warmup_step = warmup_step self.max_step = max_step self.optimizer = optimizer def step_update(self, cur_step): if cur_step < self.warmup_step: frac = cur_step / self.warmup_step step = self.warmup_lr - self.init_lr self.running_lr = self.init_lr + step * frac else: frac = (float(cur_step) - self.warmup_step) / (self.max_step - self.warmup_step) scale_running_lr = max((1. - frac), 0.) ** self.lr_pow self.running_lr = self.warmup_lr * scale_running_lr if self.optimizer is not None: for param_group in self.optimizer.param_groups: param_group['lr'] = self.running_lr if __name__ == '__main__': import matplotlib.pyplot as plt scheduler = WarmupScheduler(optimizer=None, lr_pow=4, init_lr=0.0000003, warmup_lr=0.00003, warmup_step=10000, max_step=100000) x = [] y = [] for i in range(100000): if i == 10000-1: print() scheduler.step_update(i) x.append(i) y.append(scheduler.running_lr) plt.plot(x, y, linewidth=1) plt.show()