"""
Helpers to train with 16-bit precision.
"""

import numpy as np
import torch as th
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from . import logger

INITIAL_LOG_LOSS_SCALE = 20.0


def convert_module_to_f16(l):
    """
    Convert primitive modules to float16.
    """
    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        l.weight.data = l.weight.data.half()
        if l.bias is not None:
            l.bias.data = l.bias.data.half()


def convert_module_to_f32(l):
    """
    Convert primitive modules to float32, undoing convert_module_to_f16().
    """
    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        l.weight.data = l.weight.data.float()
        if l.bias is not None:
            l.bias.data = l.bias.data.float()


def make_master_params(param_groups_and_shapes):
    """
    Copy model parameters into a (differently-shaped) list of full-precision
    parameters.
    """
    master_params = []
    for param_group, shape in param_groups_and_shapes:
        master_param = nn.Parameter(
            _flatten_dense_tensors([
                param.detach().float() for (_, param) in param_group
            ]).view(shape))
        master_param.requires_grad = True
        master_params.append(master_param)
    return master_params


def model_grads_to_master_grads(param_groups_and_shapes, master_params):
    """
    Copy the gradients from the model parameters into the master parameters
    from make_master_params().
    """
    for master_param, (param_group, shape) in zip(master_params,
                                                  param_groups_and_shapes):
        master_param.grad = _flatten_dense_tensors([
            param_grad_or_zeros(param) for (_, param) in param_group
        ]).view(shape)


def master_params_to_model_params(param_groups_and_shapes, master_params):
    """
    Copy the master parameter data back into the model parameters.
    """
    # Without copying to a list, if a generator is passed, this will
    # silently not copy any parameters.
    for master_param, (param_group, _) in zip(master_params,
                                              param_groups_and_shapes):
        for (_, param), unflat_master_param in zip(
                param_group,
                unflatten_master_params(param_group, master_param.view(-1))):
            param.detach().copy_(unflat_master_param)


def unflatten_master_params(param_group, master_param):
    return _unflatten_dense_tensors(master_param,
                                    [param for (_, param) in param_group])


def get_param_groups_and_shapes(named_model_params):
    named_model_params = list(named_model_params)
    scalar_vector_named_params = (
        [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
        (-1),
    )
    matrix_named_params = (
        [(n, p) for (n, p) in named_model_params if p.ndim > 1],
        (1, -1),
    )
    return [scalar_vector_named_params, matrix_named_params]


def master_params_to_state_dict(model, param_groups_and_shapes, master_params,
                                use_fp16):
    if use_fp16:
        state_dict = model.state_dict()
        for master_param, (param_group, _) in zip(master_params,
                                                  param_groups_and_shapes):
            for (name, _), unflat_master_param in zip(
                    param_group,
                    unflatten_master_params(param_group,
                                            master_param.view(-1))):
                assert name in state_dict
                state_dict[name] = unflat_master_param
    else:
        state_dict = model.state_dict()
        for i, (name, _value) in enumerate(model.named_parameters()):
            assert name in state_dict
            state_dict[name] = master_params[i]
    return state_dict


def state_dict_to_master_params(model, state_dict, use_fp16):
    if use_fp16:
        named_model_params = [(name, state_dict[name])
                              for name, _ in model.named_parameters()]
        param_groups_and_shapes = get_param_groups_and_shapes(
            named_model_params)
        master_params = make_master_params(param_groups_and_shapes)
    else:
        master_params = [
            state_dict[name] for name, _ in model.named_parameters()
        ]
    return master_params


def zero_master_grads(master_params):
    for param in master_params:
        param.grad = None


def zero_grad(model_params):
    for param in model_params:
        # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
        if param.grad is not None:
            param.grad.detach_()
            param.grad.zero_()


def param_grad_or_zeros(param):
    if param.grad is not None:
        return param.grad.data.detach()
    else:
        return th.zeros_like(param)


class MixedPrecisionTrainer:

    def __init__(self,
                 *,
                 model,
                 use_fp16=False,
                 use_amp=False,
                 fp16_scale_growth=1e-3,
                 initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
                 model_name='ddpm',
                 submodule_name='',
                 model_params=None):
        self.model_name = model_name
        self.model = model
        self.use_fp16 = use_fp16
        self.use_amp = use_amp
        if self.use_amp:
            # https://github.com/pytorch/pytorch/issues/40497#issuecomment-1262373602
            # https://github.com/pytorch/pytorch/issues/111739
            self.scaler = th.cuda.amp.GradScaler(enabled=use_amp, init_scale=2**15, growth_interval=100)
            logger.log(model_name, 'enables AMP to accelerate training')
        else:
            logger.log(model_name, 'not enables AMP to accelerate training')

        self.fp16_scale_growth = fp16_scale_growth

        self.model_params = list(self.model.parameters(
        )) if model_params is None else list(model_params) if not isinstance(
            model_params, list) else model_params
        self.master_params = self.model_params
        self.param_groups_and_shapes = None
        self.lg_loss_scale = initial_lg_loss_scale

        if self.use_fp16:
            self.param_groups_and_shapes = get_param_groups_and_shapes(
                self.model.named_parameters())
            self.master_params = make_master_params(
                self.param_groups_and_shapes)
            self.model.convert_to_fp16()

    def zero_grad(self):
        zero_grad(self.model_params)

    def backward(self, loss: th.Tensor, disable_amp=False, **kwargs):
        """**kwargs: retain_graph=True
        """
        if self.use_fp16:
            loss_scale = 2**self.lg_loss_scale
            (loss * loss_scale).backward(**kwargs)
        elif self.use_amp and not disable_amp:
            self.scaler.scale(loss).backward(**kwargs)
        else:
            loss.backward(**kwargs)

    # def optimize(self, opt: th.optim.Optimizer, clip_grad=False):
    def optimize(self, opt: th.optim.Optimizer, clip_grad=True):
        if self.use_fp16:
            return self._optimize_fp16(opt)
        elif self.use_amp:
            return self._optimize_amp(opt, clip_grad)
        else:
            return self._optimize_normal(opt, clip_grad)

    def _optimize_fp16(self, opt: th.optim.Optimizer):
        logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
        model_grads_to_master_grads(self.param_groups_and_shapes,
                                    self.master_params)
        grad_norm, param_norm = self._compute_norms(
            grad_scale=2**self.lg_loss_scale)
        if check_overflow(grad_norm):
            self.lg_loss_scale -= 1
            logger.log(
                f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
            zero_master_grads(self.master_params)
            return False

        logger.logkv_mean("grad_norm", grad_norm)
        logger.logkv_mean("param_norm", param_norm)

        for p in self.master_params:
            p.grad.mul_(1.0 / (2**self.lg_loss_scale))
        opt.step()
        zero_master_grads(self.master_params)
        master_params_to_model_params(self.param_groups_and_shapes,
                                      self.master_params)
        self.lg_loss_scale += self.fp16_scale_growth
        return True

    def _optimize_amp(self, opt: th.optim.Optimizer, clip_grad=False):
        # https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping
        assert clip_grad
        self.scaler.unscale_(opt) # to calculate accurate gradients

        if clip_grad:
            th.nn.utils.clip_grad_norm_( # type: ignore
                self.master_params,
                5.0,
                norm_type=2,
                error_if_nonfinite=False,
                foreach=True,
            )   # clip before compute_norm

        grad_norm, param_norm = self._compute_norms()
        logger.logkv_mean("grad_norm", grad_norm)
        logger.logkv_mean("param_norm", param_norm)

        self.scaler.step(opt)
        self.scaler.update()
        return True

    def _optimize_normal(self, opt: th.optim.Optimizer, clip_grad:bool=False):

        assert clip_grad
        if clip_grad:
            th.nn.utils.clip_grad_norm_( # type: ignore
                self.master_params,
                5.0,
                norm_type=2,
                error_if_nonfinite=False,
                foreach=True,
            )   # clip before compute_norm

        grad_norm, param_norm = self._compute_norms()
        logger.logkv_mean("grad_norm", grad_norm)
        logger.logkv_mean("param_norm", param_norm)
        opt.step()
        return True

    def _compute_norms(self, grad_scale=1.0):
        grad_norm = 0.0
        param_norm = 0.0
        for p in self.master_params:
            with th.no_grad():
                param_norm += th.norm(p, p=2, dtype=th.float32).item()**2
                if p.grad is not None:
                    grad_norm += th.norm(p.grad, p=2,
                                         dtype=th.float32).item()**2
        return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)

    def master_params_to_state_dict(self, master_params, model=None):
        if model is None:
            model = self.model
        return master_params_to_state_dict(model, self.param_groups_and_shapes,
                                           master_params, self.use_fp16)

    def state_dict_to_master_params(self, state_dict, model=None):
        if model is None:
            model = self.model
        return state_dict_to_master_params(model, state_dict, self.use_fp16)

    def state_dict_to_master_params_given_submodule_name(
            self, state_dict, submodule_name):
        return state_dict_to_master_params(getattr(self.model, submodule_name),
                                           state_dict, self.use_fp16)


def check_overflow(value):
    return (value == float("inf")) or (value == -float("inf")) or (value
                                                                   != value)