Vincentqyw
update: features and matchers
a80d6bb
raw
history blame
678 Bytes
"""
This file implements different learning rate schedulers
"""
import torch
def get_lr_scheduler(lr_decay, lr_decay_cfg, optimizer):
""" Get the learning rate scheduler according to the config. """
# If no lr_decay is specified => return None
if (lr_decay == False) or (lr_decay_cfg is None):
schduler = None
# Exponential decay
elif (lr_decay == True) and (lr_decay_cfg["policy"] == "exp"):
schduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer,
gamma=lr_decay_cfg["gamma"]
)
# Unknown policy
else:
raise ValueError("[Error] Unknow learning rate decay policy!")
return schduler