import os
from torch.multiprocessing import Process, Manager, set_start_method, Pool
import functools
import argparse
import yaml
import numpy as np
import sys
import cv2
from tqdm import trange

set_start_method("spawn", force=True)


ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)

from components import load_component
from utils import evaluation_utils, metrics

parser = argparse.ArgumentParser(description="dump eval data.")
parser.add_argument(
    "--config_path", type=str, default="configs/eval/scannet_eval_sgm.yaml"
)
parser.add_argument("--num_process_match", type=int, default=4)
parser.add_argument("--num_process_eval", type=int, default=4)
parser.add_argument("--vis_folder", type=str, default=None)
args = parser.parse_args()


def feed_match(info, matcher):
    x1, x2, desc1, desc2, size1, size2 = (
        info["x1"],
        info["x2"],
        info["desc1"],
        info["desc2"],
        info["img1"].shape[:2],
        info["img2"].shape[:2],
    )
    test_data = {
        "x1": x1,
        "x2": x2,
        "desc1": desc1,
        "desc2": desc2,
        "size1": np.flip(np.asarray(size1)),
        "size2": np.flip(np.asarray(size2)),
    }
    corr1, corr2 = matcher.run(test_data)
    return [corr1, corr2]


def reader_handler(config, read_que):
    reader = load_component("reader", config["name"], config)
    for index in range(len(reader)):
        index += 0
        info = reader.run(index)
        read_que.put(info)
    read_que.put("over")


def match_handler(config, read_que, match_que):
    matcher = load_component("matcher", config["name"], config)
    match_func = functools.partial(feed_match, matcher=matcher)
    pool = Pool(args.num_process_match)
    cache = []
    while True:
        item = read_que.get()
        # clear cache
        if item == "over":
            if len(cache) != 0:
                results = pool.map(match_func, cache)
                for cur_item, cur_result in zip(cache, results):
                    cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1]
                    match_que.put(cur_item)
            match_que.put("over")
            break
        cache.append(item)
        # print(len(cache))
        if len(cache) == args.num_process_match:
            # matching in parallel
            results = pool.map(match_func, cache)
            for cur_item, cur_result in zip(cache, results):
                cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1]
                match_que.put(cur_item)
            cache = []
    pool.close()
    pool.join()


def evaluate_handler(config, match_que):
    evaluator = load_component("evaluator", config["name"], config)
    pool = Pool(args.num_process_eval)
    cache = []
    for _ in trange(config["num_pair"]):
        item = match_que.get()
        if item == "over":
            if len(cache) != 0:
                results = pool.map(evaluator.run, cache)
                for cur_res in results:
                    evaluator.res_inqueue(cur_res)
            break
        cache.append(item)
        if len(cache) == args.num_process_eval:
            results = pool.map(evaluator.run, cache)
            for cur_res in results:
                evaluator.res_inqueue(cur_res)
            cache = []
        if args.vis_folder is not None:
            # dump visualization
            corr1_norm, corr2_norm = evaluation_utils.normalize_intrinsic(
                item["corr1"], item["K1"]
            ), evaluation_utils.normalize_intrinsic(item["corr2"], item["K2"])
            inlier_mask = metrics.compute_epi_inlier(
                corr1_norm, corr2_norm, item["e"], config["inlier_th"]
            )
            display = evaluation_utils.draw_match(
                item["img1"], item["img2"], item["corr1"], item["corr2"], inlier_mask
            )
            cv2.imwrite(
                os.path.join(args.vis_folder, str(item["index"]) + ".png"), display
            )
    evaluator.parse()


if __name__ == "__main__":
    with open(args.config_path, "r") as f:
        config = yaml.load(f)
    if args.vis_folder is not None and not os.path.exists(args.vis_folder):
        os.mkdir(args.vis_folder)

    read_que, match_que, estimate_que = (
        Manager().Queue(maxsize=100),
        Manager().Queue(maxsize=100),
        Manager().Queue(maxsize=100),
    )

    read_process = Process(target=reader_handler, args=(config["reader"], read_que))
    match_process = Process(
        target=match_handler, args=(config["matcher"], read_que, match_que)
    )
    evaluate_process = Process(
        target=evaluate_handler, args=(config["evaluator"], match_que)
    )

    read_process.start()
    match_process.start()
    evaluate_process.start()

    read_process.join()
    match_process.join()
    evaluate_process.join()