File size: 664 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
"""
This file implements different learning rate schedulers
"""
import torch
def get_lr_scheduler(lr_decay, lr_decay_cfg, optimizer):
"""Get the learning rate scheduler according to the config."""
# If no lr_decay is specified => return None
if (lr_decay == False) or (lr_decay_cfg is None):
schduler = None
# Exponential decay
elif (lr_decay == True) and (lr_decay_cfg["policy"] == "exp"):
schduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer, gamma=lr_decay_cfg["gamma"]
)
# Unknown policy
else:
raise ValueError("[Error] Unknow learning rate decay policy!")
return schduler
|