Vincentqyw
fix: roma
358ab8f
raw
history blame
1.63 kB
import torch
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
def build_optimizer(model, config):
name = config.TRAINER.OPTIMIZER
lr = config.TRAINER.TRUE_LR
if name == "adam":
return torch.optim.Adam(
model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY
)
elif name == "adamw":
return torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY
)
else:
raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
def build_scheduler(config, optimizer):
"""
Returns:
scheduler (dict):{
'scheduler': lr_scheduler,
'interval': 'step', # or 'epoch'
'monitor': 'val_f1', (optional)
'frequency': x, (optional)
}
"""
scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL}
name = config.TRAINER.SCHEDULER
if name == "MultiStepLR":
scheduler.update(
{
"scheduler": MultiStepLR(
optimizer,
config.TRAINER.MSLR_MILESTONES,
gamma=config.TRAINER.MSLR_GAMMA,
)
}
)
elif name == "CosineAnnealing":
scheduler.update(
{"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}
)
elif name == "ExponentialLR":
scheduler.update(
{"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}
)
else:
raise NotImplementedError()
return scheduler