import pdb
import os
import sys
import tqdm

import numpy as np
import torch

from PIL import Image
from matplotlib import pyplot as pl

pl.ion()
from scipy.ndimage import uniform_filter

smooth = lambda arr: uniform_filter(arr, 3)


def transparent(img, alpha, cmap, **kw):
    from matplotlib.colors import Normalize

    colored_img = cmap(Normalize(clip=True, **kw)(img))
    colored_img[:, :, -1] = alpha
    return colored_img


from tools import common
from tools.dataloader import norm_RGB
from nets.patchnet import *
from extract import NonMaxSuppression


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser("Visualize the patch detector and descriptor")

    parser.add_argument("--img", type=str, default="imgs/brooklyn.png")
    parser.add_argument("--resize", type=int, default=512)
    parser.add_argument("--out", type=str, default="viz.png")

    parser.add_argument("--checkpoint", type=str, required=True, help="network path")
    parser.add_argument("--net", type=str, default="", help="network command")

    parser.add_argument("--max-kpts", type=int, default=200)
    parser.add_argument("--reliability-thr", type=float, default=0.8)
    parser.add_argument("--repeatability-thr", type=float, default=0.7)
    parser.add_argument(
        "--border", type=int, default=20, help="rm keypoints close to border"
    )

    parser.add_argument("--gpu", type=int, nargs="+", required=True, help="-1 for CPU")
    parser.add_argument("--dbg", type=str, nargs="+", default=(), help="debug options")

    args = parser.parse_args()
    args.dbg = set(args.dbg)

    iscuda = common.torch_set_gpu(args.gpu)
    device = torch.device("cuda" if iscuda else "cpu")

    # create network
    checkpoint = torch.load(args.checkpoint, lambda a, b: a)
    args.net = args.net or checkpoint["net"]
    print("\n>> Creating net = " + args.net)
    net = eval(args.net)
    net.load_state_dict(
        {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()}
    )
    if iscuda:
        net = net.cuda()
    print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )")

    img = Image.open(args.img).convert("RGB")
    if args.resize:
        img.thumbnail((args.resize, args.resize))
    img = np.asarray(img)

    detector = NonMaxSuppression(
        rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr
    )

    with torch.no_grad():
        print(">> computing features...")
        res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)])
        rela = res.get("reliability")
        repe = res.get("repeatability")
        kpts = detector(**res).T[:, [1, 0]]
        kpts = kpts[repe[0][0, 0][kpts[:, 1], kpts[:, 0]].argsort()[-args.max_kpts :]]

    fig = pl.figure("viz")
    kw = dict(cmap=pl.cm.RdYlGn, vmax=1)
    crop = (slice(args.border, -args.border or 1),) * 2

    if "reliability" in args.dbg:

        ax1 = pl.subplot(131)
        pl.imshow(img[crop], cmap=pl.cm.gray)
        pl.xticks(())
        pl.yticks(())

        pl.subplot(132)
        pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0)
        pl.xticks(())
        pl.yticks(())

        x, y = kpts[:, 0:2].cpu().numpy().T - args.border
        pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0)

        ax1 = pl.subplot(133)
        rela = rela[0][0, 0].cpu().numpy()
        pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9)
        pl.xticks(())
        pl.yticks(())

    else:
        ax1 = pl.subplot(131)
        pl.imshow(img[crop], cmap=pl.cm.gray)
        pl.xticks(())
        pl.yticks(())

        x, y = kpts[:, 0:2].cpu().numpy().T - args.border
        pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0)

        pl.subplot(132)
        pl.imshow(img[crop], cmap=pl.cm.gray)
        pl.xticks(())
        pl.yticks(())
        c = repe[0][0, 0].cpu().numpy()
        pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw))

        ax1 = pl.subplot(133)
        pl.imshow(img[crop], cmap=pl.cm.gray)
        pl.xticks(())
        pl.yticks(())
        rela = rela[0][0, 0].cpu().numpy()
        pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw))

    pl.gcf().set_size_inches(9, 2.73)
    pl.subplots_adjust(0.01, 0.01, 0.99, 0.99, hspace=0.1)
    pl.savefig(args.out)
    pdb.set_trace()