import torch
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR


def build_optimizer(model, config):
    name = config.TRAINER.OPTIMIZER
    lr = config.TRAINER.TRUE_LR

    if name == "adam":
        return torch.optim.Adam(
            model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY
        )
    elif name == "adamw":
        return torch.optim.AdamW(
            model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY
        )
    else:
        raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")


def build_scheduler(config, optimizer):
    """
    Returns:
        scheduler (dict):{
            'scheduler': lr_scheduler,
            'interval': 'step',  # or 'epoch'
            'monitor': 'val_f1', (optional)
            'frequency': x, (optional)
        }
    """
    scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL}
    name = config.TRAINER.SCHEDULER

    if name == "MultiStepLR":
        scheduler.update(
            {
                "scheduler": MultiStepLR(
                    optimizer,
                    config.TRAINER.MSLR_MILESTONES,
                    gamma=config.TRAINER.MSLR_GAMMA,
                )
            }
        )
    elif name == "CosineAnnealing":
        scheduler.update(
            {"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}
        )
    elif name == "ExponentialLR":
        scheduler.update(
            {"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}
        )
    else:
        raise NotImplementedError()

    return scheduler