import argparse
import os
from os.path import join

import cv2
import torch
from matplotlib import pyplot as plt

from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT
from .drawing import (
    plot_images,
    plot_lines,
    plot_color_line_matches,
    plot_keypoints,
    plot_matches,
)
from .models.two_view_pipeline import TwoViewPipeline


def main():
    # Parse input parameters
    parser = argparse.ArgumentParser(
        prog="GlueStick Demo",
        description="Demo app to show the point and line matches obtained by GlueStick",
    )
    parser.add_argument("-img1", default=join("resources" + os.path.sep + "img1.jpg"))
    parser.add_argument("-img2", default=join("resources" + os.path.sep + "img2.jpg"))
    parser.add_argument("--max_pts", type=int, default=1000)
    parser.add_argument("--max_lines", type=int, default=300)
    parser.add_argument("--skip-imshow", default=False, action="store_true")
    args = parser.parse_args()

    # Evaluation config
    conf = {
        "name": "two_view_pipeline",
        "use_lines": True,
        "extractor": {
            "name": "wireframe",
            "sp_params": {
                "force_num_keypoints": False,
                "max_num_keypoints": args.max_pts,
            },
            "wireframe_params": {
                "merge_points": True,
                "merge_line_endpoints": True,
            },
            "max_n_lines": args.max_lines,
        },
        "matcher": {
            "name": "gluestick",
            "weights": str(
                GLUESTICK_ROOT / "resources" / "weights" / "checkpoint_GlueStick_MD.tar"
            ),
            "trainable": False,
        },
        "ground_truth": {
            "from_pose_depth": False,
        },
    }

    device = "cuda" if torch.cuda.is_available() else "cpu"

    pipeline_model = TwoViewPipeline(conf).to(device).eval()

    gray0 = cv2.imread(args.img1, 0)
    gray1 = cv2.imread(args.img2, 0)

    torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1)
    torch_gray0, torch_gray1 = (
        torch_gray0.to(device)[None],
        torch_gray1.to(device)[None],
    )
    x = {"image0": torch_gray0, "image1": torch_gray1}
    pred = pipeline_model(x)

    pred = batch_to_np(pred)
    kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
    m0 = pred["matches0"]

    line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
    line_matches = pred["line_matches0"]

    valid_matches = m0 != -1
    match_indices = m0[valid_matches]
    matched_kps0 = kp0[valid_matches]
    matched_kps1 = kp1[match_indices]

    valid_matches = line_matches != -1
    match_indices = line_matches[valid_matches]
    matched_lines0 = line_seg0[valid_matches]
    matched_lines1 = line_seg1[match_indices]

    # Plot the matches
    img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(
        gray1, cv2.COLOR_GRAY2BGR
    )
    plot_images(
        [img0, img1],
        ["Image 1 - detected lines", "Image 2 - detected lines"],
        dpi=200,
        pad=2.0,
    )
    plot_lines([line_seg0, line_seg1], ps=4, lw=2)
    plt.gcf().canvas.manager.set_window_title("Detected Lines")
    plt.savefig("detected_lines.png")

    plot_images(
        [img0, img1],
        ["Image 1 - detected points", "Image 2 - detected points"],
        dpi=200,
        pad=2.0,
    )
    plot_keypoints([kp0, kp1], colors="c")
    plt.gcf().canvas.manager.set_window_title("Detected Points")
    plt.savefig("detected_points.png")

    plot_images(
        [img0, img1],
        ["Image 1 - line matches", "Image 2 - line matches"],
        dpi=200,
        pad=2.0,
    )
    plot_color_line_matches([matched_lines0, matched_lines1], lw=2)
    plt.gcf().canvas.manager.set_window_title("Line Matches")
    plt.savefig("line_matches.png")

    plot_images(
        [img0, img1],
        ["Image 1 - point matches", "Image 2 - point matches"],
        dpi=200,
        pad=2.0,
    )
    plot_matches(matched_kps0, matched_kps1, "green", lw=1, ps=0)
    plt.gcf().canvas.manager.set_window_title("Point Matches")
    plt.savefig("detected_points.png")
    if not args.skip_imshow:
        plt.show()


if __name__ == "__main__":
    main()