File size: 679 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from ganime.trainer.warmup.cosine import WarmUpCosine


def create_warmup_scheduler(trainer_config, num_devices):
    len_x_train = trainer_config["len_x_train"]
    batch_size = trainer_config["batch_size"]
    n_epochs = trainer_config["n_epochs"]

    total_steps = int(len_x_train / batch_size * n_epochs / num_devices)
    warmup_epoch_percentage = trainer_config["warmup_epoch_percentage"]
    warmup_steps = int(total_steps * warmup_epoch_percentage)

    scheduled_lrs = WarmUpCosine(
        lr_start=trainer_config["lr_start"],
        lr_max=trainer_config["lr_max"],
        warmup_steps=warmup_steps,
        total_steps=total_steps,
    )

    return scheduled_lrs