import copy
import os
import cv2
import glob
import logging
import argparse
import numpy as np
from tqdm import tqdm
from alike import ALike, configs


class ImageLoader(object):
    def __init__(self, filepath: str):
        self.N = 3000
        if filepath.startswith("camera"):
            camera = int(filepath[6:])
            self.cap = cv2.VideoCapture(camera)
            if not self.cap.isOpened():
                raise IOError(f"Can't open camera {camera}!")
            logging.info(f"Opened camera {camera}")
            self.mode = "camera"
        elif os.path.exists(filepath):
            if os.path.isfile(filepath):
                self.cap = cv2.VideoCapture(filepath)
                if not self.cap.isOpened():
                    raise IOError(f"Can't open video {filepath}!")
                rate = self.cap.get(cv2.CAP_PROP_FPS)
                self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
                duration = self.N / rate
                logging.info(f"Opened video {filepath}")
                logging.info(f"Frames: {self.N}, FPS: {rate}, Duration: {duration}s")
                self.mode = "video"
            else:
                self.images = (
                    glob.glob(os.path.join(filepath, "*.png"))
                    + glob.glob(os.path.join(filepath, "*.jpg"))
                    + glob.glob(os.path.join(filepath, "*.ppm"))
                )
                self.images.sort()
                self.N = len(self.images)
                logging.info(f"Loading {self.N} images")
                self.mode = "images"
        else:
            raise IOError(
                "Error filepath (camerax/path of images/path of videos): ", filepath
            )

    def __getitem__(self, item):
        if self.mode == "camera" or self.mode == "video":
            if item > self.N:
                return None
            ret, img = self.cap.read()
            if not ret:
                raise "Can't read image from camera"
            if self.mode == "video":
                self.cap.set(cv2.CAP_PROP_POS_FRAMES, item)
        elif self.mode == "images":
            filename = self.images[item]
            img = cv2.imread(filename)
            if img is None:
                raise Exception("Error reading image %s" % filename)
        return img

    def __len__(self):
        return self.N


class SimpleTracker(object):
    def __init__(self):
        self.pts_prev = None
        self.desc_prev = None

    def update(self, img, pts, desc):
        N_matches = 0
        if self.pts_prev is None:
            self.pts_prev = pts
            self.desc_prev = desc

            out = copy.deepcopy(img)
            for pt1 in pts:
                p1 = (int(round(pt1[0])), int(round(pt1[1])))
                cv2.circle(out, p1, 1, (0, 0, 255), -1, lineType=16)
        else:
            matches = self.mnn_mather(self.desc_prev, desc)
            mpts1, mpts2 = self.pts_prev[matches[:, 0]], pts[matches[:, 1]]
            N_matches = len(matches)

            out = copy.deepcopy(img)
            for pt1, pt2 in zip(mpts1, mpts2):
                p1 = (int(round(pt1[0])), int(round(pt1[1])))
                p2 = (int(round(pt2[0])), int(round(pt2[1])))
                cv2.line(out, p1, p2, (0, 255, 0), lineType=16)
                cv2.circle(out, p2, 1, (0, 0, 255), -1, lineType=16)

            self.pts_prev = pts
            self.desc_prev = desc

        return out, N_matches

    def mnn_mather(self, desc1, desc2):
        sim = desc1 @ desc2.transpose()
        sim[sim < 0.9] = 0
        nn12 = np.argmax(sim, axis=1)
        nn21 = np.argmax(sim, axis=0)
        ids1 = np.arange(0, sim.shape[0])
        mask = ids1 == nn21[nn12]
        matches = np.stack([ids1[mask], nn12[mask]])
        return matches.transpose()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ALike Demo.")
    parser.add_argument(
        "input",
        type=str,
        default="",
        help='Image directory or movie file or "camera0" (for webcam0).',
    )
    parser.add_argument(
        "--model",
        choices=["alike-t", "alike-s", "alike-n", "alike-l"],
        default="alike-t",
        help="The model configuration",
    )
    parser.add_argument(
        "--device", type=str, default="cuda", help="Running device (default: cuda)."
    )
    parser.add_argument(
        "--top_k",
        type=int,
        default=-1,
        help="Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)",
    )
    parser.add_argument(
        "--scores_th",
        type=float,
        default=0.2,
        help="Detector score threshold (default: 0.2).",
    )
    parser.add_argument(
        "--n_limit",
        type=int,
        default=5000,
        help="Maximum number of keypoints to be detected (default: 5000).",
    )
    parser.add_argument(
        "--no_display",
        action="store_true",
        help="Do not display images to screen. Useful if running remotely (default: False).",
    )
    parser.add_argument(
        "--no_sub_pixel",
        action="store_true",
        help="Do not detect sub-pixel keypoints (default: False).",
    )
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    image_loader = ImageLoader(args.input)
    model = ALike(
        **configs[args.model],
        device=args.device,
        top_k=args.top_k,
        scores_th=args.scores_th,
        n_limit=args.n_limit,
    )
    tracker = SimpleTracker()

    if not args.no_display:
        logging.info("Press 'q' to stop!")
        cv2.namedWindow(args.model)

    runtime = []
    progress_bar = tqdm(image_loader)
    for img in progress_bar:
        if img is None:
            break

        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        pred = model(img_rgb, sub_pixel=not args.no_sub_pixel)
        kpts = pred["keypoints"]
        desc = pred["descriptors"]
        runtime.append(pred["time"])

        out, N_matches = tracker.update(img, kpts, desc)

        ave_fps = (1.0 / np.stack(runtime)).mean()
        status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}"
        progress_bar.set_description(status)

        if not args.no_display:
            cv2.setWindowTitle(args.model, args.model + ": " + status)
            cv2.imshow(args.model, out)
            if cv2.waitKey(1) == ord("q"):
                break

    logging.info("Finished!")
    if not args.no_display:
        logging.info("Press any key to exit!")
        cv2.waitKey()