import kornia
from kornia.feature.laf import (
    laf_from_center_scale_ori,
    extract_patches_from_pyramid,
)
import numpy as np
import torch
import pycolmap

from ..utils.base_model import BaseModel


EPS = 1e-6


def sift_to_rootsift(x):
    x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
    x = np.sqrt(x.clip(min=EPS))
    x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
    return x


class DoG(BaseModel):
    default_conf = {
        "options": {
            "first_octave": 0,
            "peak_threshold": 0.01,
        },
        "descriptor": "rootsift",
        "max_keypoints": -1,
        "patch_size": 32,
        "mr_size": 12,
    }
    required_inputs = ["image"]
    detection_noise = 1.0
    max_batch_size = 1024

    def _init(self, conf):
        if conf["descriptor"] == "sosnet":
            self.describe = kornia.feature.SOSNet(pretrained=True)
        elif conf["descriptor"] == "hardnet":
            self.describe = kornia.feature.HardNet(pretrained=True)
        elif conf["descriptor"] not in ["sift", "rootsift"]:
            raise ValueError(f'Unknown descriptor: {conf["descriptor"]}')

        self.sift = None  # lazily instantiated on the first image
        self.dummy_param = torch.nn.Parameter(torch.empty(0))
        self.device = torch.device("cpu")

    def to(self, *args, **kwargs):
        device = kwargs.get("device")
        if device is None:
            match = [a for a in args if isinstance(a, (torch.device, str))]
            if len(match) > 0:
                device = match[0]
        if device is not None:
            self.device = torch.device(device)
        return super().to(*args, **kwargs)

    def _forward(self, data):
        image = data["image"]
        image_np = image.cpu().numpy()[0, 0]
        assert image.shape[1] == 1
        assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS

        if self.sift is None:
            device = self.dummy_param.device
            use_gpu = pycolmap.has_cuda and device.type == "cuda"
            options = {**self.conf["options"]}
            if self.conf["descriptor"] == "rootsift":
                options["normalization"] = pycolmap.Normalization.L1_ROOT
            else:
                options["normalization"] = pycolmap.Normalization.L2
            self.sift = pycolmap.Sift(
                options=pycolmap.SiftExtractionOptions(options),
                device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"),
            )
        keypoints, descriptors = self.sift.extract(image_np)
        scales = keypoints[:, 2]
        oris = np.rad2deg(keypoints[:, 3])

        if self.conf["descriptor"] in ["sift", "rootsift"]:
            # We still renormalize because COLMAP does not normalize well,
            # maybe due to numerical errors
            if self.conf["descriptor"] == "rootsift":
                descriptors = sift_to_rootsift(descriptors)
            descriptors = torch.from_numpy(descriptors)
        elif self.conf["descriptor"] in ("sosnet", "hardnet"):
            center = keypoints[:, :2] + 0.5
            laf_scale = scales * self.conf["mr_size"] / 2
            laf_ori = -oris
            lafs = laf_from_center_scale_ori(
                torch.from_numpy(center)[None],
                torch.from_numpy(laf_scale)[None, :, None, None],
                torch.from_numpy(laf_ori)[None, :, None],
            ).to(image.device)
            patches = extract_patches_from_pyramid(
                image, lafs, PS=self.conf["patch_size"]
            )[0]
            descriptors = patches.new_zeros((len(patches), 128))
            if len(patches) > 0:
                for start_idx in range(0, len(patches), self.max_batch_size):
                    end_idx = min(len(patches), start_idx + self.max_batch_size)
                    descriptors[start_idx:end_idx] = self.describe(
                        patches[start_idx:end_idx]
                    )
        else:
            raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}')

        keypoints = torch.from_numpy(keypoints[:, :2])  # keep only x, y
        scales = torch.from_numpy(scales)
        oris = torch.from_numpy(oris)
        scores = keypoints.new_zeros(len(keypoints))  # no scores for SIFT yet

        if self.conf["max_keypoints"] != -1:
            # TODO: check that the scores from PyCOLMAP are 100% correct,
            # follow https://github.com/mihaidusmanu/pycolmap/issues/8
            max_number = (
                scores.shape[0]
                if scores.shape[0] < self.conf["max_keypoints"]
                else self.conf["max_keypoints"]
            )
            values, indices = torch.topk(scores, max_number)
            keypoints = keypoints[indices]
            scales = scales[indices]
            oris = oris[indices]
            scores = scores[indices]
            descriptors = descriptors[indices]

        return {
            "keypoints": keypoints[None],
            "scales": scales[None],
            "oris": oris[None],
            "scores": scores[None],
            "descriptors": descriptors.T[None],
        }