# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File   pram -> metrics
@IDE    PyCharm
@Author fx221@cam.ac.uk
@Date   29/01/2024 16:32
=================================================='''
import torch
import numpy as np
import torch.nn.functional as F


class SeqIOU:
    def __init__(self, n_class, ignored_sids=[]):
        self.n_class = n_class
        self.ignored_sids = ignored_sids
        self.class_iou = np.zeros(n_class)
        self.precisions = []

    def add(self, pred, target):
        for i in range(self.n_class):
            inter = np.sum((pred == target) * (target == i))
            union = np.sum(target == i) + np.sum(pred == i) - inter
            if union > 0:
                self.class_iou[i] = inter / union

        acc = (pred == target)
        if len(self.ignored_sids) == 0:
            acc_ratio = np.sum(acc) / pred.shape[0]
        else:
            pred_mask = (pred >= 0)
            target_mask = (target >= 0)
            for i in self.ignored_sids:
                pred_mask = pred_mask & (pred == i)
                target_mask = target_mask & (target == i)

            acc = acc & (1 - pred_mask)
            tgt = (1 - target_mask)
            if np.sum(tgt) == 0:
                acc_ratio = 0
            else:
                acc_ratio = np.sum(acc) / np.sum(tgt)

        self.precisions.append(acc_ratio)

    def get_mean_iou(self):
        return np.mean(self.class_iou)

    def get_mean_precision(self):
        return np.mean(self.precisions)

    def clear(self):
        self.precisions = []
        self.class_iou = np.zeros(self.n_class)


def compute_iou(pred: np.ndarray, target: np.ndarray, n_class: int, ignored_ids=[]) -> float:
    class_iou = np.zeros(n_class)
    for i in range(n_class):
        if i in ignored_ids:
            continue
        inter = np.sum((pred == target) * (target == i))
        union = np.sum(target == i) + np.sum(pred == i) - inter
        if union > 0:
            class_iou[i] = inter / union

    return np.mean(class_iou)
    # return class_iou


def compute_precision(pred: np.ndarray, target: np.ndarray, ignored_ids: list = []) -> float:
    acc = (pred == target)
    if len(ignored_ids) == 0:
        return np.sum(acc) / pred.shape[0]
    else:
        pred_mask = (pred >= 0)
        target_mask = (target >= 0)
        for i in ignored_ids:
            pred_mask = pred_mask & (pred == i)
            target_mask = target_mask & (target == i)

        acc = acc & (1 - pred_mask)
        tgt = (1 - target_mask)
        if np.sum(tgt) == 0:
            return 0
        return np.sum(acc) / np.sum(tgt)


def compute_cls_corr(pred: torch.Tensor, target: torch.Tensor, k: int = 20) -> torch.Tensor:
    bs = pred.shape[0]
    _, target_ids = torch.topk(target, k=k, dim=1)
    target_ids = target_ids.cpu().numpy()
    _, top_ids = torch.topk(pred, k=k, dim=1)  # [B, k, 1]
    top_ids = top_ids.cpu().numpy()
    acc = 0
    for i in range(bs):
        # print('top_ids: ', i, top_ids[i], target_ids[i])
        overlap = [v for v in top_ids[i] if v in target_ids[i] and v >= 0]
        acc = acc + len(overlap) / k
    acc = acc / bs
    return torch.from_numpy(np.array([acc])).to(pred.device)


def compute_corr_incorr(pred: torch.Tensor, target: torch.Tensor, ignored_ids: list = []) -> tuple:
    '''
    :param pred: [B, N, C]
    :param target: [B, N]
    :param ignored_ids: []
    :return:
    '''
    pred_ids = torch.max(pred, dim=-1)[1]
    if len(ignored_ids) == 0:
        acc = (pred_ids == target)
        inacc = torch.logical_not(acc)
        acc_ratio = torch.sum(acc) / torch.numel(target)
        inacc_ratio = torch.sum(inacc) / torch.numel(target)
    else:
        acc = (pred_ids == target)
        inacc = torch.logical_not(acc)

        mask = torch.zeros_like(acc)
        for i in ignored_ids:
            mask = torch.logical_and(mask, (target == i))

        acc = torch.logical_and(acc, torch.logical_not(mask))
        acc_ratio = torch.sum(acc) / torch.numel(target)
        inacc_ratio = torch.sum(inacc) / torch.numel(target)

    return acc_ratio, inacc_ratio


def compute_seg_loss_weight(pred: torch.Tensor,
                            target: torch.Tensor,
                            background_id: int = 0,
                            weight_background: float = 0.1) -> torch.Tensor:
    '''
    :param pred: [B, C, N]
    :param target: [B, N]
    :param background_id:
    :param weight_background:
    :return:
    '''
    pred = pred.transpose(-2, -1).contiguous()  # [B, N, C] -> [B, C, N]
    weight = torch.ones(size=(pred.shape[1],), device=pred.device).float()
    pred = torch.log_softmax(pred, dim=1)
    weight[background_id] = weight_background
    seg_loss = F.cross_entropy(pred, target.long(), weight=weight)
    return seg_loss


def compute_cls_loss_ce(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    cls_loss = torch.zeros(size=[], device=pred.device)
    if len(pred.shape) == 2:
        n_valid = torch.sum(target > 0)
        cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred, target, reduction='sum')
        cls_loss = cls_loss / n_valid
    else:
        for i in range(pred.shape[-1]):
            cls_loss = cls_loss + torch.nn.functional.cross_entropy(pred[..., i], target[..., i], reduction='sum')
        n_valid = torch.sum(target > 0)
        cls_loss = cls_loss / n_valid

    return cls_loss


def compute_cls_loss_kl(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    cls_loss = torch.zeros(size=[], device=pred.device)
    if len(pred.shape) == 2:
        cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred, dim=-1),
                                                         torch.softmax(target, dim=-1),
                                                         reduction='sum')
    else:
        for i in range(pred.shape[-1]):
            cls_loss = cls_loss + torch.nn.functional.kl_div(torch.log_softmax(pred[..., i], dim=-1),
                                                             torch.softmax(target[..., i], dim=-1),
                                                             reduction='sum')

        cls_loss = cls_loss / pred.shape[-1]

    return cls_loss


def compute_sc_loss_l1(pred: torch.Tensor, target: torch.Tensor, mean_xyz=None, scale_xyz=None, mask=None):
    '''
    :param pred: [B, N, C]
    :param target: [B, N, C]
    :param mean_xyz:
    :param scale_xyz:
    :param mask:
    :return:
    '''
    loss = (pred - target)
    loss = torch.abs(loss).mean(dim=1)
    if mask is not None:
        return torch.mean(loss[mask])
    else:
        return torch.mean(loss)


def compute_sc_loss_geo(pred: torch.Tensor, P, K, p2ds, mean_xyz, scale_xyz, max_value=20, mask=None):
    b, c, n = pred.shape
    p3ds = (pred * scale_xyz[..., None].repeat(1, 1, n) + mean_xyz[..., None].repeat(1, 1, n))
    p3ds_homo = torch.cat(
        [pred, torch.ones(size=(p3ds.shape[0], 1, p3ds.shape[2]), dtype=p3ds.dtype, device=p3ds.device)],
        dim=1)  # [B, 4, N]
    p3ds = torch.matmul(K, torch.matmul(P, p3ds_homo)[:, :3, :])  # [B, 3, N]
    # print('p3ds: ', p3ds.shape, P.shape, K.shape, p2ds.shape)

    p2ds_ = p3ds[:, :2, :] / p3ds[:, 2:, :]

    loss = ((p2ds_ - p2ds.permute(0, 2, 1)) ** 2).sum(1)
    loss = torch.clamp_max(loss, max=max_value)
    if mask is not None:
        return torch.mean(loss[mask])
    else:
        return torch.mean(loss)