import torch
from torch import nn
from torch.nn.parameter import Parameter
import torchvision.transforms as tvf
import torch.nn.functional as F
import numpy as np


def gather_nd(params, indices):
    orig_shape = list(indices.shape)
    num_samples = np.prod(orig_shape[:-1])
    m = orig_shape[-1]
    n = len(params.shape)

    if m <= n:
        out_shape = orig_shape[:-1] + list(params.shape)[m:]
    else:
        raise ValueError(
            f"the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}"
        )

    indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist()
    output = params[indices]  # (num_samples, ...)
    return output.reshape(out_shape).contiguous()


# input: pos [kpt_n, 2]; inputs [H, W, 128] / [H, W]
# output: [kpt_n, 128] / [kpt_n]
def interpolate(pos, inputs, nd=True):
    h = inputs.shape[0]
    w = inputs.shape[1]

    i = pos[:, 0]
    j = pos[:, 1]

    i_top_left = torch.clamp(torch.floor(i).int(), 0, h - 1)
    j_top_left = torch.clamp(torch.floor(j).int(), 0, w - 1)

    i_top_right = torch.clamp(torch.floor(i).int(), 0, h - 1)
    j_top_right = torch.clamp(torch.ceil(j).int(), 0, w - 1)

    i_bottom_left = torch.clamp(torch.ceil(i).int(), 0, h - 1)
    j_bottom_left = torch.clamp(torch.floor(j).int(), 0, w - 1)

    i_bottom_right = torch.clamp(torch.ceil(i).int(), 0, h - 1)
    j_bottom_right = torch.clamp(torch.ceil(j).int(), 0, w - 1)

    dist_i_top_left = i - i_top_left.float()
    dist_j_top_left = j - j_top_left.float()
    w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
    w_top_right = (1 - dist_i_top_left) * dist_j_top_left
    w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
    w_bottom_right = dist_i_top_left * dist_j_top_left

    if nd:
        w_top_left = w_top_left[..., None]
        w_top_right = w_top_right[..., None]
        w_bottom_left = w_bottom_left[..., None]
        w_bottom_right = w_bottom_right[..., None]

    interpolated_val = (
        w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1))
        + w_top_right
        * gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1))
        + w_bottom_left
        * gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1))
        + w_bottom_right
        * gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1))
    )

    return interpolated_val


def edge_mask(inputs, n_channel, dilation=1, edge_thld=5):
    b, c, h, w = inputs.size()
    device = inputs.device

    dii_filter = torch.tensor([[0, 1.0, 0], [0, -2.0, 0], [0, 1.0, 0]]).view(1, 1, 3, 3)
    dij_filter = 0.25 * torch.tensor(
        [[1.0, 0, -1.0], [0, 0.0, 0], [-1.0, 0, 1.0]]
    ).view(1, 1, 3, 3)
    djj_filter = torch.tensor([[0, 0, 0], [1.0, -2.0, 1.0], [0, 0, 0]]).view(1, 1, 3, 3)

    dii = F.conv2d(
        inputs.view(-1, 1, h, w),
        dii_filter.to(device),
        padding=dilation,
        dilation=dilation,
    ).view(b, c, h, w)
    dij = F.conv2d(
        inputs.view(-1, 1, h, w),
        dij_filter.to(device),
        padding=dilation,
        dilation=dilation,
    ).view(b, c, h, w)
    djj = F.conv2d(
        inputs.view(-1, 1, h, w),
        djj_filter.to(device),
        padding=dilation,
        dilation=dilation,
    ).view(b, c, h, w)

    det = dii * djj - dij * dij
    tr = dii + djj
    del dii, dij, djj

    threshold = (edge_thld + 1) ** 2 / edge_thld
    is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)

    return is_not_edge


# input: score_map [batch_size, 1, H, W]
# output: indices [2, k, 2], scores [2, k]
def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_size=5):
    h = score_map.shape[2]
    w = score_map.shape[3]

    mask = score_map > score_thld
    if nms_size > 0:
        nms_mask = F.max_pool2d(
            score_map, kernel_size=nms_size, stride=1, padding=nms_size // 2
        )
        nms_mask = torch.eq(score_map, nms_mask)
        mask = torch.logical_and(nms_mask, mask)
    if eof_size > 0:
        eof_mask = torch.ones(
            (1, 1, h - 2 * eof_size, w - 2 * eof_size),
            dtype=torch.float32,
            device=score_map.device,
        )
        eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0)
        eof_mask = eof_mask.bool()
        mask = torch.logical_and(eof_mask, mask)
    if edge_thld > 0:
        non_edge_mask = edge_mask(score_map, 1, dilation=3, edge_thld=edge_thld)
        mask = torch.logical_and(non_edge_mask, mask)

    bs = score_map.shape[0]
    if bs is None:
        indices = torch.nonzero(mask)[0]
        scores = gather_nd(score_map, indices)[0]
        sample = torch.sort(scores, descending=True)[1][0:k]
        indices = indices[sample].unsqueeze(0)
        scores = scores[sample].unsqueeze(0)
    else:
        indices = []
        scores = []
        for i in range(bs):
            tmp_mask = mask[i][0]
            tmp_score_map = score_map[i][0]
            tmp_indices = torch.nonzero(tmp_mask)
            tmp_scores = gather_nd(tmp_score_map, tmp_indices)
            tmp_sample = torch.sort(tmp_scores, descending=True)[1][0:k]
            tmp_indices = tmp_indices[tmp_sample]
            tmp_scores = tmp_scores[tmp_sample]
            indices.append(tmp_indices)
            scores.append(tmp_scores)
        try:
            indices = torch.stack(indices, dim=0)
            scores = torch.stack(scores, dim=0)
        except:
            min_num = np.min([len(i) for i in indices])
            indices = torch.stack([i[:min_num] for i in indices], dim=0)
            scores = torch.stack([i[:min_num] for i in scores], dim=0)
    return indices, scores


# input: [batch_size, C, H, W]
# output: [batch_size, C, H, W], [batch_size, C, H, W]
def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1):
    inputs = inputs / moving_instance_max

    batch_size, C, H, W = inputs.shape

    pad_size = ksize // 2 + (dilation - 1)
    kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize)

    pad_inputs = F.pad(inputs, [pad_size] * 4, mode="reflect")

    avg_spatial_inputs = F.conv2d(
        pad_inputs, kernel, stride=1, dilation=dilation, padding=0, groups=C
    )
    avg_channel_inputs = torch.mean(
        inputs, axis=1, keepdim=True
    )  # channel dimension is 1
    # print(avg_spatial_inputs.shape)

    alpha = F.softplus(inputs - avg_spatial_inputs)
    beta = F.softplus(inputs - avg_channel_inputs)

    return alpha, beta


class DarkFeat(nn.Module):
    default_config = {
        "model_path": "",
        "input_type": "raw-demosaic",
        "kpt_n": 5000,
        "kpt_refinement": True,
        "score_thld": 0.5,
        "edge_thld": 10,
        "multi_scale": False,
        "multi_level": True,
        "nms_size": 3,
        "eof_size": 5,
        "need_norm": True,
        "use_peakiness": True,
    }

    def __init__(
        self,
        model_path="",
        inchan=3,
        dilated=True,
        dilation=1,
        bn=True,
        bn_affine=False,
    ):
        super(DarkFeat, self).__init__()
        inchan = (
            3
            if self.default_config["input_type"] == "rgb"
            or self.default_config["input_type"] == "raw-demosaic"
            else 1
        )
        self.config = {**self.default_config}

        self.inchan = inchan
        self.curchan = inchan
        self.dilated = dilated
        self.dilation = dilation
        self.bn = bn
        self.bn_affine = bn_affine
        self.config["model_path"] = model_path

        dim = 128
        mchan = 4

        self.conv0 = self._add_conv(8 * mchan)
        self.conv1 = self._add_conv(8 * mchan, bn=False)
        self.bn1 = self._make_bn(8 * mchan)
        self.conv2 = self._add_conv(16 * mchan, stride=2)
        self.conv3 = self._add_conv(16 * mchan, bn=False)
        self.bn3 = self._make_bn(16 * mchan)
        self.conv4 = self._add_conv(32 * mchan, stride=2)
        self.conv5 = self._add_conv(32 * mchan)
        # replace last 8x8 convolution with 3 3x3 convolutions
        self.conv6_0 = self._add_conv(32 * mchan)
        self.conv6_1 = self._add_conv(32 * mchan)
        self.conv6_2 = self._add_conv(dim, bn=False, relu=False)
        self.out_dim = dim

        self.moving_avg_params = nn.ParameterList(
            [
                Parameter(torch.tensor(1.0), requires_grad=False),
                Parameter(torch.tensor(1.0), requires_grad=False),
                Parameter(torch.tensor(1.0), requires_grad=False),
            ]
        )
        self.clf = nn.Conv2d(128, 2, kernel_size=1)

        state_dict = torch.load(self.config["model_path"], map_location="cpu")
        new_state_dict = {}

        for key in state_dict:
            if (
                "running_mean" not in key
                and "running_var" not in key
                and "num_batches_tracked" not in key
            ):
                new_state_dict[key] = state_dict[key]

        self.load_state_dict(new_state_dict)
        print("Loaded DarkFeat model")

    def _make_bn(self, outd):
        return nn.BatchNorm2d(outd, affine=self.bn_affine, track_running_stats=False)

    def _add_conv(
        self,
        outd,
        k=3,
        stride=1,
        dilation=1,
        bn=True,
        relu=True,
        k_pool=1,
        pool_type="max",
        bias=False,
    ):
        d = self.dilation * dilation
        conv_params = dict(
            padding=((k - 1) * d) // 2, dilation=d, stride=stride, bias=bias
        )

        ops = nn.ModuleList([])

        ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params))
        if bn and self.bn:
            ops.append(self._make_bn(outd))
        if relu:
            ops.append(nn.ReLU(inplace=True))
        self.curchan = outd

        if k_pool > 1:
            if pool_type == "avg":
                ops.append(torch.nn.AvgPool2d(kernel_size=k_pool))
            elif pool_type == "max":
                ops.append(torch.nn.MaxPool2d(kernel_size=k_pool))
            else:
                print(f"Error, unknown pooling type {pool_type}...")

        return nn.Sequential(*ops)

    def forward(self, input):
        """Compute keypoints, scores, descriptors for image"""
        data = input["image"]
        H, W = data.shape[2:]

        if self.config["input_type"] == "rgb":
            # 3-channel rgb
            RGB_mean = [0.485, 0.456, 0.406]
            RGB_std = [0.229, 0.224, 0.225]
            norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
            data = norm_RGB(data)

        elif self.config["input_type"] == "gray":
            # 1-channel
            data = torch.mean(data, dim=1, keepdim=True)
            norm_gray0 = tvf.Normalize(mean=data.mean(), std=data.std())
            data = norm_gray0(data)

        elif self.config["input_type"] == "raw":
            # 4-channel
            pass
        elif self.config["input_type"] == "raw-demosaic":
            # 3-channel
            pass
        else:
            raise NotImplementedError()

        # x: [N, C, H, W]
        x0 = self.conv0(data)
        x1 = self.conv1(x0)
        x1_bn = self.bn1(x1)
        x2 = self.conv2(x1_bn)
        x3 = self.conv3(x2)
        x3_bn = self.bn3(x3)
        x4 = self.conv4(x3_bn)
        x5 = self.conv5(x4)
        x6_0 = self.conv6_0(x5)
        x6_1 = self.conv6_1(x6_0)
        x6_2 = self.conv6_2(x6_1)

        comb_weights = torch.tensor([1.0, 2.0, 3.0], device=data.device)
        comb_weights /= torch.sum(comb_weights)
        ksize = [3, 2, 1]
        det_score_maps = []

        for idx, xx in enumerate([x1, x3, x6_2]):
            alpha, beta = peakiness_score(
                xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx]
            )
            score_vol = alpha * beta
            det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0]
            det_score_map = F.interpolate(
                det_score_map, size=data.shape[2:], mode="bilinear", align_corners=True
            )
            det_score_map = comb_weights[idx] * det_score_map
            det_score_maps.append(det_score_map)

        det_score_map = torch.sum(torch.stack(det_score_maps, dim=0), dim=0)

        desc = x6_2
        score_map = det_score_map
        conf = F.softmax(self.clf((desc) ** 2), dim=1)[:, 1:2]
        score_map = score_map * F.interpolate(
            conf, size=score_map.shape[2:], mode="bilinear", align_corners=True
        )

        kpt_inds, kpt_score = extract_kpts(
            score_map,
            k=self.config["kpt_n"],
            score_thld=self.config["score_thld"],
            nms_size=self.config["nms_size"],
            eof_size=self.config["eof_size"],
            edge_thld=self.config["edge_thld"],
        )

        descs = (
            F.normalize(
                interpolate(kpt_inds.squeeze(0) / 4, desc.squeeze(0).permute(1, 2, 0)),
                p=2,
                dim=-1,
            )
            .detach()
            .cpu()
            .numpy(),
        )
        kpts = np.squeeze(
            torch.stack([kpt_inds[:, :, 1], kpt_inds[:, :, 0]], dim=-1).cpu(), axis=0
        ) * np.array([W / data.shape[3], H / data.shape[2]], dtype=np.float32)
        scores = np.squeeze(kpt_score.detach().cpu().numpy(), axis=0)

        idxs = np.negative(scores).argsort()[0 : self.config["kpt_n"]]
        descs = descs[0][idxs]
        kpts = kpts[idxs]
        scores = scores[idxs]

        return {
            "keypoints": kpts,
            "scores": torch.from_numpy(scores),
            "descriptors": torch.from_numpy(descs.T),
        }