File size: 2,657 Bytes
c0283b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch

from ...geometry.gt_generation import (
    gt_line_matches_from_pose_depth,
    gt_matches_from_pose_depth,
)
from ..base_model import BaseModel


class DepthMatcher(BaseModel):
    default_conf = {
        # GT parameters for points
        "use_points": True,
        "th_positive": 3.0,
        "th_negative": 5.0,
        "th_epi": None,  # add some more epi outliers
        "th_consistency": None,  # check for projection consistency in px
        # GT parameters for lines
        "use_lines": False,
        "n_line_sampled_pts": 50,
        "line_perp_dist_th": 5,
        "overlap_th": 0.2,
        "min_visibility_th": 0.5,
    }

    required_data_keys = ["view0", "view1", "T_0to1", "T_1to0"]

    def _init(self, conf):
        # TODO (iago): Is this just boilerplate code?
        if self.conf.use_points:
            self.required_data_keys += ["keypoints0", "keypoints1"]
        if self.conf.use_lines:
            self.required_data_keys += [
                "lines0",
                "lines1",
                "valid_lines0",
                "valid_lines1",
            ]

    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def _forward(self, data):
        result = {}
        if self.conf.use_points:
            if "depth_keypoints0" in data:
                keys = [
                    "depth_keypoints0",
                    "valid_depth_keypoints0",
                    "depth_keypoints1",
                    "valid_depth_keypoints1",
                ]
                kw = {k: data[k] for k in keys}
            else:
                kw = {}
            result = gt_matches_from_pose_depth(
                data["keypoints0"],
                data["keypoints1"],
                data,
                pos_th=self.conf.th_positive,
                neg_th=self.conf.th_negative,
                epi_th=self.conf.th_epi,
                cc_th=self.conf.th_consistency,
                **kw,
            )
        if self.conf.use_lines:
            line_assignment, line_m0, line_m1 = gt_line_matches_from_pose_depth(
                data["lines0"],
                data["lines1"],
                data["valid_lines0"],
                data["valid_lines1"],
                data,
                self.conf.n_line_sampled_pts,
                self.conf.line_perp_dist_th,
                self.conf.overlap_th,
                self.conf.min_visibility_th,
            )
            result["line_matches0"] = line_m0
            result["line_matches1"] = line_m1
            result["line_assignment"] = line_assignment
        return result

    def loss(self, pred, data):
        raise NotImplementedError