File size: 2,040 Bytes
404d2af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b973ee
 
 
404d2af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b973ee
 
 
 
404d2af
 
8b973ee
 
 
404d2af
 
 
 
 
 
 
 
 
8b973ee
404d2af
 
 
8b973ee
404d2af
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pprint
from abc import ABCMeta, abstractmethod
import torch
from itertools import chain

from src.utils.plotting import make_matching_figure, error_colormap
from src.utils.metrics import aggregate_metrics


def flatten_list(x):
    return list(chain(*x))


class Viz(metaclass=ABCMeta):
    def __init__(self):
        super().__init__()
        self.device = torch.device(
            "cuda:{}".format(0) if torch.cuda.is_available() else "cpu"
        )
        torch.set_grad_enabled(False)

        # for evaluation metrics of MegaDepth and ScanNet
        self.eval_stats = []
        self.time_stats = []

    def draw_matches(self, mkpts0, mkpts1, img0, img1, conf, path=None, **kwargs):
        thr = 5e-4
        # mkpts0 = pe['mkpts0_f'].cpu().numpy()
        # mkpts1 = pe['mkpts1_f'].cpu().numpy()
        if "conf_thr" in kwargs:
            thr = kwargs["conf_thr"]
        color = error_colormap(conf, thr, alpha=0.1)

        text = [
            f"{self.name}",
            f"#Matches: {len(mkpts0)}",
        ]
        if "R_errs" in kwargs:
            text.append(
                f"$\\Delta$R:{kwargs['R_errs']:.2f}°,  $\\Delta$t:{kwargs['t_errs']:.2f}°",
            )

        if path:
            make_matching_figure(
                img0, img1, mkpts0, mkpts1, color, text=text, path=path, dpi=150
            )
        else:
            return make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text)

    @abstractmethod
    def match_and_draw(self, data_dict, **kwargs):
        pass

    def compute_eval_metrics(self, epi_err_thr=5e-4):
        # metrics: dict of list, numpy
        _metrics = [o["metrics"] for o in self.eval_stats]
        metrics = {k: flatten_list([_me[k] for _me in _metrics]) for k in _metrics[0]}

        val_metrics_4tb = aggregate_metrics(metrics, epi_err_thr)
        print("\n" + pprint.pformat(val_metrics_4tb))

    def measure_time(self):
        if len(self.time_stats) == 0:
            return 0
        return sum(self.time_stats) / len(self.time_stats)