# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File   pram -> sfd2
@IDE    PyCharm
@Author fx221@cam.ac.uk
@Date   07/02/2024 14:53
=================================================='''
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as tvf

RGB_mean = [0.485, 0.456, 0.406]
RGB_std = [0.229, 0.224, 0.225]

norm_RGB = tvf.Compose([tvf.Normalize(mean=RGB_mean, std=RGB_std)])


def simple_nms(scores, nms_radius: int):
    """ Fast Non-maximum suppression to remove nearby points """
    assert (nms_radius >= 0)

    def max_pool(x):
        return torch.nn.functional.max_pool2d(
            x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)

    zeros = torch.zeros_like(scores)
    max_mask = scores == max_pool(scores)
    for _ in range(2):
        supp_mask = max_pool(max_mask.float()) > 0
        supp_scores = torch.where(supp_mask, zeros, scores)
        new_max_mask = supp_scores == max_pool(supp_scores)
        max_mask = max_mask | (new_max_mask & (~supp_mask))
    return torch.where(max_mask, scores, zeros)


def remove_borders(keypoints, scores, border: int, height: int, width: int):
    """ Removes keypoints too close to the border """
    mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
    mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
    mask = mask_h & mask_w
    return keypoints[mask], scores[mask]


def top_k_keypoints(keypoints, scores, k: int):
    if k >= len(keypoints):
        return keypoints, scores
    scores, indices = torch.topk(scores, k, dim=0)
    return keypoints[indices], scores


def sample_descriptors(keypoints, descriptors, s: int = 8):
    """ Interpolate descriptors at keypoint locations """
    b, c, h, w = descriptors.shape
    keypoints = keypoints - s / 2 + 0.5
    keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
                              ).to(keypoints)[None]
    keypoints = keypoints * 2 - 1  # normalize to (-1, 1)
    descriptors = torch.nn.functional.grid_sample(
        descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', align_corners=True)
    descriptors = torch.nn.functional.normalize(
        descriptors.reshape(b, c, -1), p=2, dim=1)
    return descriptors


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_bn=False, groups=1, dilation=1):
    if not use_bn:
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation),
            nn.ReLU(inplace=True),
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )


class ResBlock(nn.Module):
    def __init__(self, inplanes, outplanes, stride=1, groups=32, dilation=1, norm_layer=None):
        super(ResBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self.conv1 = conv1x1(inplanes, outplanes)
        self.bn1 = norm_layer(outplanes)
        self.conv2 = conv3x3(outplanes, outplanes, stride, groups, dilation)
        self.bn2 = norm_layer(outplanes)
        self.conv3 = conv1x1(outplanes, outplanes)
        self.bn3 = norm_layer(outplanes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet4x(nn.Module):
    default_config = {
        'conf_th': 0.005,
        'remove_borders': 4,
        'min_keypoints': 128,
        'max_keypoints': 4096,
    }

    def __init__(self, inputdim=3, outdim=128, desc_compressor=None):
        super().__init__()
        self.outdim = outdim
        self.desc_compressor = desc_compressor

        d1, d2, d3, d4, d5, d6 = 64, 128, 256, 256, 256, 256
        self.conv1a = conv(in_channels=inputdim, out_channels=d1, kernel_size=3, use_bn=True)
        self.conv1b = conv(in_channels=d1, out_channels=d1, kernel_size=3, stride=2, use_bn=True)

        self.conv2a = conv(in_channels=d1, out_channels=d2, kernel_size=3, use_bn=True)
        self.conv2b = conv(in_channels=d2, out_channels=d2, kernel_size=3, stride=2, use_bn=True)

        self.conv3a = conv(in_channels=d2, out_channels=d3, kernel_size=3, use_bn=True)
        self.conv3b = conv(in_channels=d3, out_channels=d3, kernel_size=3, use_bn=True)

        self.conv4 = nn.Sequential(
            ResBlock(inplanes=256, outplanes=256, groups=32),
            ResBlock(inplanes=256, outplanes=256, groups=32),
            ResBlock(inplanes=256, outplanes=256, groups=32),
        )

        self.convPa = nn.Sequential(
            torch.nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            torch.nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
        )
        self.convDa = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        )

        self.convPb = torch.nn.Conv2d(256, 65, kernel_size=1, stride=1, padding=0)
        self.convDb = torch.nn.Conv2d(256, outdim, kernel_size=1, stride=1, padding=0)

    def det(self, x):
        out1a = self.conv1a(x)
        out1b = self.conv1b(out1a)

        out2a = self.conv2a(out1b)
        out2b = self.conv2b(out2a)

        out3a = self.conv3a(out2b)
        out3b = self.conv3b(out3a)

        out4 = self.conv4(out3b)

        cPa = self.convPa(out4)
        logits = self.convPb(cPa)
        full_semi = torch.softmax(logits, dim=1)
        semi = full_semi[:, :-1, :, :]
        Hc, Wc = semi.size(2), semi.size(3)
        score = semi.permute([0, 2, 3, 1])
        score = score.view(score.size(0), Hc, Wc, 8, 8)
        score = score.permute([0, 1, 3, 2, 4])
        score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8)

        # Descriptor Head
        cDa = self.convDa(out4)
        desc = self.convDb(cDa)
        desc = F.normalize(desc, dim=1)

        return score, desc

    def forward(self, batch):
        out1a = self.conv1a(batch['image'])
        out1b = self.conv1b(out1a)

        out2a = self.conv2a(out1b)
        out2b = self.conv2b(out2a)

        out3a = self.conv3a(out2b)
        out3b = self.conv3b(out3a)

        out4 = self.conv4(out3b)

        cPa = self.convPa(out4)
        logits = self.convPb(cPa)
        full_semi = torch.softmax(logits, dim=1)
        semi = full_semi[:, :-1, :, :]
        Hc, Wc = semi.size(2), semi.size(3)
        score = semi.permute([0, 2, 3, 1])
        score = score.view(score.size(0), Hc, Wc, 8, 8)
        score = score.permute([0, 1, 3, 2, 4])
        score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8)

        # Descriptor Head
        cDa = self.convDa(out4)
        desc = self.convDb(cDa)
        desc = F.normalize(desc, dim=1)

        return {
            'dense_features': desc,
            'scores': score,
            'logits': logits,
            'semi_map': semi,
        }

    def extract_patches(self, batch):
        out1a = self.conv1a(batch['image'])
        out1b = self.conv1b(out1a)

        out2a = self.conv2a(out1b)
        out2b = self.conv2b(out2a)

        out3a = self.conv3a(out2b)
        out3b = self.conv3b(out3a)

        out4 = self.conv4(out3b)

        cPa = self.convPa(out4)
        logits = self.convPb(cPa)
        full_semi = torch.softmax(logits, dim=1)
        semi = full_semi[:, :-1, :, :]
        Hc, Wc = semi.size(2), semi.size(3)
        score = semi.permute([0, 2, 3, 1])
        score = score.view(score.size(0), Hc, Wc, 8, 8)
        score = score.permute([0, 1, 3, 2, 4])
        score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8)

        # Descriptor Head
        cDa = self.convDa(out4)
        desc = self.convDb(cDa)
        desc = F.normalize(desc, dim=1)

        return {
            'dense_features': desc,
            'scores': score,
            'logits': logits,
            'semi_map': semi,
        }

    def extract_local_global(self, data,
                             config={
                                 'conf_th': 0.005,
                                 'remove_borders': 4,
                                 'min_keypoints': 128,
                                 'max_keypoints': 4096,
                             }
                             ):

        config = {**self.default_config, **config}

        b, ic, ih, iw = data['image'].shape
        out1a = self.conv1a(data['image'])
        out1b = self.conv1b(out1a)  # 64

        out2a = self.conv2a(out1b)
        out2b = self.conv2b(out2a)  # 128

        out3a = self.conv3a(out2b)
        out3b = self.conv3b(out3a)  # 256

        out4 = self.conv4(out3b)  # 256

        cPa = self.convPa(out4)
        logits = self.convPb(cPa)
        full_semi = torch.softmax(logits, dim=1)
        semi = full_semi[:, :-1, :, :]
        Hc, Wc = semi.size(2), semi.size(3)
        score = semi.permute([0, 2, 3, 1])
        score = score.view(score.size(0), Hc, Wc, 8, 8)
        score = score.permute([0, 1, 3, 2, 4])
        score = score.contiguous().view(score.size(0), Hc * 8, Wc * 8)
        if Hc * 8 != ih or Wc * 8 != iw:
            score = F.interpolate(score.unsqueeze(1), size=[ih, iw], align_corners=True, mode='bilinear')
            score = score.squeeze(1)
        # extract keypoints
        nms_scores = simple_nms(scores=score, nms_radius=4)
        keypoints = [
            torch.nonzero(s >= config['conf_th'])
            for s in nms_scores]
        scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)]

        if len(scores[0]) <= config['min_keypoints']:
            keypoints = [
                torch.nonzero(s >= config['conf_th'] * 0.5)
                for s in nms_scores]
            scores = [s[tuple(k.t())] for s, k in zip(nms_scores, keypoints)]

        # Discard keypoints near the image borders
        keypoints, scores = list(zip(*[
            remove_borders(k, s, config['remove_borders'], ih, iw)
            for k, s in zip(keypoints, scores)]))

        # Keep the k keypoints with highest score
        if config['max_keypoints'] >= 0:
            keypoints, scores = list(zip(*[
                top_k_keypoints(k, s, config['max_keypoints'])
                for k, s in zip(keypoints, scores)]))

        # Convert (h, w) to (x, y)
        keypoints = [torch.flip(k, [1]).float() for k in keypoints]
        # Descriptor Head
        cDa = self.convDa(out4)
        desc_map = self.convDb(cDa)
        desc_map = F.normalize(desc_map, dim=1)

        descriptors = [sample_descriptors(k[None], d[None], 4)[0]
                       for k, d in zip(keypoints, desc_map)]

        return {
            'score_map': score,
            'desc_map': desc_map,
            'mid_features': out4,
            'global_descriptors': [out1b, out2b, out3b, out4],
            'keypoints': keypoints,
            'scores': scores,
            'descriptors': descriptors,
        }

    def sample(self, score_map, semi_descs, kpts, s=4, norm_desc=True):
        # print('sample: ', score_map.shape, semi_descs.shape, kpts.shape)
        b, c, h, w = semi_descs.shape
        norm_kpts = kpts - s / 2 + 0.5
        norm_kpts = norm_kpts / torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
                                             ).to(norm_kpts)[None]
        norm_kpts = norm_kpts * 2 - 1
        # args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
        descriptors = torch.nn.functional.grid_sample(
            semi_descs, norm_kpts.view(b, 1, -1, 2), mode='bilinear', align_corners=True)

        if norm_desc:
            descriptors = torch.nn.functional.normalize(
                descriptors.reshape(b, c, -1), p=2, dim=1)
        else:
            descriptors = descriptors.reshape(b, c, -1)

        # print('max: ', torch.min(kpts[:, 1].long()), torch.max(kpts[:, 1].long()), torch.min(kpts[:, 0].long()),
        #       torch.max(kpts[:, 0].long()))
        scores = score_map[0, kpts[:, 1].long(), kpts[:, 0].long()]

        return scores, descriptors.squeeze(0)


class DescriptorCompressor(nn.Module):
    def __init__(self, inputdim: int, outdim: int):
        super().__init__()
        self.inputdim = inputdim
        self.outdim = outdim
        self.conv = nn.Conv1d(in_channels=inputdim, out_channels=outdim, kernel_size=1, padding=0, bias=True)

    def forward(self, x):
        # b, c, n = x.shape
        out = self.conv(x)
        out = F.normalize(out, p=2, dim=1)
        return out


def extract_sfd2_return(model, img, conf_th=0.001,
                        mask=None,
                        topK=-1,
                        min_keypoints=0,
                        **kwargs):
    old_bm = torch.backends.cudnn.benchmark
    torch.backends.cudnn.benchmark = False  # speedup

    img = norm_RGB(img.squeeze())
    img = img[None]
    img = img.cuda()

    B, one, H, W = img.shape

    all_pts = []
    all_descs = []

    if 'scales' in kwargs.keys():
        scales = kwargs.get('scales')
    else:
        scales = [1.0]

    for s in scales:
        if s == 1.0:
            new_img = img
        else:
            nh = int(H * s)
            nw = int(W * s)
            new_img = F.interpolate(img, size=(nh, nw), mode='bilinear', align_corners=True)
        nh, nw = new_img.shape[2:]

        with torch.no_grad():
            heatmap, coarse_desc = model.det(new_img)

            # print("nh, nw, heatmap, desc: ", nh, nw, heatmap.shape, coarse_desc.shape)
            if len(heatmap.size()) == 3:
                heatmap = heatmap.unsqueeze(1)
            if len(heatmap.size()) == 2:
                heatmap = heatmap.unsqueeze(0)
                heatmap = heatmap.unsqueeze(1)
            # print(heatmap.shape)
            if heatmap.size(2) != nh or heatmap.size(3) != nw:
                heatmap = F.interpolate(heatmap, size=[nh, nw], mode='bilinear', align_corners=True)

            conf_thresh = conf_th
            nms_dist = 3
            border_remove = 4
            scores = simple_nms(heatmap, nms_radius=nms_dist)
            keypoints = [
                torch.nonzero(s > conf_thresh)
                for s in scores]
            scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
            # print('scores in return: ', len(scores[0]))

            # print(keypoints[0].shape)
            keypoints = [torch.flip(k, [1]).float() for k in keypoints]
            scores = scores[0].data.cpu().numpy().squeeze()
            keypoints = keypoints[0].data.cpu().numpy().squeeze()
            pts = keypoints.transpose()
            pts[2, :] = scores

            inds = np.argsort(pts[2, :])
            pts = pts[:, inds[::-1]]  # Sort by confidence.
            # Remove points along border.
            bord = border_remove
            toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W - bord))
            toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H - bord))
            toremove = np.logical_or(toremoveW, toremoveH)
            pts = pts[:, ~toremove]

            # valid_idex = heatmap > conf_thresh
            # valid_score = heatmap[valid_idex]
            # """
            # --- Process descriptor.
            # coarse_desc = coarse_desc.data.cpu().numpy().squeeze()
            D = coarse_desc.size(1)
            if pts.shape[1] == 0:
                desc = np.zeros((D, 0))
            else:
                if coarse_desc.size(2) == nh and coarse_desc.size(3) == nw:
                    desc = coarse_desc[:, :, pts[1, :], pts[0, :]]
                    desc = desc.data.cpu().numpy().reshape(D, -1)
                else:
                    # Interpolate into descriptor map using 2D point locations.
                    samp_pts = torch.from_numpy(pts[:2, :].copy())
                    samp_pts[0, :] = (samp_pts[0, :] / (float(nw) / 2.)) - 1.
                    samp_pts[1, :] = (samp_pts[1, :] / (float(nh) / 2.)) - 1.
                    samp_pts = samp_pts.transpose(0, 1).contiguous()
                    samp_pts = samp_pts.view(1, 1, -1, 2)
                    samp_pts = samp_pts.float()
                    samp_pts = samp_pts.cuda()
                    desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts, mode='bilinear', align_corners=True)
                    desc = desc.data.cpu().numpy().reshape(D, -1)
                    desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :]

            if pts.shape[1] == 0:
                continue

            # print(pts.shape, heatmap.shape, new_img.shape, img.shape, nw, nh, W, H)
            pts[0, :] = pts[0, :] * W / nw
            pts[1, :] = pts[1, :] * H / nh
            all_pts.append(np.transpose(pts, [1, 0]))
            all_descs.append(np.transpose(desc, [1, 0]))

    all_pts = np.vstack(all_pts)
    all_descs = np.vstack(all_descs)

    torch.backends.cudnn.benchmark = old_bm

    if all_pts.shape[0] == 0:
        return None, None, None

    keypoints = all_pts[:, 0:2]
    scores = all_pts[:, 2]
    descriptors = all_descs

    if mask is not None:
        # cv2.imshow("mask", mask)
        # cv2.waitKey(0)
        labels = []
        others = []
        keypoints_with_labels = []
        scores_with_labels = []
        descriptors_with_labels = []
        keypoints_without_labels = []
        scores_without_labels = []
        descriptors_without_labels = []

        id_img = np.int32(mask[:, :, 2]) * 256 * 256 + np.int32(mask[:, :, 1]) * 256 + np.int32(mask[:, :, 0])
        # print(img.shape, id_img.shape)

        for i in range(keypoints.shape[0]):
            x = keypoints[i, 0]
            y = keypoints[i, 1]
            # print("x-y", x, y, int(x), int(y))
            gid = id_img[int(y), int(x)]
            if gid == 0:
                keypoints_without_labels.append(keypoints[i])
                scores_without_labels.append(scores[i])
                descriptors_without_labels.append(descriptors[i])
                others.append(0)
            else:
                keypoints_with_labels.append(keypoints[i])
                scores_with_labels.append(scores[i])
                descriptors_with_labels.append(descriptors[i])
                labels.append(gid)

        if topK > 0:
            if topK <= len(keypoints_with_labels):
                idxes = np.array(scores_with_labels, float).argsort()[::-1][:topK]
                keypoints = np.array(keypoints_with_labels, float)[idxes]
                scores = np.array(scores_with_labels, float)[idxes]
                labels = np.array(labels, np.int32)[idxes]
                descriptors = np.array(descriptors_with_labels, float)[idxes]
            elif topK >= len(keypoints_with_labels) + len(keypoints_without_labels):
                # keypoints = np.vstack([keypoints_with_labels, keypoints_without_labels])
                # scores = np.vstack([scorescc_with_labels, scores_without_labels])
                # descriptors = np.vstack([descriptors_with_labels, descriptors_without_labels])
                # labels = np.vstack([labels, others])
                keypoints = keypoints_with_labels
                scores = scores_with_labels
                descriptors = descriptors_with_labels
                for i in range(len(others)):
                    keypoints.append(keypoints_without_labels[i])
                    scores.append(scores_without_labels[i])
                    descriptors.append(descriptors_without_labels[i])
                    labels.append(others[i])
            else:
                n = topK - len(keypoints_with_labels)
                idxes = np.array(scores_without_labels, float).argsort()[::-1][:n]
                keypoints = keypoints_with_labels
                scores = scores_with_labels
                descriptors = descriptors_with_labels
                for i in idxes:
                    keypoints.append(keypoints_without_labels[i])
                    scores.append(scores_without_labels[i])
                    descriptors.append(descriptors_without_labels[i])
                    labels.append(others[i])
        keypoints = np.array(keypoints, float)
        descriptors = np.array(descriptors, float)
        # print(keypoints.shape, descriptors.shape)
        return {"keypoints": np.array(keypoints, float),
                "descriptors": np.array(descriptors, float),
                "scores": np.array(scores, np.float),
                "labels": np.array(labels, np.int32),
                }
    else:
        # print(topK)
        if topK > 0:
            idxes = np.array(scores, dtype=float).argsort()[::-1][:topK]
            keypoints = np.array(keypoints[idxes], dtype=float)
            scores = np.array(scores[idxes], dtype=float)
            descriptors = np.array(descriptors[idxes], dtype=float)

        keypoints = np.array(keypoints, dtype=float)
        scores = np.array(scores, dtype=float)
        descriptors = np.array(descriptors, dtype=float)

        # print(keypoints.shape, descriptors.shape)

        return {"keypoints": np.array(keypoints, dtype=float),
                "descriptors": descriptors,
                "scores": scores,
                }


def load_sfd2(weight_path):
    net = ResNet4x(inputdim=3, outdim=128)
    net.load_state_dict(torch.load(weight_path, map_location='cpu')['state_dict'], strict=True)
    # print('Load sfd2 from {:s}'.format(weight_path))
    return net