""" This file implements different learning rate schedulers """ import torch def get_lr_scheduler(lr_decay, lr_decay_cfg, optimizer): """ Get the learning rate scheduler according to the config. """ # If no lr_decay is specified => return None if (lr_decay == False) or (lr_decay_cfg is None): schduler = None # Exponential decay elif (lr_decay == True) and (lr_decay_cfg["policy"] == "exp"): schduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=lr_decay_cfg["gamma"] ) # Unknown policy else: raise ValueError("[Error] Unknow learning rate decay policy!") return schduler