import numpy as np import torch from scipy.optimize import linear_sum_assignment from .depth import project, sample_depth from .epipolar import T_to_E, sym_epipolar_distance_all from .homography import warp_points_torch IGNORE_FEATURE = -2 UNMATCHED_FEATURE = -1 @torch.no_grad() def gt_matches_from_pose_depth( kp0, kp1, data, pos_th=3, neg_th=5, epi_th=None, cc_th=None, **kw ): if kp0.shape[1] == 0 or kp1.shape[1] == 0: b_size, n_kp0 = kp0.shape[:2] n_kp1 = kp1.shape[1] assignment = torch.zeros( b_size, n_kp0, n_kp1, dtype=torch.bool, device=kp0.device ) m0 = -torch.ones_like(kp0[:, :, 0]).long() m1 = -torch.ones_like(kp1[:, :, 0]).long() return assignment, m0, m1 camera0, camera1 = data["view0"]["camera"], data["view1"]["camera"] T_0to1, T_1to0 = data["T_0to1"], data["T_1to0"] depth0 = data["view0"].get("depth") depth1 = data["view1"].get("depth") if "depth_keypoints0" in kw and "depth_keypoints1" in kw: d0, valid0 = kw["depth_keypoints0"], kw["valid_depth_keypoints0"] d1, valid1 = kw["depth_keypoints1"], kw["valid_depth_keypoints1"] else: assert depth0 is not None assert depth1 is not None d0, valid0 = sample_depth(kp0, depth0) d1, valid1 = sample_depth(kp1, depth1) kp0_1, visible0 = project( kp0, d0, depth1, camera0, camera1, T_0to1, valid0, ccth=cc_th ) kp1_0, visible1 = project( kp1, d1, depth0, camera1, camera0, T_1to0, valid1, ccth=cc_th ) mask_visible = visible0.unsqueeze(-1) & visible1.unsqueeze(-2) # build a distance matrix of size [... x M x N] dist0 = torch.sum((kp0_1.unsqueeze(-2) - kp1.unsqueeze(-3)) ** 2, -1) dist1 = torch.sum((kp0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1) dist = torch.max(dist0, dist1) inf = dist.new_tensor(float("inf")) dist = torch.where(mask_visible, dist, inf) min0 = dist.min(-1).indices min1 = dist.min(-2).indices ismin0 = torch.zeros(dist.shape, dtype=torch.bool, device=dist.device) ismin1 = ismin0.clone() ismin0.scatter_(-1, min0.unsqueeze(-1), value=1) ismin1.scatter_(-2, min1.unsqueeze(-2), value=1) positive = ismin0 & ismin1 & (dist < pos_th**2) negative0 = (dist0.min(-1).values > neg_th**2) & valid0 negative1 = (dist1.min(-2).values > neg_th**2) & valid1 # pack the indices of positive matches # if -1: unmatched point # if -2: ignore point unmatched = min0.new_tensor(UNMATCHED_FEATURE) ignore = min0.new_tensor(IGNORE_FEATURE) m0 = torch.where(positive.any(-1), min0, ignore) m1 = torch.where(positive.any(-2), min1, ignore) m0 = torch.where(negative0, unmatched, m0) m1 = torch.where(negative1, unmatched, m1) F = ( camera1.calibration_matrix().inverse().transpose(-1, -2) @ T_to_E(T_0to1) @ camera0.calibration_matrix().inverse() ) epi_dist = sym_epipolar_distance_all(kp0, kp1, F) # Add some more unmatched points using epipolar geometry if epi_th is not None: mask_ignore = (m0.unsqueeze(-1) == ignore) & (m1.unsqueeze(-2) == ignore) epi_dist = torch.where(mask_ignore, epi_dist, inf) exclude0 = epi_dist.min(-1).values > neg_th exclude1 = epi_dist.min(-2).values > neg_th m0 = torch.where((~valid0) & exclude0, ignore.new_tensor(-1), m0) m1 = torch.where((~valid1) & exclude1, ignore.new_tensor(-1), m1) return { "assignment": positive, "reward": (dist < pos_th**2).float() - (epi_dist > neg_th).float(), "matches0": m0, "matches1": m1, "matching_scores0": (m0 > -1).float(), "matching_scores1": (m1 > -1).float(), "depth_keypoints0": d0, "depth_keypoints1": d1, "proj_0to1": kp0_1, "proj_1to0": kp1_0, "visible0": visible0, "visible1": visible1, } @torch.no_grad() def gt_matches_from_homography(kp0, kp1, H, pos_th=3, neg_th=6, **kw): if kp0.shape[1] == 0 or kp1.shape[1] == 0: b_size, n_kp0 = kp0.shape[:2] n_kp1 = kp1.shape[1] assignment = torch.zeros( b_size, n_kp0, n_kp1, dtype=torch.bool, device=kp0.device ) m0 = -torch.ones_like(kp0[:, :, 0]).long() m1 = -torch.ones_like(kp1[:, :, 0]).long() return assignment, m0, m1 kp0_1 = warp_points_torch(kp0, H, inverse=False) kp1_0 = warp_points_torch(kp1, H, inverse=True) # build a distance matrix of size [... x M x N] dist0 = torch.sum((kp0_1.unsqueeze(-2) - kp1.unsqueeze(-3)) ** 2, -1) dist1 = torch.sum((kp0.unsqueeze(-2) - kp1_0.unsqueeze(-3)) ** 2, -1) dist = torch.max(dist0, dist1) reward = (dist < pos_th**2).float() - (dist > neg_th**2).float() min0 = dist.min(-1).indices min1 = dist.min(-2).indices ismin0 = torch.zeros(dist.shape, dtype=torch.bool, device=dist.device) ismin1 = ismin0.clone() ismin0.scatter_(-1, min0.unsqueeze(-1), value=1) ismin1.scatter_(-2, min1.unsqueeze(-2), value=1) positive = ismin0 & ismin1 & (dist < pos_th**2) negative0 = dist0.min(-1).values > neg_th**2 negative1 = dist1.min(-2).values > neg_th**2 # pack the indices of positive matches # if -1: unmatched point # if -2: ignore point unmatched = min0.new_tensor(UNMATCHED_FEATURE) ignore = min0.new_tensor(IGNORE_FEATURE) m0 = torch.where(positive.any(-1), min0, ignore) m1 = torch.where(positive.any(-2), min1, ignore) m0 = torch.where(negative0, unmatched, m0) m1 = torch.where(negative1, unmatched, m1) return { "assignment": positive, "reward": reward, "matches0": m0, "matches1": m1, "matching_scores0": (m0 > -1).float(), "matching_scores1": (m1 > -1).float(), "proj_0to1": kp0_1, "proj_1to0": kp1_0, } def sample_pts(lines, npts): dir_vec = (lines[..., 2:4] - lines[..., :2]) / (npts - 1) pts = lines[..., :2, np.newaxis] + dir_vec[..., np.newaxis].expand( dir_vec.shape + (npts,) ) * torch.arange(npts).to(lines) pts = torch.transpose(pts, -1, -2) return pts def torch_perp_dist(segs2d, points_2d): # Check batch size and segments format assert segs2d.shape[0] == points_2d.shape[0] assert segs2d.shape[-1] == 4 dir = segs2d[..., 2:] - segs2d[..., :2] sizes = torch.norm(dir, dim=-1).half() norm_dir = dir / torch.unsqueeze(sizes, dim=-1) # middle_ptn = 0.5 * (segs2d[..., 2:] + segs2d[..., :2]) # centered [batch, nsegs0, nsegs1, n_sampled_pts, 2] centered = points_2d[:, None] - segs2d[..., None, None, 2:] R = torch.cat( [ norm_dir[..., 0, None], norm_dir[..., 1, None], -norm_dir[..., 1, None], norm_dir[..., 0, None], ], dim=2, ).reshape((len(segs2d), -1, 2, 2)) # Try to reduce the memory consumption by using float16 type if centered.is_cuda: centered, R = centered.half(), R.half() # R: [batch, nsegs0, 2, 2] , centered: [batch, nsegs1, n_sampled_pts, 2] # -> [batch, nsegs0, nsegs1, n_sampled_pts, 2] rotated = torch.einsum("bdji,bdepi->bdepj", R, centered) overlaping = (rotated[..., 0] <= 0) & ( torch.abs(rotated[..., 0]) <= sizes[..., None, None] ) return torch.abs(rotated[..., 1]), overlaping @torch.no_grad() def gt_line_matches_from_pose_depth( pred_lines0, pred_lines1, valid_lines0, valid_lines1, data, npts=50, dist_th=5, overlap_th=0.2, min_visibility_th=0.5, ): """Compute ground truth line matches and label the remaining the lines as: - UNMATCHED: if reprojection is outside the image or far away from any other line. - IGNORE: if a line has not enough valid depth pixels along itself or it is labeled as invalid.""" lines0 = pred_lines0.clone() lines1 = pred_lines1.clone() if pred_lines0.shape[1] == 0 or pred_lines1.shape[1] == 0: bsize, nlines0, nlines1 = ( pred_lines0.shape[0], pred_lines0.shape[1], pred_lines1.shape[1], ) positive = torch.zeros( (bsize, nlines0, nlines1), dtype=torch.bool, device=pred_lines0.device ) m0 = torch.full((bsize, nlines0), -1, device=pred_lines0.device) m1 = torch.full((bsize, nlines1), -1, device=pred_lines0.device) return positive, m0, m1 if lines0.shape[-2:] == (2, 2): lines0 = torch.flatten(lines0, -2) elif lines0.dim() == 4: lines0 = torch.cat([lines0[:, :, 0], lines0[:, :, -1]], dim=2) if lines1.shape[-2:] == (2, 2): lines1 = torch.flatten(lines1, -2) elif lines1.dim() == 4: lines1 = torch.cat([lines1[:, :, 0], lines1[:, :, -1]], dim=2) b_size, n_lines0, _ = lines0.shape b_size, n_lines1, _ = lines1.shape h0, w0 = data["view0"]["depth"][0].shape h1, w1 = data["view1"]["depth"][0].shape lines0 = torch.min( torch.max(lines0, torch.zeros_like(lines0)), lines0.new_tensor([w0 - 1, h0 - 1, w0 - 1, h0 - 1], dtype=torch.float), ) lines1 = torch.min( torch.max(lines1, torch.zeros_like(lines1)), lines1.new_tensor([w1 - 1, h1 - 1, w1 - 1, h1 - 1], dtype=torch.float), ) # Sample points along each line pts0 = sample_pts(lines0, npts).reshape(b_size, n_lines0 * npts, 2) pts1 = sample_pts(lines1, npts).reshape(b_size, n_lines1 * npts, 2) # Sample depth and valid points d0, valid0_pts0 = sample_depth(pts0, data["view0"]["depth"]) d1, valid1_pts1 = sample_depth(pts1, data["view1"]["depth"]) # Reproject to the other view pts0_1, visible0 = project( pts0, d0, data["view1"]["depth"], data["view0"]["camera"], data["view1"]["camera"], data["T_0to1"], valid0_pts0, ) pts1_0, visible1 = project( pts1, d1, data["view0"]["depth"], data["view1"]["camera"], data["view0"]["camera"], data["T_1to0"], valid1_pts1, ) h0, w0 = data["view0"]["image"].shape[-2:] h1, w1 = data["view1"]["image"].shape[-2:] # If a line has less than min_visibility_th inside the image is considered OUTSIDE pts_out_of0 = (pts1_0 < 0).any(-1) | ( pts1_0 >= torch.tensor([w0, h0]).to(pts1_0) ).any(-1) pts_out_of0 = pts_out_of0.reshape(b_size, n_lines1, npts).float() out_of0 = pts_out_of0.mean(dim=-1) >= (1 - min_visibility_th) pts_out_of1 = (pts0_1 < 0).any(-1) | ( pts0_1 >= torch.tensor([w1, h1]).to(pts0_1) ).any(-1) pts_out_of1 = pts_out_of1.reshape(b_size, n_lines0, npts).float() out_of1 = pts_out_of1.mean(dim=-1) >= (1 - min_visibility_th) # visible0 is [bs, nl0 * npts] pts0_1 = pts0_1.reshape(b_size, n_lines0, npts, 2) pts1_0 = pts1_0.reshape(b_size, n_lines1, npts, 2) perp_dists0, overlaping0 = torch_perp_dist(lines0, pts1_0) close_points0 = (perp_dists0 < dist_th) & overlaping0 # [bs, nl0, nl1, npts] del perp_dists0, overlaping0 close_points0 = close_points0 * visible1.reshape(b_size, 1, n_lines1, npts) perp_dists1, overlaping1 = torch_perp_dist(lines1, pts0_1) close_points1 = (perp_dists1 < dist_th) & overlaping1 # [bs, nl1, nl0, npts] del perp_dists1, overlaping1 close_points1 = close_points1 * visible0.reshape(b_size, 1, n_lines0, npts) torch.cuda.empty_cache() # For each segment detected in 0, how many sampled points from # reprojected segments 1 are close num_close_pts0 = close_points0.sum(dim=-1) # [bs, nl0, nl1] # num_close_pts0_t = num_close_pts0.transpose(-1, -2) # For each segment detected in 1, how many sampled points from # reprojected segments 0 are close num_close_pts1 = close_points1.sum(dim=-1) num_close_pts1_t = num_close_pts1.transpose(-1, -2) # [bs, nl1, nl0] num_close_pts = num_close_pts0 * num_close_pts1_t mask_close = ( num_close_pts1_t > visible0.reshape(b_size, n_lines0, npts).float().sum(-1)[:, :, None] * overlap_th ) & ( num_close_pts0 > visible1.reshape(b_size, n_lines1, npts).float().sum(-1)[:, None] * overlap_th ) # mask_close = (num_close_pts1_t > npts * overlap_th) & ( # num_close_pts0 > npts * overlap_th) # Define the unmatched lines unmatched0 = torch.all(~mask_close, dim=2) | out_of1 unmatched1 = torch.all(~mask_close, dim=1) | out_of0 # Define the lines to ignore ignore0 = ( valid0_pts0.reshape(b_size, n_lines0, npts).float().mean(dim=-1) < min_visibility_th ) | ~valid_lines0 ignore1 = ( valid1_pts1.reshape(b_size, n_lines1, npts).float().mean(dim=-1) < min_visibility_th ) | ~valid_lines1 cost = -num_close_pts.clone() # High score for unmatched and non-valid lines cost[unmatched0] = 1e6 cost[ignore0] = 1e6 # TODO: Is it reasonable to forbid the matching with a segment because it # has not GT depth? cost = cost.transpose(1, 2) cost[unmatched1] = 1e6 cost[ignore1] = 1e6 cost = cost.transpose(1, 2) # For each row, returns the col of max number of points assignation = np.array( [linear_sum_assignment(C) for C in cost.detach().cpu().numpy()] ) assignation = torch.tensor(assignation).to(num_close_pts) # Set ignore and unmatched labels unmatched = assignation.new_tensor(UNMATCHED_FEATURE) ignore = assignation.new_tensor(IGNORE_FEATURE) positive = num_close_pts.new_zeros(num_close_pts.shape, dtype=torch.bool) all_in_batch = ( torch.arange(b_size)[:, None].repeat(1, assignation.shape[-1]).flatten() ) positive[ all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten() ] = True m0 = assignation.new_full((b_size, n_lines0), unmatched, dtype=torch.long) m0.scatter_(-1, assignation[:, 0], assignation[:, 1]) m1 = assignation.new_full((b_size, n_lines1), unmatched, dtype=torch.long) m1.scatter_(-1, assignation[:, 1], assignation[:, 0]) positive = positive & mask_close # Remove values to be ignored or unmatched positive[unmatched0] = False positive[ignore0] = False positive = positive.transpose(1, 2) positive[unmatched1] = False positive[ignore1] = False positive = positive.transpose(1, 2) m0[~positive.any(-1)] = unmatched m0[unmatched0] = unmatched m0[ignore0] = ignore m1[~positive.any(-2)] = unmatched m1[unmatched1] = unmatched m1[ignore1] = ignore if num_close_pts.numel() == 0: no_matches = torch.zeros(positive.shape[0], 0).to(positive) return positive, no_matches, no_matches return positive, m0, m1 @torch.no_grad() def gt_line_matches_from_homography( pred_lines0, pred_lines1, valid_lines0, valid_lines1, shape0, shape1, H, npts=50, dist_th=5, overlap_th=0.2, min_visibility_th=0.2, ): """Compute ground truth line matches and label the remaining the lines as: - UNMATCHED: if reprojection is outside the image or far away from any other line. - IGNORE: if a line is labeled as invalid.""" h0, w0 = shape0[-2:] h1, w1 = shape1[-2:] lines0 = pred_lines0.clone() lines1 = pred_lines1.clone() if lines0.shape[-2:] == (2, 2): lines0 = torch.flatten(lines0, -2) elif lines0.dim() == 4: lines0 = torch.cat([lines0[:, :, 0], lines0[:, :, -1]], dim=2) if lines1.shape[-2:] == (2, 2): lines1 = torch.flatten(lines1, -2) elif lines1.dim() == 4: lines1 = torch.cat([lines1[:, :, 0], lines1[:, :, -1]], dim=2) b_size, n_lines0, _ = lines0.shape b_size, n_lines1, _ = lines1.shape lines0 = torch.min( torch.max(lines0, torch.zeros_like(lines0)), lines0.new_tensor([w0 - 1, h0 - 1, w0 - 1, h0 - 1], dtype=torch.float), ) lines1 = torch.min( torch.max(lines1, torch.zeros_like(lines1)), lines1.new_tensor([w1 - 1, h1 - 1, w1 - 1, h1 - 1], dtype=torch.float), ) # Sample points along each line pts0 = sample_pts(lines0, npts).reshape(b_size, n_lines0 * npts, 2) pts1 = sample_pts(lines1, npts).reshape(b_size, n_lines1 * npts, 2) # Project the points to the other image pts0_1 = warp_points_torch(pts0, H, inverse=False) pts1_0 = warp_points_torch(pts1, H, inverse=True) pts0_1 = pts0_1.reshape(b_size, n_lines0, npts, 2) pts1_0 = pts1_0.reshape(b_size, n_lines1, npts, 2) # If a line has less than min_visibility_th inside the image is considered OUTSIDE pts_out_of0 = (pts1_0 < 0).any(-1) | ( pts1_0 >= torch.tensor([w0, h0]).to(pts1_0) ).any(-1) pts_out_of0 = pts_out_of0.reshape(b_size, n_lines1, npts).float() out_of0 = pts_out_of0.mean(dim=-1) >= (1 - min_visibility_th) pts_out_of1 = (pts0_1 < 0).any(-1) | ( pts0_1 >= torch.tensor([w1, h1]).to(pts0_1) ).any(-1) pts_out_of1 = pts_out_of1.reshape(b_size, n_lines0, npts).float() out_of1 = pts_out_of1.mean(dim=-1) >= (1 - min_visibility_th) perp_dists0, overlaping0 = torch_perp_dist(lines0, pts1_0) close_points0 = (perp_dists0 < dist_th) & overlaping0 # [bs, nl0, nl1, npts] del perp_dists0, overlaping0 perp_dists1, overlaping1 = torch_perp_dist(lines1, pts0_1) close_points1 = (perp_dists1 < dist_th) & overlaping1 # [bs, nl1, nl0, npts] del perp_dists1, overlaping1 torch.cuda.empty_cache() # For each segment detected in 0, # how many sampled points from reprojected segments 1 are close num_close_pts0 = close_points0.sum(dim=-1) # [bs, nl0, nl1] # num_close_pts0_t = num_close_pts0.transpose(-1, -2) # For each segment detected in 1, # how many sampled points from reprojected segments 0 are close num_close_pts1 = close_points1.sum(dim=-1) num_close_pts1_t = num_close_pts1.transpose(-1, -2) # [bs, nl1, nl0] num_close_pts = num_close_pts0 * num_close_pts1_t mask_close = ( (num_close_pts1_t > npts * overlap_th) & (num_close_pts0 > npts * overlap_th) & ~out_of0.unsqueeze(1) & ~out_of1.unsqueeze(-1) ) # Define the unmatched lines unmatched0 = torch.all(~mask_close, dim=2) | out_of1 unmatched1 = torch.all(~mask_close, dim=1) | out_of0 # Define the lines to ignore ignore0 = ~valid_lines0 ignore1 = ~valid_lines1 cost = -num_close_pts.clone() # High score for unmatched and non-valid lines cost[unmatched0] = 1e6 cost[ignore0] = 1e6 cost = cost.transpose(1, 2) cost[unmatched1] = 1e6 cost[ignore1] = 1e6 cost = cost.transpose(1, 2) # For each row, returns the col of max number of points assignation = np.array( [linear_sum_assignment(C) for C in cost.detach().cpu().numpy()] ) assignation = torch.tensor(assignation).to(num_close_pts) # Set unmatched labels unmatched = assignation.new_tensor(UNMATCHED_FEATURE) ignore = assignation.new_tensor(IGNORE_FEATURE) positive = num_close_pts.new_zeros(num_close_pts.shape, dtype=torch.bool) # TODO Do with a single and beautiful call # for b in range(b_size): # positive[b][assignation[b, 0], assignation[b, 1]] = True positive[ torch.arange(b_size)[:, None].repeat(1, assignation.shape[-1]).flatten(), assignation[:, 0].flatten(), assignation[:, 1].flatten(), ] = True m0 = assignation.new_full((b_size, n_lines0), unmatched, dtype=torch.long) m0.scatter_(-1, assignation[:, 0], assignation[:, 1]) m1 = assignation.new_full((b_size, n_lines1), unmatched, dtype=torch.long) m1.scatter_(-1, assignation[:, 1], assignation[:, 0]) positive = positive & mask_close # Remove values to be ignored or unmatched positive[unmatched0] = False positive[ignore0] = False positive = positive.transpose(1, 2) positive[unmatched1] = False positive[ignore1] = False positive = positive.transpose(1, 2) m0[~positive.any(-1)] = unmatched m0[unmatched0] = unmatched m0[ignore0] = ignore m1[~positive.any(-2)] = unmatched m1[unmatched1] = unmatched m1[ignore1] = ignore if num_close_pts.numel() == 0: no_matches = torch.zeros(positive.shape[0], 0).to(positive) return positive, no_matches, no_matches return positive, m0, m1