"""
Loss function implementations.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia.geometry import warp_perspective

from ..misc.geometry_utils import keypoints_to_grid, get_dist_mask, get_common_line_mask


def get_loss_and_weights(model_cfg, device=torch.device("cuda")):
    """Get loss functions and either static or dynamic weighting."""
    # Get the global weighting policy
    w_policy = model_cfg.get("weighting_policy", "static")
    if not w_policy in ["static", "dynamic"]:
        raise ValueError("[Error] Not supported weighting policy.")

    loss_func = {}
    loss_weight = {}
    # Get junction loss function and weight
    w_junc, junc_loss_func = get_junction_loss_and_weight(model_cfg, w_policy)
    loss_func["junc_loss"] = junc_loss_func.to(device)
    loss_weight["w_junc"] = w_junc

    # Get heatmap loss function and weight
    w_heatmap, heatmap_loss_func = get_heatmap_loss_and_weight(
        model_cfg, w_policy, device
    )
    loss_func["heatmap_loss"] = heatmap_loss_func.to(device)
    loss_weight["w_heatmap"] = w_heatmap

    # [Optionally] get descriptor loss function and weight
    if model_cfg.get("descriptor_loss_func", None) is not None:
        w_descriptor, descriptor_loss_func = get_descriptor_loss_and_weight(
            model_cfg, w_policy
        )
        loss_func["descriptor_loss"] = descriptor_loss_func.to(device)
        loss_weight["w_desc"] = w_descriptor

    return loss_func, loss_weight


def get_junction_loss_and_weight(model_cfg, global_w_policy):
    """Get the junction loss function and weight."""
    junction_loss_cfg = model_cfg.get("junction_loss_cfg", {})

    # Get the junction loss weight
    w_policy = junction_loss_cfg.get("policy", global_w_policy)
    if w_policy == "static":
        w_junc = torch.tensor(model_cfg["w_junc"], dtype=torch.float32)
    elif w_policy == "dynamic":
        w_junc = nn.Parameter(
            torch.tensor(model_cfg["w_junc"], dtype=torch.float32), requires_grad=True
        )
    else:
        raise ValueError("[Error] Unknown weighting policy for junction loss weight.")

    # Get the junction loss function
    junc_loss_name = model_cfg.get("junction_loss_func", "superpoint")
    if junc_loss_name == "superpoint":
        junc_loss_func = JunctionDetectionLoss(
            model_cfg["grid_size"], model_cfg["keep_border_valid"]
        )
    else:
        raise ValueError("[Error] Not supported junction loss function.")

    return w_junc, junc_loss_func


def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device):
    """Get the heatmap loss function and weight."""
    heatmap_loss_cfg = model_cfg.get("heatmap_loss_cfg", {})

    # Get the heatmap loss weight
    w_policy = heatmap_loss_cfg.get("policy", global_w_policy)
    if w_policy == "static":
        w_heatmap = torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32)
    elif w_policy == "dynamic":
        w_heatmap = nn.Parameter(
            torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32),
            requires_grad=True,
        )
    else:
        raise ValueError("[Error] Unknown weighting policy for junction loss weight.")

    # Get the corresponding heatmap loss based on the config
    heatmap_loss_name = model_cfg.get("heatmap_loss_func", "cross_entropy")
    if heatmap_loss_name == "cross_entropy":
        # Get the heatmap class weight (always static)
        heatmap_class_w = model_cfg.get("w_heatmap_class", 1.0)
        class_weight = (
            torch.tensor(np.array([1.0, heatmap_class_w])).to(torch.float).to(device)
        )
        heatmap_loss_func = HeatmapLoss(class_weight=class_weight)
    else:
        raise ValueError("[Error] Not supported heatmap loss function.")

    return w_heatmap, heatmap_loss_func


def get_descriptor_loss_and_weight(model_cfg, global_w_policy):
    """Get the descriptor loss function and weight."""
    descriptor_loss_cfg = model_cfg.get("descriptor_loss_cfg", {})

    # Get the descriptor loss weight
    w_policy = descriptor_loss_cfg.get("policy", global_w_policy)
    if w_policy == "static":
        w_descriptor = torch.tensor(model_cfg["w_desc"], dtype=torch.float32)
    elif w_policy == "dynamic":
        w_descriptor = nn.Parameter(
            torch.tensor(model_cfg["w_desc"], dtype=torch.float32), requires_grad=True
        )
    else:
        raise ValueError("[Error] Unknown weighting policy for descriptor loss weight.")

    # Get the descriptor loss function
    descriptor_loss_name = model_cfg.get("descriptor_loss_func", "regular_sampling")
    if descriptor_loss_name == "regular_sampling":
        descriptor_loss_func = TripletDescriptorLoss(
            descriptor_loss_cfg["grid_size"],
            descriptor_loss_cfg["dist_threshold"],
            descriptor_loss_cfg["margin"],
        )
    else:
        raise ValueError("[Error] Not supported descriptor loss function.")

    return w_descriptor, descriptor_loss_func


def space_to_depth(input_tensor, grid_size):
    """PixelUnshuffle for pytorch."""
    N, C, H, W = input_tensor.size()
    # (N, C, H//bs, bs, W//bs, bs)
    x = input_tensor.view(N, C, H // grid_size, grid_size, W // grid_size, grid_size)
    # (N, bs, bs, C, H//bs, W//bs)
    x = x.permute(0, 3, 5, 1, 2, 4).contiguous()
    # (N, C*bs^2, H//bs, W//bs)
    x = x.view(N, C * (grid_size**2), H // grid_size, W // grid_size)
    return x


def junction_detection_loss(
    junction_map, junc_predictions, valid_mask=None, grid_size=8, keep_border=True
):
    """Junction detection loss."""
    # Convert junc_map to channel tensor
    junc_map = space_to_depth(junction_map, grid_size)
    map_shape = junc_map.shape[-2:]
    batch_size = junc_map.shape[0]
    dust_bin_label = (
        torch.ones([batch_size, 1, map_shape[0], map_shape[1]])
        .to(junc_map.device)
        .to(torch.int)
    )
    junc_map = torch.cat([junc_map * 2, dust_bin_label], dim=1)
    labels = torch.argmax(
        junc_map.to(torch.float)
        + torch.distributions.Uniform(0, 0.1)
        .sample(junc_map.shape)
        .to(junc_map.device),
        dim=1,
    )

    # Also convert the valid mask to channel tensor
    valid_mask = torch.ones(junction_map.shape) if valid_mask is None else valid_mask
    valid_mask = space_to_depth(valid_mask, grid_size)

    # Compute junction loss on the border patch or not
    if keep_border:
        valid_mask = (
            torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) > 0
        )
    else:
        valid_mask = (
            torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True)
            >= grid_size * grid_size
        )

    # Compute the classification loss
    loss_func = nn.CrossEntropyLoss(reduction="none")
    # The loss still need NCHW format
    loss = loss_func(input=junc_predictions, target=labels.to(torch.long))

    # Weighted sum by the valid mask
    loss_ = torch.sum(
        loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[0, 1, 2]
    )
    loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float), dim=1))

    return loss_final


def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None, class_weight=None):
    """Heatmap prediction loss."""
    # Compute the classification loss on each pixel
    if class_weight is None:
        loss_func = nn.CrossEntropyLoss(reduction="none")
    else:
        loss_func = nn.CrossEntropyLoss(class_weight, reduction="none")

    loss = loss_func(
        input=heatmap_pred, target=torch.squeeze(heatmap_gt.to(torch.long), dim=1)
    )

    # Weighted sum by the valid mask
    # Sum over H and W
    loss_spatial_sum = torch.sum(
        loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[1, 2]
    )
    valid_spatial_sum = torch.sum(
        torch.squeeze(valid_mask.to(torch.float32), dim=1), dim=[1, 2]
    )
    # Mean to single scalar over batch dimension
    loss = torch.sum(loss_spatial_sum) / torch.sum(valid_spatial_sum)

    return loss


class JunctionDetectionLoss(nn.Module):
    """Junction detection loss."""

    def __init__(self, grid_size, keep_border):
        super(JunctionDetectionLoss, self).__init__()
        self.grid_size = grid_size
        self.keep_border = keep_border

    def forward(self, prediction, target, valid_mask=None):
        return junction_detection_loss(
            target, prediction, valid_mask, self.grid_size, self.keep_border
        )


class HeatmapLoss(nn.Module):
    """Heatmap prediction loss."""

    def __init__(self, class_weight):
        super(HeatmapLoss, self).__init__()
        self.class_weight = class_weight

    def forward(self, prediction, target, valid_mask=None):
        return heatmap_loss(target, prediction, valid_mask, self.class_weight)


class RegularizationLoss(nn.Module):
    """Module for regularization loss."""

    def __init__(self):
        super(RegularizationLoss, self).__init__()
        self.name = "regularization_loss"
        self.loss_init = torch.zeros([])

    def forward(self, loss_weights):
        # Place it to the same device
        loss = self.loss_init.to(loss_weights["w_junc"].device)
        for _, val in loss_weights.items():
            if isinstance(val, nn.Parameter):
                loss += val

        return loss


def triplet_loss(
    desc_pred1,
    desc_pred2,
    points1,
    points2,
    line_indices,
    epoch,
    grid_size=8,
    dist_threshold=8,
    init_dist_threshold=64,
    margin=1,
):
    """Regular triplet loss for descriptor learning."""
    b_size, _, Hc, Wc = desc_pred1.size()
    img_size = (Hc * grid_size, Wc * grid_size)
    device = desc_pred1.device

    # Extract valid keypoints
    n_points = line_indices.size()[1]
    valid_points = line_indices.bool().flatten()
    n_correct_points = torch.sum(valid_points).item()
    if n_correct_points == 0:
        return torch.tensor(0.0, dtype=torch.float, device=device)

    # Check which keypoints are too close to be matched
    # dist_threshold is decreased at each epoch for easier training
    dist_threshold = max(dist_threshold, 2 * init_dist_threshold // (epoch + 1))
    dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold)

    # Additionally ban negative mining along the same line
    common_line_mask = get_common_line_mask(line_indices, valid_points)
    dist_mask = dist_mask | common_line_mask

    # Convert the keypoints to a grid suitable for interpolation
    grid1 = keypoints_to_grid(points1, img_size)
    grid2 = keypoints_to_grid(points2, img_size)

    # Extract the descriptors
    desc1 = (
        F.grid_sample(desc_pred1, grid1)
        .permute(0, 2, 3, 1)
        .reshape(b_size * n_points, -1)[valid_points]
    )
    desc1 = F.normalize(desc1, dim=1)
    desc2 = (
        F.grid_sample(desc_pred2, grid2)
        .permute(0, 2, 3, 1)
        .reshape(b_size * n_points, -1)[valid_points]
    )
    desc2 = F.normalize(desc2, dim=1)
    desc_dists = 2 - 2 * (desc1 @ desc2.t())

    # Positive distance loss
    pos_dist = torch.diag(desc_dists)

    # Negative distance loss
    max_dist = torch.tensor(4.0, dtype=torch.float, device=device)
    desc_dists[
        torch.arange(n_correct_points, dtype=torch.long),
        torch.arange(n_correct_points, dtype=torch.long),
    ] = max_dist
    desc_dists[dist_mask] = max_dist
    neg_dist = torch.min(
        torch.min(desc_dists, dim=1)[0], torch.min(desc_dists, dim=0)[0]
    )

    triplet_loss = F.relu(margin + pos_dist - neg_dist)
    return triplet_loss, grid1, grid2, valid_points


class TripletDescriptorLoss(nn.Module):
    """Triplet descriptor loss."""

    def __init__(self, grid_size, dist_threshold, margin):
        super(TripletDescriptorLoss, self).__init__()
        self.grid_size = grid_size
        self.init_dist_threshold = 64
        self.dist_threshold = dist_threshold
        self.margin = margin

    def forward(self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch):
        return self.descriptor_loss(
            desc_pred1, desc_pred2, points1, points2, line_indices, epoch
        )

    # The descriptor loss based on regularly sampled points along the lines
    def descriptor_loss(
        self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch
    ):
        return torch.mean(
            triplet_loss(
                desc_pred1,
                desc_pred2,
                points1,
                points2,
                line_indices,
                epoch,
                self.grid_size,
                self.dist_threshold,
                self.init_dist_threshold,
                self.margin,
            )[0]
        )


class TotalLoss(nn.Module):
    """Total loss summing junction, heatma, descriptor
    and regularization losses."""

    def __init__(self, loss_funcs, loss_weights, weighting_policy):
        super(TotalLoss, self).__init__()
        # Whether we need to compute the descriptor loss
        self.compute_descriptors = "descriptor_loss" in loss_funcs.keys()

        self.loss_funcs = loss_funcs
        self.loss_weights = loss_weights
        self.weighting_policy = weighting_policy

        # Always add regularization loss (it will return zero if not used)
        self.loss_funcs["reg_loss"] = RegularizationLoss().cuda()

    def forward(
        self, junc_pred, junc_target, heatmap_pred, heatmap_target, valid_mask=None
    ):
        """Detection only loss."""
        # Compute the junction loss
        junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target, valid_mask)
        # Compute the heatmap loss
        heatmap_loss = self.loss_funcs["heatmap_loss"](
            heatmap_pred, heatmap_target, valid_mask
        )

        # Compute the total loss.
        if self.weighting_policy == "dynamic":
            reg_loss = self.loss_funcs["reg_loss"](self.loss_weights)
            total_loss = (
                junc_loss * torch.exp(-self.loss_weights["w_junc"])
                + heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"])
                + reg_loss
            )

            return {
                "total_loss": total_loss,
                "junc_loss": junc_loss,
                "heatmap_loss": heatmap_loss,
                "reg_loss": reg_loss,
                "w_junc": torch.exp(-self.loss_weights["w_junc"]).item(),
                "w_heatmap": torch.exp(-self.loss_weights["w_heatmap"]).item(),
            }

        elif self.weighting_policy == "static":
            total_loss = (
                junc_loss * self.loss_weights["w_junc"]
                + heatmap_loss * self.loss_weights["w_heatmap"]
            )

            return {
                "total_loss": total_loss,
                "junc_loss": junc_loss,
                "heatmap_loss": heatmap_loss,
            }

        else:
            raise ValueError("[Error] Unknown weighting policy.")

    def forward_descriptors(
        self,
        junc_map_pred1,
        junc_map_pred2,
        junc_map_target1,
        junc_map_target2,
        heatmap_pred1,
        heatmap_pred2,
        heatmap_target1,
        heatmap_target2,
        line_points1,
        line_points2,
        line_indices,
        desc_pred1,
        desc_pred2,
        epoch,
        valid_mask1=None,
        valid_mask2=None,
    ):
        """Loss for detection + description."""
        # Compute junction loss
        junc_loss = self.loss_funcs["junc_loss"](
            torch.cat([junc_map_pred1, junc_map_pred2], dim=0),
            torch.cat([junc_map_target1, junc_map_target2], dim=0),
            torch.cat([valid_mask1, valid_mask2], dim=0),
        )
        # Get junction loss weight (dynamic or not)
        if isinstance(self.loss_weights["w_junc"], nn.Parameter):
            w_junc = torch.exp(-self.loss_weights["w_junc"])
        else:
            w_junc = self.loss_weights["w_junc"]

        # Compute heatmap loss
        heatmap_loss = self.loss_funcs["heatmap_loss"](
            torch.cat([heatmap_pred1, heatmap_pred2], dim=0),
            torch.cat([heatmap_target1, heatmap_target2], dim=0),
            torch.cat([valid_mask1, valid_mask2], dim=0),
        )
        # Get heatmap loss weight (dynamic or not)
        if isinstance(self.loss_weights["w_heatmap"], nn.Parameter):
            w_heatmap = torch.exp(-self.loss_weights["w_heatmap"])
        else:
            w_heatmap = self.loss_weights["w_heatmap"]

        # Compute the descriptor loss
        descriptor_loss = self.loss_funcs["descriptor_loss"](
            desc_pred1, desc_pred2, line_points1, line_points2, line_indices, epoch
        )
        # Get descriptor loss weight (dynamic or not)
        if isinstance(self.loss_weights["w_desc"], nn.Parameter):
            w_descriptor = torch.exp(-self.loss_weights["w_desc"])
        else:
            w_descriptor = self.loss_weights["w_desc"]

        # Update the total loss
        total_loss = (
            junc_loss * w_junc
            + heatmap_loss * w_heatmap
            + descriptor_loss * w_descriptor
        )
        outputs = {
            "junc_loss": junc_loss,
            "heatmap_loss": heatmap_loss,
            "w_junc": w_junc.item() if isinstance(w_junc, nn.Parameter) else w_junc,
            "w_heatmap": w_heatmap.item()
            if isinstance(w_heatmap, nn.Parameter)
            else w_heatmap,
            "descriptor_loss": descriptor_loss,
            "w_desc": w_descriptor.item()
            if isinstance(w_descriptor, nn.Parameter)
            else w_descriptor,
        }

        # Compute the regularization loss
        reg_loss = self.loss_funcs["reg_loss"](self.loss_weights)
        total_loss += reg_loss
        outputs.update({"reg_loss": reg_loss, "total_loss": total_loss})

        return outputs