Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch.optim.optimizer import Optimizer | |
import pytorch_lightning as pl | |
from torch.optim.lr_scheduler import _LRScheduler | |
class BaseScheduler(object): | |
"""Base class for the step-wise scheduler logic. | |
Args: | |
optimizer (Optimize): Optimizer instance to apply lr schedule on. | |
Subclass this and overwrite ``_get_lr`` to write your own step-wise scheduler. | |
""" | |
def __init__(self, optimizer): | |
self.optimizer = optimizer | |
self.step_num = 0 | |
def zero_grad(self): | |
self.optimizer.zero_grad() | |
def _get_lr(self): | |
raise NotImplementedError | |
def _set_lr(self, lr): | |
for param_group in self.optimizer.param_groups: | |
param_group["lr"] = lr | |
def step(self, metrics=None, epoch=None): | |
"""Update step-wise learning rate before optimizer.step.""" | |
self.step_num += 1 | |
lr = self._get_lr() | |
self._set_lr(lr) | |
def load_state_dict(self, state_dict): | |
self.__dict__.update(state_dict) | |
def state_dict(self): | |
return {key: value for key, value in self.__dict__.items() if key != "optimizer"} | |
def as_tensor(self, start=0, stop=100_000): | |
"""Returns the scheduler values from start to stop.""" | |
lr_list = [] | |
for _ in range(start, stop): | |
self.step_num += 1 | |
lr_list.append(self._get_lr()) | |
self.step_num = 0 | |
return torch.tensor(lr_list) | |
def plot(self, start=0, stop=100_000): # noqa | |
"""Plot the scheduler values from start to stop.""" | |
import matplotlib.pyplot as plt | |
all_lr = self.as_tensor(start=start, stop=stop) | |
plt.plot(all_lr.numpy()) | |
plt.show() | |
class DPTNetScheduler(BaseScheduler): | |
"""Dual Path Transformer Scheduler used in [1] | |
Args: | |
optimizer (Optimizer): Optimizer instance to apply lr schedule on. | |
steps_per_epoch (int): Number of steps per epoch. | |
d_model(int): The number of units in the layer output. | |
warmup_steps (int): The number of steps in the warmup stage of training. | |
noam_scale (float): Linear increase rate in first phase. | |
exp_max (float): Max learning rate in second phase. | |
exp_base (float): Exp learning rate base in second phase. | |
Schedule: | |
This scheduler increases the learning rate linearly for the first | |
``warmup_steps``, and then decay it by 0.98 for every two epochs. | |
References | |
[1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct Context- | |
Aware Modeling for End-to-End Monaural Speech Separation" Interspeech 2020. | |
""" | |
def __init__( | |
self, | |
optimizer, | |
steps_per_epoch, | |
d_model, | |
warmup_steps=4000, | |
noam_scale=1.0, | |
exp_max=0.0004, | |
exp_base=0.98, | |
): | |
super().__init__(optimizer) | |
self.noam_scale = noam_scale | |
self.d_model = d_model | |
self.warmup_steps = warmup_steps | |
self.exp_max = exp_max | |
self.exp_base = exp_base | |
self.steps_per_epoch = steps_per_epoch | |
self.epoch = 0 | |
def _get_lr(self): | |
if self.step_num % self.steps_per_epoch == 0: | |
self.epoch += 1 | |
if self.step_num > self.warmup_steps: | |
# exp decaying | |
lr = self.exp_max * (self.exp_base ** ((self.epoch - 1) // 2)) | |
else: | |
# noam | |
lr = ( | |
self.noam_scale | |
* self.d_model ** (-0.5) | |
* min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5)) | |
) | |
return lr | |
class CustomExponentialLR(_LRScheduler): | |
def __init__(self, optimizer, gamma, step_size, last_epoch=-1): | |
self.gamma = gamma | |
self.step_size = step_size | |
self.base_lrs = list(map(lambda group: group['lr'], optimizer.param_groups)) | |
super(CustomExponentialLR, self).__init__(optimizer, last_epoch) | |
def get_lr(self): | |
if self.last_epoch == 0 or (self.last_epoch + 1) % self.step_size != 0: | |
return [group['lr'] for group in self.optimizer.param_groups] | |
return [lr * self.gamma for lr in self.base_lrs] | |
# Backward compat | |
_BaseScheduler = BaseScheduler |