import os
import glob
import math
import re
import numpy as np
import h5py
from tqdm import trange
from torch.multiprocessing import Pool
import pyxis as px
from .base_dumper import BaseDumper

import sys

ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
sys.path.insert(0, ROOT_DIR)

from utils import transformations, data_utils


class gl3d_train(BaseDumper):
    def get_seqs(self):
        data_dir = os.path.join(self.config["rawdata_dir"], "data")
        seq_train = np.loadtxt(
            os.path.join(
                self.config["rawdata_dir"], "list", "comb", "imageset_train.txt"
            ),
            dtype=str,
        )
        seq_valid = np.loadtxt(
            os.path.join(
                self.config["rawdata_dir"], "list", "comb", "imageset_test.txt"
            ),
            dtype=str,
        )

        # filtering seq list
        self.seq_list, self.train_list, self.valid_list = [], [], []
        for seq in seq_train:
            if seq not in self.config["exclude_seq"]:
                self.train_list.append(seq)
        for seq in seq_valid:
            if seq not in self.config["exclude_seq"]:
                self.valid_list.append(seq)
        seq_list = []
        if self.config["dump_train"]:
            seq_list.append(self.train_list)
        if self.config["dump_valid"]:
            seq_list.append(self.valid_list)
        self.seq_list = np.concatenate(seq_list, axis=0)

        # self.seq_list=self.seq_list[:2]
        # self.valid_list=self.valid_list[:2]
        for seq in self.seq_list:
            dump_dir = os.path.join(self.config["feature_dump_dir"], seq)
            cur_img_seq = glob.glob(
                os.path.join(data_dir, seq, "undist_images", "*.jpg")
            )
            cur_dump_seq = [
                os.path.join(dump_dir, path.split("/")[-1])
                + "_"
                + self.config["extractor"]["name"]
                + "_"
                + str(self.config["extractor"]["num_kpt"])
                + ".hdf5"
                for path in cur_img_seq
            ]
            self.img_seq += cur_img_seq
            self.dump_seq += cur_dump_seq

    def format_dump_folder(self):
        if not os.path.exists(self.config["feature_dump_dir"]):
            os.mkdir(self.config["feature_dump_dir"])
        for seq in self.seq_list:
            seq_dir = os.path.join(self.config["feature_dump_dir"], seq)
            if not os.path.exists(seq_dir):
                os.mkdir(seq_dir)
        if not os.path.exists(self.config["dataset_dump_dir"]):
            os.mkdir(self.config["dataset_dump_dir"])

    def load_geom(self, seq):
        # load geometry file
        geom_file = os.path.join(
            self.config["rawdata_dir"], "data", seq, "geolabel", "cameras.txt"
        )
        basename_list = np.loadtxt(
            os.path.join(self.config["rawdata_dir"], "data", seq, "basenames.txt"),
            dtype=str,
        )
        geom_dict = []
        cameras = np.loadtxt(geom_file)
        camera_index = 0
        for base_index in range(len(basename_list)):
            if base_index < cameras[camera_index][0]:
                geom_dict.append(None)
                continue
            cur_geom = {}
            ori_img_size = [cameras[camera_index][-2], cameras[camera_index][-1]]
            scale_factor = [1000.0 / ori_img_size[0], 1000.0 / ori_img_size[1]]
            K = np.asarray(
                [
                    [
                        cameras[camera_index][1],
                        cameras[camera_index][5],
                        cameras[camera_index][3],
                    ],
                    [0, cameras[camera_index][2], cameras[camera_index][4]],
                    [0, 0, 1],
                ]
            )
            # Rescale calbration according to previous resizing
            S = np.asarray(
                [[scale_factor[0], 0, 0], [0, scale_factor[1], 0], [0, 0, 1]]
            )
            K = np.dot(S, K)
            cur_geom["K"] = K
            cur_geom["R"] = cameras[camera_index][9:18].reshape([3, 3])
            cur_geom["T"] = cameras[camera_index][6:9]
            cur_geom["size"] = np.asarray([1000, 1000])
            geom_dict.append(cur_geom)
            camera_index += 1
        return geom_dict

    def load_depth(self, file_path):
        with open(os.path.join(file_path), "rb") as fin:
            color = None
            width = None
            height = None
            scale = None
            data_type = None
            header = str(fin.readline().decode("UTF-8")).rstrip()
            if header == "PF":
                color = True
            elif header == "Pf":
                color = False
            else:
                raise Exception("Not a PFM file.")
            dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8"))
            if dim_match:
                width, height = map(int, dim_match.groups())
            else:
                raise Exception("Malformed PFM header.")
            scale = float((fin.readline().decode("UTF-8")).rstrip())
            if scale < 0:  # little-endian
                data_type = "<f"
            else:
                data_type = ">f"  # big-endian
            data_string = fin.read()
            data = np.fromstring(data_string, data_type)
            shape = (height, width, 3) if color else (height, width)
            data = np.reshape(data, shape)
            data = np.flip(data, 0)
        return data

    def dump_info(self, seq, info):
        pair_type = [
            "dR",
            "dt",
            "K1",
            "K2",
            "size1",
            "size2",
            "corr",
            "incorr1",
            "incorr2",
        ]
        num_pairs = len(info["dR"])
        os.mkdir(os.path.join(self.config["dataset_dump_dir"], seq))
        with h5py.File(
            os.path.join(self.config["dataset_dump_dir"], seq, "info.h5py"), "w"
        ) as f:
            for type in pair_type:
                dg = f.create_group(type)
                for idx in range(num_pairs):
                    data_item = np.asarray(info[type][idx])
                    dg.create_dataset(
                        str(idx), data_item.shape, data_item.dtype, data=data_item
                    )
            for type in ["img_path1", "img_path2"]:
                dg = f.create_group(type)
                for idx in range(num_pairs):
                    dg.create_dataset(
                        str(idx),
                        [1],
                        h5py.string_dtype(encoding="ascii"),
                        data=info[type][idx].encode("ascii"),
                    )

        with open(
            os.path.join(self.config["dataset_dump_dir"], seq, "pair_num.txt"), "w"
        ) as f:
            f.write(str(info["pair_num"]))

    def format_seq(self, index):
        seq = self.seq_list[index]
        seq_dir = os.path.join(os.path.join(self.config["rawdata_dir"], "data", seq))
        basename_list = np.loadtxt(os.path.join(seq_dir, "basenames.txt"), dtype=str)
        pair_list = np.loadtxt(
            os.path.join(seq_dir, "geolabel", "common_track.txt"), dtype=float
        )[:, :2].astype(int)
        overlap_score = np.loadtxt(
            os.path.join(seq_dir, "geolabel", "common_track.txt"), dtype=float
        )[:, 2]
        geom_dict = self.load_geom(seq)

        # check info existance
        if os.path.exists(
            os.path.join(self.config["dataset_dump_dir"], seq, "pair_num.txt")
        ):
            return

        angle_list = []
        # filtering pairs
        for cur_pair in pair_list:
            pair_index1, pair_index2 = cur_pair[0], cur_pair[1]
            geo1, geo2 = geom_dict[pair_index1], geom_dict[pair_index2]
            dR = np.dot(geo2["R"], geo1["R"].T)
            q = transformations.quaternion_from_matrix(dR)
            angle_list.append(math.acos(q[0]) * 2 * 180 / math.pi)
        angle_list = np.asarray(angle_list)
        mask_survive = np.logical_and(
            np.logical_and(
                angle_list > self.config["angle_th"][0],
                angle_list < self.config["angle_th"][1],
            ),
            np.logical_and(
                overlap_score > self.config["overlap_th"][0],
                overlap_score < self.config["overlap_th"][1],
            ),
        )
        pair_list = pair_list[mask_survive]
        if len(pair_list) < 100:
            print(seq, len(pair_list))
        # sample pairs
        shuffled_pair_list = np.random.permutation(pair_list)
        sample_target = min(self.config["pairs_per_seq"], len(shuffled_pair_list))
        sample_number = 0

        info = {
            "dR": [],
            "dt": [],
            "K1": [],
            "K2": [],
            "img_path1": [],
            "img_path2": [],
            "fea_path1": [],
            "fea_path2": [],
            "size1": [],
            "size2": [],
            "corr": [],
            "incorr1": [],
            "incorr2": [],
            "pair_num": [],
        }
        for cur_pair in shuffled_pair_list:
            pair_index1, pair_index2 = cur_pair[0], cur_pair[1]
            geo1, geo2 = geom_dict[pair_index1], geom_dict[pair_index2]
            dR = np.dot(geo2["R"], geo1["R"].T)
            t1, t2 = geo1["T"].reshape([3, 1]), geo2["T"].reshape([3, 1])
            dt = t2 - np.dot(dR, t1)
            K1, K2 = geo1["K"], geo2["K"]
            size1, size2 = geo1["size"], geo2["size"]

            basename1, basename2 = (
                basename_list[pair_index1],
                basename_list[pair_index2],
            )
            img_path1, img_path2 = os.path.join(
                seq, "undist_images", basename1 + ".jpg"
            ), os.path.join(seq, "undist_images", basename2 + ".jpg")
            fea_path1, fea_path2 = os.path.join(
                seq,
                basename1
                + ".jpg"
                + "_"
                + self.config["extractor"]["name"]
                + "_"
                + str(self.config["extractor"]["num_kpt"])
                + ".hdf5",
            ), os.path.join(
                seq,
                basename2
                + ".jpg"
                + "_"
                + self.config["extractor"]["name"]
                + "_"
                + str(self.config["extractor"]["num_kpt"])
                + ".hdf5",
            )

            with h5py.File(
                os.path.join(self.config["feature_dump_dir"], fea_path1), "r"
            ) as fea1, h5py.File(
                os.path.join(self.config["feature_dump_dir"], fea_path2), "r"
            ) as fea2:
                desc1, desc2 = fea1["descriptors"][()], fea2["descriptors"][()]
                kpt1, kpt2 = fea1["keypoints"][()], fea2["keypoints"][()]
                depth_path1, depth_path2 = os.path.join(
                    self.config["rawdata_dir"],
                    "data",
                    seq,
                    "depths",
                    basename1 + ".pfm",
                ), os.path.join(
                    self.config["rawdata_dir"],
                    "data",
                    seq,
                    "depths",
                    basename2 + ".pfm",
                )
                depth1, depth2 = self.load_depth(depth_path1), self.load_depth(
                    depth_path2
                )
                corr_index, incorr_index1, incorr_index2 = data_utils.make_corr(
                    kpt1[:, :2],
                    kpt2[:, :2],
                    desc1,
                    desc2,
                    depth1,
                    depth2,
                    K1,
                    K2,
                    dR,
                    dt,
                    size1,
                    size2,
                    self.config["corr_th"],
                    self.config["incorr_th"],
                    self.config["check_desc"],
                )

            if (
                len(corr_index) > self.config["min_corr"]
                and len(incorr_index1) > self.config["min_incorr"]
                and len(incorr_index2) > self.config["min_incorr"]
            ):
                info["corr"].append(corr_index), info["incorr1"].append(
                    incorr_index1
                ), info["incorr2"].append(incorr_index2)
                info["dR"].append(dR), info["dt"].append(dt), info["K1"].append(
                    K1
                ), info["K2"].append(K2), info["img_path1"].append(img_path1), info[
                    "img_path2"
                ].append(
                    img_path2
                )
                info["fea_path1"].append(fea_path1), info["fea_path2"].append(
                    fea_path2
                ), info["size1"].append(size1), info["size2"].append(size2)
                sample_number += 1
            if sample_number == sample_target:
                break
        info["pair_num"] = sample_number
        # dump info
        self.dump_info(seq, info)

    def collect_meta(self):
        print("collecting meta info...")
        dump_path, seq_list = [], []
        if self.config["dump_train"]:
            dump_path.append(os.path.join(self.config["dataset_dump_dir"], "train"))
            seq_list.append(self.train_list)
        if self.config["dump_valid"]:
            dump_path.append(os.path.join(self.config["dataset_dump_dir"], "valid"))
            seq_list.append(self.valid_list)
        for pth, seqs in zip(dump_path, seq_list):
            if not os.path.exists(pth):
                os.mkdir(pth)
            pair_num_list, total_pair = [], 0
            for seq_index in range(len(seqs)):
                seq = seqs[seq_index]
                pair_num = np.loadtxt(
                    os.path.join(self.config["dataset_dump_dir"], seq, "pair_num.txt"),
                    dtype=int,
                )
                pair_num_list.append(str(pair_num))
                total_pair += pair_num
            pair_num_list = np.stack(
                [np.asarray(seqs, dtype=str), np.asarray(pair_num_list, dtype=str)],
                axis=1,
            )
            pair_num_list = np.concatenate(
                [np.asarray([["total", str(total_pair)]]), pair_num_list], axis=0
            )
            np.savetxt(os.path.join(pth, "pair_num.txt"), pair_num_list, fmt="%s")

    def format_dump_data(self):
        print("Formatting data...")
        iteration_num = len(self.seq_list) // self.config["num_process"]
        if len(self.seq_list) % self.config["num_process"] != 0:
            iteration_num += 1
        pool = Pool(self.config["num_process"])
        for index in trange(iteration_num):
            indices = range(
                index * self.config["num_process"],
                min((index + 1) * self.config["num_process"], len(self.seq_list)),
            )
            pool.map(self.format_seq, indices)
        pool.close()
        pool.join()

        self.collect_meta()