Spaces:
Running
Running
File size: 2,657 Bytes
4dfb78b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
from ...geometry.gt_generation import (
gt_line_matches_from_pose_depth,
gt_matches_from_pose_depth,
)
from ..base_model import BaseModel
class DepthMatcher(BaseModel):
default_conf = {
# GT parameters for points
"use_points": True,
"th_positive": 3.0,
"th_negative": 5.0,
"th_epi": None, # add some more epi outliers
"th_consistency": None, # check for projection consistency in px
# GT parameters for lines
"use_lines": False,
"n_line_sampled_pts": 50,
"line_perp_dist_th": 5,
"overlap_th": 0.2,
"min_visibility_th": 0.5,
}
required_data_keys = ["view0", "view1", "T_0to1", "T_1to0"]
def _init(self, conf):
# TODO (iago): Is this just boilerplate code?
if self.conf.use_points:
self.required_data_keys += ["keypoints0", "keypoints1"]
if self.conf.use_lines:
self.required_data_keys += [
"lines0",
"lines1",
"valid_lines0",
"valid_lines1",
]
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def _forward(self, data):
result = {}
if self.conf.use_points:
if "depth_keypoints0" in data:
keys = [
"depth_keypoints0",
"valid_depth_keypoints0",
"depth_keypoints1",
"valid_depth_keypoints1",
]
kw = {k: data[k] for k in keys}
else:
kw = {}
result = gt_matches_from_pose_depth(
data["keypoints0"],
data["keypoints1"],
data,
pos_th=self.conf.th_positive,
neg_th=self.conf.th_negative,
epi_th=self.conf.th_epi,
cc_th=self.conf.th_consistency,
**kw,
)
if self.conf.use_lines:
line_assignment, line_m0, line_m1 = gt_line_matches_from_pose_depth(
data["lines0"],
data["lines1"],
data["valid_lines0"],
data["valid_lines1"],
data,
self.conf.n_line_sampled_pts,
self.conf.line_perp_dist_th,
self.conf.overlap_th,
self.conf.min_visibility_th,
)
result["line_matches0"] = line_m0
result["line_matches1"] = line_m1
result["line_assignment"] = line_assignment
return result
def loss(self, pred, data):
raise NotImplementedError
|