|
from math import log |
|
from loguru import logger |
|
|
|
import torch |
|
from einops import repeat |
|
from kornia.utils import create_meshgrid |
|
|
|
from .geometry import warp_kpts |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def mask_pts_at_padded_regions(grid_pt, mask): |
|
"""For megadepth dataset, zero-padding exists in images""" |
|
mask = repeat(mask, "n h w -> n (h w) c", c=2) |
|
grid_pt[~mask.bool()] = 0 |
|
return grid_pt |
|
|
|
|
|
@torch.no_grad() |
|
def spvs_coarse(data, config): |
|
""" |
|
Update: |
|
data (dict): { |
|
"conf_matrix_gt": [N, hw0, hw1], |
|
'spv_b_ids': [M] |
|
'spv_i_ids': [M] |
|
'spv_j_ids': [M] |
|
'spv_w_pt0_i': [N, hw0, 2], in original image resolution |
|
'spv_pt1_i': [N, hw1, 2], in original image resolution |
|
} |
|
|
|
NOTE: |
|
- for scannet dataset, there're 3 kinds of resolution {i, c, f} |
|
- for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} |
|
""" |
|
|
|
device = data["image0"].device |
|
N, _, H0, W0 = data["image0"].shape |
|
_, _, H1, W1 = data["image1"].shape |
|
scale = config["MODEL"]["RESOLUTION"][0] |
|
scale0 = scale * data["scale0"][:, None] if "scale0" in data else scale |
|
scale1 = scale * data["scale1"][:, None] if "scale0" in data else scale |
|
h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) |
|
|
|
|
|
|
|
grid_pt0_c = ( |
|
create_meshgrid(h0, w0, False, device).reshape(1, h0 * w0, 2).repeat(N, 1, 1) |
|
) |
|
grid_pt0_i = scale0 * grid_pt0_c |
|
grid_pt1_c = ( |
|
create_meshgrid(h1, w1, False, device).reshape(1, h1 * w1, 2).repeat(N, 1, 1) |
|
) |
|
grid_pt1_i = scale1 * grid_pt1_c |
|
|
|
|
|
if "mask0" in data: |
|
grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data["mask0"]) |
|
grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data["mask1"]) |
|
|
|
|
|
|
|
|
|
_, w_pt0_i = warp_kpts( |
|
grid_pt0_i, |
|
data["depth0"], |
|
data["depth1"], |
|
data["T_0to1"], |
|
data["K0"], |
|
data["K1"], |
|
) |
|
_, w_pt1_i = warp_kpts( |
|
grid_pt1_i, |
|
data["depth1"], |
|
data["depth0"], |
|
data["T_1to0"], |
|
data["K1"], |
|
data["K0"], |
|
) |
|
w_pt0_c = w_pt0_i / scale1 |
|
w_pt1_c = w_pt1_i / scale0 |
|
|
|
|
|
w_pt0_c_round = w_pt0_c[:, :, :].round().long() |
|
nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 |
|
w_pt1_c_round = w_pt1_c[:, :, :].round().long() |
|
nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0 |
|
|
|
|
|
def out_bound_mask(pt, w, h): |
|
return ( |
|
(pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) |
|
) |
|
|
|
nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 |
|
nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 |
|
|
|
loop_back = torch.stack( |
|
[nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0 |
|
) |
|
correct_0to1 = loop_back == torch.arange(h0 * w0, device=device)[None].repeat(N, 1) |
|
correct_0to1[:, 0] = False |
|
|
|
|
|
conf_matrix_gt = torch.zeros(N, h0 * w0, h1 * w1, device=device) |
|
b_ids, i_ids = torch.where(correct_0to1 != 0) |
|
j_ids = nearest_index1[b_ids, i_ids] |
|
|
|
conf_matrix_gt[b_ids, i_ids, j_ids] = 1 |
|
data.update({"conf_matrix_gt": conf_matrix_gt}) |
|
|
|
|
|
if len(b_ids) == 0: |
|
logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") |
|
|
|
b_ids = torch.tensor([0], device=device) |
|
i_ids = torch.tensor([0], device=device) |
|
j_ids = torch.tensor([0], device=device) |
|
|
|
data.update({"spv_b_ids": b_ids, "spv_i_ids": i_ids, "spv_j_ids": j_ids}) |
|
|
|
|
|
data.update({"spv_w_pt0_i": w_pt0_i, "spv_pt1_i": grid_pt1_i}) |
|
|
|
|
|
def compute_supervision_coarse(data, config): |
|
assert ( |
|
len(set(data["dataset_name"])) == 1 |
|
), "Do not support mixed datasets training!" |
|
data_source = data["dataset_name"][0] |
|
if data_source.lower() in ["scannet", "megadepth"]: |
|
spvs_coarse(data, config) |
|
else: |
|
raise ValueError(f"Unknown data source: {data_source}") |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def spvs_fine(data, config): |
|
""" |
|
Update: |
|
data (dict):{ |
|
"expec_f_gt": [M, 2]} |
|
""" |
|
|
|
|
|
w_pt0_i, pt1_i = data["spv_w_pt0_i"], data["spv_pt1_i"] |
|
scale = config["MODEL"]["RESOLUTION"][1] |
|
radius = config["MODEL"]["FINE_WINDOW_SIZE"] // 2 |
|
|
|
|
|
b_ids, i_ids, j_ids = data["b_ids"], data["i_ids"], data["j_ids"] |
|
|
|
|
|
scale = scale * data["scale1"][b_ids] if "scale0" in data else scale |
|
|
|
expec_f_gt = ( |
|
(w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius |
|
) |
|
data.update({"expec_f_gt": expec_f_gt}) |
|
|
|
|
|
def compute_supervision_fine(data, config): |
|
data_source = data["dataset_name"][0] |
|
if data_source.lower() in ["scannet", "megadepth"]: |
|
spvs_fine(data, config) |
|
else: |
|
raise NotImplementedError |
|
|