File size: 664 Bytes
404d2af
 
 
 
 
 
 
8b973ee
404d2af
 
 
 
 
 
8b973ee
404d2af
 
 
 
 
8b973ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""
This file implements different learning rate schedulers
"""
import torch


def get_lr_scheduler(lr_decay, lr_decay_cfg, optimizer):
    """Get the learning rate scheduler according to the config."""
    # If no lr_decay is specified => return None
    if (lr_decay == False) or (lr_decay_cfg is None):
        schduler = None
    # Exponential decay
    elif (lr_decay == True) and (lr_decay_cfg["policy"] == "exp"):
        schduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=lr_decay_cfg["gamma"]
        )
    # Unknown policy
    else:
        raise ValueError("[Error] Unknow learning rate decay policy!")

    return schduler