from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from omegaconf import OmegaConf
from tqdm import tqdm

from ..datasets import get_dataset
from ..models.cache_loader import CacheLoader
from ..settings import EVAL_PATH
from ..utils.export_predictions import export_predictions
from .eval_pipeline import EvalPipeline, load_eval
from .io import get_eval_parser, load_model, parse_eval_args
from .utils import aggregate_pr_results, get_tp_fp_pts


def eval_dataset(loader, pred_file, suffix=""):
    results = defaultdict(list)
    results["num_pos" + suffix] = 0
    cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
    for data in tqdm(loader):
        pred = cache_loader(data)

        if suffix == "":
            scores = pred["matching_scores0"].numpy()
            sort_indices = np.argsort(scores)[::-1]
            gt_matches = pred["gt_matches0"].numpy()[sort_indices]
            pred_matches = pred["matches0"].numpy()[sort_indices]
        else:
            scores = pred["line_matching_scores0"].numpy()
            sort_indices = np.argsort(scores)[::-1]
            gt_matches = pred["gt_line_matches0"].numpy()[sort_indices]
            pred_matches = pred["line_matches0"].numpy()[sort_indices]
        scores = scores[sort_indices]

        tp, fp, scores, num_pos = get_tp_fp_pts(pred_matches, gt_matches, scores)
        results["tp" + suffix].append(tp)
        results["fp" + suffix].append(fp)
        results["scores" + suffix].append(scores)
        results["num_pos" + suffix] += num_pos

    # Aggregate the results
    return aggregate_pr_results(results, suffix=suffix)


class ETH3DPipeline(EvalPipeline):
    default_conf = {
        "data": {
            "name": "eth3d",
            "batch_size": 1,
            "train_batch_size": 1,
            "val_batch_size": 1,
            "test_batch_size": 1,
            "num_workers": 16,
        },
        "model": {
            "name": "gluefactory.models.two_view_pipeline",
            "ground_truth": {
                "name": "gluefactory.models.matchers.depth_matcher",
                "use_lines": False,
            },
            "run_gt_in_forward": True,
        },
        "eval": {"plot_methods": [], "plot_line_methods": [], "eval_lines": False},
    }

    export_keys = [
        "gt_matches0",
        "matches0",
        "matching_scores0",
    ]

    optional_export_keys = [
        "gt_line_matches0",
        "line_matches0",
        "line_matching_scores0",
    ]

    def get_dataloader(self, data_conf=None):
        data_conf = data_conf if data_conf is not None else self.default_conf["data"]
        dataset = get_dataset("eth3d")(data_conf)
        return dataset.get_data_loader("test")

    def get_predictions(self, experiment_dir, model=None, overwrite=False):
        pred_file = experiment_dir / "predictions.h5"
        if not pred_file.exists() or overwrite:
            if model is None:
                model = load_model(self.conf.model, self.conf.checkpoint)
            export_predictions(
                self.get_dataloader(self.conf.data),
                model,
                pred_file,
                keys=self.export_keys,
                optional_keys=self.optional_export_keys,
            )
        return pred_file

    def run_eval(self, loader, pred_file):
        eval_conf = self.conf.eval
        r = eval_dataset(loader, pred_file)
        if self.conf.eval.eval_lines:
            r.update(eval_dataset(loader, pred_file, conf=eval_conf, suffix="_lines"))
        s = {}

        return s, {}, r


def plot_pr_curve(
    models_name, results, dst_file="eth3d_pr_curve.pdf", title=None, suffix=""
):
    plt.figure()
    f_scores = np.linspace(0.2, 0.9, num=8)
    for f_score in f_scores:
        x = np.linspace(0.01, 1)
        y = f_score * x / (2 * x - f_score)
        plt.plot(x[y >= 0], y[y >= 0], color=[0, 0.5, 0], alpha=0.3)
        plt.annotate(
            "f={0:0.1}".format(f_score),
            xy=(0.9, y[45] + 0.02),
            alpha=0.4,
            fontsize=14,
        )

    plt.rcParams.update({"font.size": 12})
    # plt.rc('legend', fontsize=10)
    plt.grid(True)
    plt.axis([0.0, 1.0, 0.0, 1.0])
    plt.xticks(np.arange(0, 1.05, step=0.1), fontsize=16)
    plt.xlabel("Recall", fontsize=18)
    plt.ylabel("Precision", fontsize=18)
    plt.yticks(np.arange(0, 1.05, step=0.1), fontsize=16)
    plt.ylim([0.3, 1.0])
    prop_cycle = plt.rcParams["axes.prop_cycle"]
    colors = prop_cycle.by_key()["color"]
    for m, c in zip(models_name, colors):
        sAP_string = f'{m}: {results[m]["AP" + suffix]:.1f}'
        plt.plot(
            results[m]["curve_recall" + suffix],
            results[m]["curve_precision" + suffix],
            label=sAP_string,
            color=c,
        )

    plt.legend(fontsize=16, loc="lower right")
    if title:
        plt.title(title)

    plt.tight_layout(pad=0.5)
    print(f"Saving plot to: {dst_file}")
    plt.savefig(dst_file)
    plt.show()


if __name__ == "__main__":
    dataset_name = Path(__file__).stem
    parser = get_eval_parser()
    args = parser.parse_intermixed_args()

    default_conf = OmegaConf.create(ETH3DPipeline.default_conf)

    # mingle paths
    output_dir = Path(EVAL_PATH, dataset_name)
    output_dir.mkdir(exist_ok=True, parents=True)

    name, conf = parse_eval_args(
        dataset_name,
        args,
        "configs/",
        default_conf,
    )

    experiment_dir = output_dir / name
    experiment_dir.mkdir(exist_ok=True)

    pipeline = ETH3DPipeline(conf)
    s, f, r = pipeline.run(
        experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval
    )

    # print results
    for k, v in r.items():
        if k.startswith("AP"):
            print(f"{k}: {v:.2f}")

    if args.plot:
        results = {}
        for m in conf.eval.plot_methods:
            exp_dir = output_dir / m
            results[m] = load_eval(exp_dir)[1]

        plot_pr_curve(conf.eval.plot_methods, results, dst_file="eth3d_pr_curve.pdf")
        if conf.eval.eval_lines:
            for m in conf.eval.plot_line_methods:
                exp_dir = output_dir / m
                results[m] = load_eval(exp_dir)[1]
            plot_pr_curve(
                conf.eval.plot_line_methods,
                results,
                dst_file="eth3d_pr_curve_lines.pdf",
                suffix="_lines",
            )