import logging import zipfile from collections import defaultdict from collections.abc import Iterable from pathlib import Path from pprint import pprint import matplotlib.pyplot as plt import numpy as np import torch from omegaconf import OmegaConf from tqdm import tqdm from ..datasets import get_dataset from ..models.cache_loader import CacheLoader from ..settings import DATA_PATH, EVAL_PATH from ..utils.export_predictions import export_predictions from ..visualization.viz2d import plot_cumulative from .eval_pipeline import EvalPipeline from .io import get_eval_parser, load_model, parse_eval_args from .utils import eval_matches_epipolar, eval_poses, eval_relative_pose_robust logger = logging.getLogger(__name__) class MegaDepth1500Pipeline(EvalPipeline): default_conf = { "data": { "name": "image_pairs", "pairs": "megadepth1500/pairs_calibrated.txt", "root": "megadepth1500/images/", "extra_data": "relative_pose", "preprocessing": { "side": "long", }, }, "model": { "ground_truth": { "name": None, # remove gt matches } }, "eval": { "estimator": "poselib", "ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best }, } export_keys = [ "keypoints0", "keypoints1", "keypoint_scores0", "keypoint_scores1", "matches0", "matches1", "matching_scores0", "matching_scores1", ] optional_export_keys = [] def _init(self, conf): if not (DATA_PATH / "megadepth1500").exists(): logger.info("Downloading the MegaDepth-1500 dataset.") url = "https://cvg-data.inf.ethz.ch/megadepth/megadepth1500.zip" zip_path = DATA_PATH / url.rsplit("/", 1)[-1] zip_path.parent.mkdir(exist_ok=True, parents=True) torch.hub.download_url_to_file(url, zip_path) with zipfile.ZipFile(zip_path) as fid: fid.extractall(DATA_PATH) zip_path.unlink() @classmethod def get_dataloader(self, data_conf=None): """Returns a data loader with samples for each eval datapoint""" data_conf = data_conf if data_conf else self.default_conf["data"] dataset = get_dataset(data_conf["name"])(data_conf) return dataset.get_data_loader("test") def get_predictions(self, experiment_dir, model=None, overwrite=False): """Export a prediction file for each eval datapoint""" pred_file = experiment_dir / "predictions.h5" if not pred_file.exists() or overwrite: if model is None: model = load_model(self.conf.model, self.conf.checkpoint) export_predictions( self.get_dataloader(self.conf.data), model, pred_file, keys=self.export_keys, optional_keys=self.optional_export_keys, ) return pred_file def run_eval(self, loader, pred_file): """Run the eval on cached predictions""" conf = self.conf.eval results = defaultdict(list) test_thresholds = ( ([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) if not isinstance(conf.ransac_th, Iterable) else conf.ransac_th ) pose_results = defaultdict(lambda: defaultdict(list)) cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval() for i, data in enumerate(tqdm(loader)): pred = cache_loader(data) # add custom evaluations here results_i = eval_matches_epipolar(data, pred) for th in test_thresholds: pose_results_i = eval_relative_pose_robust( data, pred, {"estimator": conf.estimator, "ransac_th": th}, ) [pose_results[th][k].append(v) for k, v in pose_results_i.items()] # we also store the names for later reference results_i["names"] = data["name"][0] if "scene" in data.keys(): results_i["scenes"] = data["scene"][0] for k, v in results_i.items(): results[k].append(v) # summarize results as a dict[str, float] # you can also add your custom evaluations here summaries = {} for k, v in results.items(): arr = np.array(v) if not np.issubdtype(np.array(v).dtype, np.number): continue summaries[f"m{k}"] = round(np.mean(arr), 3) best_pose_results, best_th = eval_poses( pose_results, auc_ths=[5, 10, 20], key="rel_pose_error" ) results = {**results, **pose_results[best_th]} summaries = { **summaries, **best_pose_results, } figures = { "pose_recall": plot_cumulative( {self.conf.eval.estimator: results["rel_pose_error"]}, [0, 30], unit="°", title="Pose ", ) } return summaries, figures, results if __name__ == "__main__": from .. import logger # overwrite the logger dataset_name = Path(__file__).stem parser = get_eval_parser() args = parser.parse_intermixed_args() default_conf = OmegaConf.create(MegaDepth1500Pipeline.default_conf) # mingle paths output_dir = Path(EVAL_PATH, dataset_name) output_dir.mkdir(exist_ok=True, parents=True) name, conf = parse_eval_args( dataset_name, args, "configs/", default_conf, ) experiment_dir = output_dir / name experiment_dir.mkdir(exist_ok=True) pipeline = MegaDepth1500Pipeline(conf) s, f, r = pipeline.run( experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval, ) pprint(s) if args.plot: for name, fig in f.items(): fig.canvas.manager.set_window_title(name) plt.show()