Spaces:
Running
Running
import kornia | |
import torch | |
from .utils import get_image_coords | |
from .wrappers import Camera | |
def sample_fmap(pts, fmap): | |
h, w = fmap.shape[-2:] | |
grid_sample = torch.nn.functional.grid_sample | |
pts = (pts / pts.new_tensor([[w, h]]) * 2 - 1)[:, None] | |
# @TODO: This might still be a source of noise --> bilinear interpolation dangerous | |
interp_lin = grid_sample(fmap, pts, align_corners=False, mode="bilinear") | |
interp_nn = grid_sample(fmap, pts, align_corners=False, mode="nearest") | |
return torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)[:, :, 0].permute( | |
0, 2, 1 | |
) | |
def sample_depth(pts, depth_): | |
depth = torch.where(depth_ > 0, depth_, depth_.new_tensor(float("nan"))) | |
depth = depth[:, None] | |
interp = sample_fmap(pts, depth).squeeze(-1) | |
valid = (~torch.isnan(interp)) & (interp > 0) | |
return interp, valid | |
def sample_normals_from_depth(pts, depth, K): | |
depth = depth[:, None] | |
normals = kornia.geometry.depth.depth_to_normals(depth, K) | |
normals = torch.where(depth > 0, normals, 0.0) | |
interp = sample_fmap(pts, normals) | |
valid = (~torch.isnan(interp)) & (interp > 0) | |
return interp, valid | |
def project( | |
kpi, | |
di, | |
depthj, | |
camera_i, | |
camera_j, | |
T_itoj, | |
validi, | |
ccth=None, | |
sample_depth_fun=sample_depth, | |
sample_depth_kwargs=None, | |
): | |
if sample_depth_kwargs is None: | |
sample_depth_kwargs = {} | |
kpi_3d_i = camera_i.image2cam(kpi) | |
kpi_3d_i = kpi_3d_i * di[..., None] | |
kpi_3d_j = T_itoj.transform(kpi_3d_i) | |
kpi_j, validj = camera_j.cam2image(kpi_3d_j) | |
# di_j = kpi_3d_j[..., -1] | |
validi = validi & validj | |
if depthj is None or ccth is None: | |
return kpi_j, validi & validj | |
else: | |
# circle consistency | |
dj, validj = sample_depth_fun(kpi_j, depthj, **sample_depth_kwargs) | |
kpi_j_3d_j = camera_j.image2cam(kpi_j) * dj[..., None] | |
kpi_j_i, validj_i = camera_i.cam2image(T_itoj.inv().transform(kpi_j_3d_j)) | |
consistent = ((kpi - kpi_j_i) ** 2).sum(-1) < ccth | |
visible = validi & consistent & validj_i & validj | |
# visible = validi | |
return kpi_j, visible | |
def dense_warp_consistency( | |
depthi: torch.Tensor, | |
depthj: torch.Tensor, | |
T_itoj: torch.Tensor, | |
camerai: Camera, | |
cameraj: Camera, | |
**kwargs, | |
): | |
kpi = get_image_coords(depthi).flatten(-3, -2) | |
di = depthi.flatten( | |
-2, | |
) | |
validi = di > 0 | |
kpir, validir = project(kpi, di, depthj, camerai, cameraj, T_itoj, validi, **kwargs) | |
return kpir.unflatten(-2, depthi.shape[-2:]), validir.unflatten( | |
-1, (depthj.shape[-2:]) | |
) | |