#! /usr/bin/env python3
#
# %BANNER_BEGIN%
# ---------------------------------------------------------------------
# %COPYRIGHT_BEGIN%
#
#  Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
#
#  Unpublished Copyright (c) 2020
#  Magic Leap, Inc., All Rights Reserved.
#
# NOTICE:  All information contained herein is, and remains the property
# of COMPANY. The intellectual and technical concepts contained herein
# are proprietary to COMPANY and may be covered by U.S. and Foreign
# Patents, patents in process, and are protected by trade secret or
# copyright law.  Dissemination of this information or reproduction of
# this material is strictly forbidden unless prior written permission is
# obtained from COMPANY.  Access to the source code contained herein is
# hereby forbidden to anyone except current COMPANY employees, managers
# or contractors who have executed Confidentiality and Non-disclosure
# agreements explicitly covering such access.
#
# The copyright notice above does not evidence any actual or intended
# publication or disclosure  of  this source code, which includes
# information that is confidential and/or proprietary, and is a trade
# secret, of  COMPANY.   ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
# PUBLIC  PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE  OF THIS
# SOURCE CODE  WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
# INTERNATIONAL TREATIES.  THE RECEIPT OR POSSESSION OF  THIS SOURCE
# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
# USE, OR SELL ANYTHING THAT IT  MAY DESCRIBE, IN WHOLE OR IN PART.
#
# %COPYRIGHT_END%
# ----------------------------------------------------------------------
# %AUTHORS_BEGIN%
#
#  Originating Authors: Paul-Edouard Sarlin
#                       Daniel DeTone
#                       Tomasz Malisiewicz
#
# %AUTHORS_END%
# --------------------------------------------------------------------*/
# %BANNER_END%

from pathlib import Path
import argparse
import cv2
import matplotlib.cm as cm
import torch

from models.matching import Matching
from models.utils import (
    AverageTimer,
    VideoStreamer,
    make_matching_plot_fast,
    frame2tensor,
)

torch.set_grad_enabled(False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="SuperGlue demo",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--input",
        type=str,
        default="0",
        help="ID of a USB webcam, URL of an IP camera, "
        "or path to an image directory or movie file",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="Directory where to write output frames (If None, no output)",
    )

    parser.add_argument(
        "--image_glob",
        type=str,
        nargs="+",
        default=["*.png", "*.jpg", "*.jpeg"],
        help="Glob if a directory of images is specified",
    )
    parser.add_argument(
        "--skip",
        type=int,
        default=1,
        help="Images to skip if input is a movie or directory",
    )
    parser.add_argument(
        "--max_length",
        type=int,
        default=1000000,
        help="Maximum length if input is a movie or directory",
    )
    parser.add_argument(
        "--resize",
        type=int,
        nargs="+",
        default=[640, 480],
        help="Resize the input image before running inference. If two numbers, "
        "resize to the exact dimensions, if one number, resize the max "
        "dimension, if -1, do not resize",
    )

    parser.add_argument(
        "--superglue",
        choices={"indoor", "outdoor"},
        default="indoor",
        help="SuperGlue weights",
    )
    parser.add_argument(
        "--max_keypoints",
        type=int,
        default=-1,
        help="Maximum number of keypoints detected by Superpoint"
        " ('-1' keeps all keypoints)",
    )
    parser.add_argument(
        "--keypoint_threshold",
        type=float,
        default=0.005,
        help="SuperPoint keypoint detector confidence threshold",
    )
    parser.add_argument(
        "--nms_radius",
        type=int,
        default=4,
        help="SuperPoint Non Maximum Suppression (NMS) radius" " (Must be positive)",
    )
    parser.add_argument(
        "--sinkhorn_iterations",
        type=int,
        default=20,
        help="Number of Sinkhorn iterations performed by SuperGlue",
    )
    parser.add_argument(
        "--match_threshold", type=float, default=0.2, help="SuperGlue match threshold"
    )

    parser.add_argument(
        "--show_keypoints", action="store_true", help="Show the detected keypoints"
    )
    parser.add_argument(
        "--no_display",
        action="store_true",
        help="Do not display images to screen. Useful if running remotely",
    )
    parser.add_argument(
        "--force_cpu", action="store_true", help="Force pytorch to run in CPU mode."
    )

    opt = parser.parse_args()
    print(opt)

    if len(opt.resize) == 2 and opt.resize[1] == -1:
        opt.resize = opt.resize[0:1]
    if len(opt.resize) == 2:
        print("Will resize to {}x{} (WxH)".format(opt.resize[0], opt.resize[1]))
    elif len(opt.resize) == 1 and opt.resize[0] > 0:
        print("Will resize max dimension to {}".format(opt.resize[0]))
    elif len(opt.resize) == 1:
        print("Will not resize images")
    else:
        raise ValueError("Cannot specify more than two integers for --resize")

    device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu"
    print('Running inference on device "{}"'.format(device))
    config = {
        "superpoint": {
            "nms_radius": opt.nms_radius,
            "keypoint_threshold": opt.keypoint_threshold,
            "max_keypoints": opt.max_keypoints,
        },
        "superglue": {
            "weights": opt.superglue,
            "sinkhorn_iterations": opt.sinkhorn_iterations,
            "match_threshold": opt.match_threshold,
        },
    }
    matching = Matching(config).eval().to(device)
    keys = ["keypoints", "scores", "descriptors"]

    vs = VideoStreamer(opt.input, opt.resize, opt.skip, opt.image_glob, opt.max_length)
    frame, ret = vs.next_frame()
    assert ret, "Error when reading the first frame (try different --input?)"

    frame_tensor = frame2tensor(frame, device)
    last_data = matching.superpoint({"image": frame_tensor})
    last_data = {k + "0": last_data[k] for k in keys}
    last_data["image0"] = frame_tensor
    last_frame = frame
    last_image_id = 0

    if opt.output_dir is not None:
        print("==> Will write outputs to {}".format(opt.output_dir))
        Path(opt.output_dir).mkdir(exist_ok=True)

    # Create a window to display the demo.
    if not opt.no_display:
        cv2.namedWindow("SuperGlue matches", cv2.WINDOW_NORMAL)
        cv2.resizeWindow("SuperGlue matches", 640 * 2, 480)
    else:
        print("Skipping visualization, will not show a GUI.")

    # Print the keyboard help menu.
    print(
        "==> Keyboard control:\n"
        "\tn: select the current frame as the anchor\n"
        "\te/r: increase/decrease the keypoint confidence threshold\n"
        "\td/f: increase/decrease the match filtering threshold\n"
        "\tk: toggle the visualization of keypoints\n"
        "\tq: quit"
    )

    timer = AverageTimer()

    while True:
        frame, ret = vs.next_frame()
        if not ret:
            print("Finished demo_superglue.py")
            break
        timer.update("data")
        stem0, stem1 = last_image_id, vs.i - 1

        frame_tensor = frame2tensor(frame, device)
        pred = matching({**last_data, "image1": frame_tensor})
        kpts0 = last_data["keypoints0"][0].cpu().numpy()
        kpts1 = pred["keypoints1"][0].cpu().numpy()
        matches = pred["matches0"][0].cpu().numpy()
        confidence = pred["matching_scores0"][0].cpu().numpy()
        timer.update("forward")

        valid = matches > -1
        mkpts0 = kpts0[valid]
        mkpts1 = kpts1[matches[valid]]
        color = cm.jet(confidence[valid])
        text = [
            "SuperGlue",
            "Keypoints: {}:{}".format(len(kpts0), len(kpts1)),
            "Matches: {}".format(len(mkpts0)),
        ]
        k_thresh = matching.superpoint.config["keypoint_threshold"]
        m_thresh = matching.superglue.config["match_threshold"]
        small_text = [
            "Keypoint Threshold: {:.4f}".format(k_thresh),
            "Match Threshold: {:.2f}".format(m_thresh),
            "Image Pair: {:06}:{:06}".format(stem0, stem1),
        ]
        out = make_matching_plot_fast(
            last_frame,
            frame,
            kpts0,
            kpts1,
            mkpts0,
            mkpts1,
            color,
            text,
            path=None,
            show_keypoints=opt.show_keypoints,
            small_text=small_text,
        )

        if not opt.no_display:
            cv2.imshow("SuperGlue matches", out)
            key = chr(cv2.waitKey(1) & 0xFF)
            if key == "q":
                vs.cleanup()
                print("Exiting (via q) demo_superglue.py")
                break
            elif key == "n":  # set the current frame as anchor
                last_data = {k + "0": pred[k + "1"] for k in keys}
                last_data["image0"] = frame_tensor
                last_frame = frame
                last_image_id = vs.i - 1
            elif key in ["e", "r"]:
                # Increase/decrease keypoint threshold by 10% each keypress.
                d = 0.1 * (-1 if key == "e" else 1)
                matching.superpoint.config["keypoint_threshold"] = min(
                    max(
                        0.0001,
                        matching.superpoint.config["keypoint_threshold"] * (1 + d),
                    ),
                    1,
                )
                print(
                    "\nChanged the keypoint threshold to {:.4f}".format(
                        matching.superpoint.config["keypoint_threshold"]
                    )
                )
            elif key in ["d", "f"]:
                # Increase/decrease match threshold by 0.05 each keypress.
                d = 0.05 * (-1 if key == "d" else 1)
                matching.superglue.config["match_threshold"] = min(
                    max(0.05, matching.superglue.config["match_threshold"] + d), 0.95
                )
                print(
                    "\nChanged the match threshold to {:.2f}".format(
                        matching.superglue.config["match_threshold"]
                    )
                )
            elif key == "k":
                opt.show_keypoints = not opt.show_keypoints

        timer.update("viz")
        timer.print()

        if opt.output_dir is not None:
            # stem = 'matches_{:06}_{:06}'.format(last_image_id, vs.i-1)
            stem = "matches_{:06}_{:06}".format(stem0, stem1)
            out_file = str(Path(opt.output_dir, stem + ".png"))
            print("\nWriting image to {}".format(out_file))
            cv2.imwrite(out_file, out)

    cv2.destroyAllWindows()
    vs.cleanup()