""" Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2020 Ross Wightman
"""
import re
import torch
from torch import optim as optim
from utils.distributed import is_main_process
import logging
logger = logging.getLogger(__name__)
try:
    from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
    has_apex = True
except ImportError:
    has_apex = False


def add_weight_decay(model, weight_decay, no_decay_list=(), filter_bias_and_bn=True):
    named_param_tuples = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if filter_bias_and_bn and (len(param.shape) == 1 or name.endswith(".bias")):
            named_param_tuples.append([name, param, 0])
        elif name in no_decay_list:
            named_param_tuples.append([name, param, 0])
        else:
            named_param_tuples.append([name, param, weight_decay])
    return named_param_tuples


def add_different_lr(named_param_tuples_or_model, diff_lr_names, diff_lr, default_lr):
    """use lr=diff_lr for modules named found in diff_lr_names,
    otherwise use lr=default_lr

    Args:
        named_param_tuples_or_model: List([name, param, weight_decay]), or nn.Module
        diff_lr_names: List(str)
        diff_lr: float
        default_lr: float
    Returns:
        named_param_tuples_with_lr: List([name, param, weight_decay, lr])
    """
    named_param_tuples_with_lr = []
    logger.info(f"diff_names: {diff_lr_names}, diff_lr: {diff_lr}")
    for name, p, wd in named_param_tuples_or_model:
        use_diff_lr = False
        for diff_name in diff_lr_names:
            # if diff_name in name:
            if re.search(diff_name, name) is not None:
                logger.info(f"param {name} use different_lr: {diff_lr}")
                use_diff_lr = True
                break

        named_param_tuples_with_lr.append(
            [name, p, wd, diff_lr if use_diff_lr else default_lr]
        )

    if is_main_process():
        for name, _, wd, diff_lr in named_param_tuples_with_lr:
            logger.info(f"param {name}: wd: {wd}, lr: {diff_lr}")

    return named_param_tuples_with_lr


def create_optimizer_params_group(named_param_tuples_with_lr):
    """named_param_tuples_with_lr: List([name, param, weight_decay, lr])"""
    group = {}
    for name, p, wd, lr in named_param_tuples_with_lr:
        if wd not in group:
            group[wd] = {}
        if lr not in group[wd]:
            group[wd][lr] = []
        group[wd][lr].append(p)

    optimizer_params_group = []
    for wd, lr_groups in group.items():
        for lr, p in lr_groups.items():
            optimizer_params_group.append(dict(
                params=p,
                weight_decay=wd,
                lr=lr
            ))
            logger.info(f"optimizer -- lr={lr} wd={wd} len(p)={len(p)}")
    return optimizer_params_group


def create_optimizer(args, model, filter_bias_and_bn=True):
    opt_lower = args.opt.lower()
    weight_decay = args.weight_decay
    # check for modules that requires different lr
    if hasattr(args, "different_lr") and args.different_lr.enable:
        diff_lr_module_names = args.different_lr.module_names
        diff_lr = args.different_lr.lr
    else:
        diff_lr_module_names = []
        diff_lr = None

    no_decay = {}
    if hasattr(model, 'no_weight_decay'):
        no_decay = model.no_weight_decay()
    named_param_tuples = add_weight_decay(
        model, weight_decay, no_decay, filter_bias_and_bn)
    named_param_tuples = add_different_lr(
        named_param_tuples, diff_lr_module_names, diff_lr, args.lr)
    parameters = create_optimizer_params_group(named_param_tuples)

    if 'fused' in opt_lower:
        assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'

    opt_args = dict(lr=args.lr, weight_decay=weight_decay)
    if hasattr(args, 'opt_eps') and args.opt_eps is not None:
        opt_args['eps'] = args.opt_eps
    if hasattr(args, 'opt_betas') and args.opt_betas is not None:
        opt_args['betas'] = args.opt_betas
    if hasattr(args, 'opt_args') and args.opt_args is not None:
        opt_args.update(args.opt_args)

    opt_split = opt_lower.split('_')
    opt_lower = opt_split[-1]
    if opt_lower == 'sgd' or opt_lower == 'nesterov':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
    elif opt_lower == 'momentum':
        opt_args.pop('eps', None)
        optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
    elif opt_lower == 'adam':
        optimizer = optim.Adam(parameters, **opt_args)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters, **opt_args)
    else:
        assert False and "Invalid optimizer"
        raise ValueError
    return optimizer