File size: 852 Bytes
8bf4dee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torch.optim import Optimizer
from typing import Callable

Schedule = Callable[[float], float]


def linear_schedule(
    start_val: float, end_val: float, end_fraction: float = 1.0
) -> Schedule:
    def func(progress_fraction: float) -> float:
        if progress_fraction >= end_fraction:
            return end_val
        else:
            return start_val + (end_val - start_val) * progress_fraction / end_fraction

    return func


def constant_schedule(val: float) -> Schedule:
    return lambda f: val


def schedule(name: str, start_val: float) -> Schedule:
    if name == "linear":
        return linear_schedule(start_val, 0)
    return constant_schedule(start_val)


def update_learning_rate(optimizer: Optimizer, learning_rate: float) -> None:
    for param_group in optimizer.param_groups:
        param_group["lr"] = learning_rate