# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use


import os, pdb
from PIL import Image
import numpy as np
import torch

from tools import common
from tools.dataloader import norm_RGB
from nets.patchnet import *


def load_network(model_fn):
    checkpoint = torch.load(model_fn)
    print("\n>> Creating net = " + checkpoint["net"])
    net = eval(checkpoint["net"])
    nb_of_weights = common.model_size(net)
    print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )")

    # initialization
    weights = checkpoint["state_dict"]
    net.load_state_dict({k.replace("module.", ""): v for k, v in weights.items()})
    return net.eval()


class NonMaxSuppression(torch.nn.Module):
    def __init__(self, rel_thr=0.7, rep_thr=0.7):
        nn.Module.__init__(self)
        self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.rel_thr = rel_thr
        self.rep_thr = rep_thr

    def forward(self, reliability, repeatability, **kw):
        assert len(reliability) == len(repeatability) == 1
        reliability, repeatability = reliability[0], repeatability[0]

        # local maxima
        maxima = repeatability == self.max_filter(repeatability)

        # remove low peaks
        maxima *= repeatability >= self.rep_thr
        maxima *= reliability >= self.rel_thr

        return maxima.nonzero().t()[2:4]


def extract_multiscale(
    net,
    img,
    detector,
    scale_f=2**0.25,
    min_scale=0.0,
    max_scale=1,
    min_size=256,
    max_size=1024,
    verbose=False,
):
    old_bm = torch.backends.cudnn.benchmark
    torch.backends.cudnn.benchmark = False  # speedup

    # extract keypoints at multiple scales
    B, three, H, W = img.shape
    assert B == 1 and three == 3, "should be a batch with a single RGB image"

    assert max_scale <= 1
    s = 1.0  # current scale factor

    X, Y, S, C, Q, D = [], [], [], [], [], []
    while s + 0.001 >= max(min_scale, min_size / max(H, W)):
        if s - 0.001 <= min(max_scale, max_size / max(H, W)):
            nh, nw = img.shape[2:]
            if verbose:
                print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}")
            # extract descriptors
            with torch.no_grad():
                res = net(imgs=[img])

            # get output and reliability map
            descriptors = res["descriptors"][0]
            reliability = res["reliability"][0]
            repeatability = res["repeatability"][0]

            # normalize the reliability for nms
            # extract maxima and descs
            y, x = detector(**res)  # nms
            c = reliability[0, 0, y, x]
            q = repeatability[0, 0, y, x]
            d = descriptors[0, :, y, x].t()
            n = d.shape[0]

            # accumulate multiple scales
            X.append(x.float() * W / nw)
            Y.append(y.float() * H / nh)
            S.append((32 / s) * torch.ones(n, dtype=torch.float32, device=d.device))
            C.append(c)
            Q.append(q)
            D.append(d)
        s /= scale_f

        # down-scale the image for next iteration
        nh, nw = round(H * s), round(W * s)
        img = F.interpolate(img, (nh, nw), mode="bilinear", align_corners=False)

    # restore value
    torch.backends.cudnn.benchmark = old_bm

    Y = torch.cat(Y)
    X = torch.cat(X)
    S = torch.cat(S)  # scale
    scores = torch.cat(C) * torch.cat(Q)  # scores = reliability * repeatability
    XYS = torch.stack([X, Y, S], dim=-1)
    D = torch.cat(D)
    return XYS, D, scores


def extract_keypoints(args):
    iscuda = common.torch_set_gpu(args.gpu)

    # load the network...
    net = load_network(args.model)
    if iscuda:
        net = net.cuda()

    # create the non-maxima detector
    detector = NonMaxSuppression(
        rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr
    )

    while args.images:
        img_path = args.images.pop(0)

        if img_path.endswith(".txt"):
            args.images = open(img_path).read().splitlines() + args.images
            continue

        print(f"\nExtracting features for {img_path}")
        img = Image.open(img_path).convert("RGB")
        W, H = img.size
        img = norm_RGB(img)[None]
        if iscuda:
            img = img.cuda()

        # extract keypoints/descriptors for a single image
        xys, desc, scores = extract_multiscale(
            net,
            img,
            detector,
            scale_f=args.scale_f,
            min_scale=args.min_scale,
            max_scale=args.max_scale,
            min_size=args.min_size,
            max_size=args.max_size,
            verbose=True,
        )

        xys = xys.cpu().numpy()
        desc = desc.cpu().numpy()
        scores = scores.cpu().numpy()
        idxs = scores.argsort()[-args.top_k or None :]

        outpath = img_path + "." + args.tag
        print(f"Saving {len(idxs)} keypoints to {outpath}")
        np.savez(
            open(outpath, "wb"),
            imsize=(W, H),
            keypoints=xys[idxs],
            descriptors=desc[idxs],
            scores=scores[idxs],
        )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser("Extract keypoints for a given image")
    parser.add_argument("--model", type=str, required=True, help="model path")

    parser.add_argument(
        "--images", type=str, required=True, nargs="+", help="images / list"
    )
    parser.add_argument("--tag", type=str, default="r2d2", help="output file tag")

    parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints")

    parser.add_argument("--scale-f", type=float, default=2**0.25)
    parser.add_argument("--min-size", type=int, default=256)
    parser.add_argument("--max-size", type=int, default=1024)
    parser.add_argument("--min-scale", type=float, default=0)
    parser.add_argument("--max-scale", type=float, default=1)

    parser.add_argument("--reliability-thr", type=float, default=0.7)
    parser.add_argument("--repeatability-thr", type=float, default=0.7)

    parser.add_argument(
        "--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU"
    )
    args = parser.parse_args()

    extract_keypoints(args)