Realcat
add: GIM (https://github.com/xuelunshen/gim)
c0283b3
raw
history blame
5.29 kB
import torch
from .utils import skew_symmetric, to_homogeneous
from .wrappers import Camera, Pose
def T_to_E(T: Pose):
"""Convert batched poses (..., 4, 4) to batched essential matrices."""
return skew_symmetric(T.t) @ T.R
def T_to_F(cam0: Camera, cam1: Camera, T_0to1: Pose):
return E_to_F(cam0, cam1, T_to_E(T_0to1))
def E_to_F(cam0: Camera, cam1: Camera, E: torch.Tensor):
assert cam0._data.shape[-1] == 6, "only pinhole cameras supported"
assert cam1._data.shape[-1] == 6, "only pinhole cameras supported"
K0 = cam0.calibration_matrix()
K1 = cam1.calibration_matrix()
return K1.inverse().transpose(-1, -2) @ E @ K0.inverse()
def F_to_E(cam0: Camera, cam1: Camera, F: torch.Tensor):
assert cam0._data.shape[-1] == 6, "only pinhole cameras supported"
assert cam1._data.shape[-1] == 6, "only pinhole cameras supported"
K0 = cam0.calibration_matrix()
K1 = cam1.calibration_matrix()
return K1.transpose(-1, -2) @ F @ K0
def sym_epipolar_distance(p0, p1, E, squared=True):
"""Compute batched symmetric epipolar distances.
Args:
p0, p1: batched tensors of N 2D points of size (..., N, 2).
E: essential matrices from camera 0 to camera 1, size (..., 3, 3).
Returns:
The symmetric epipolar distance of each point-pair: (..., N).
"""
assert p0.shape[-2] == p1.shape[-2]
if p0.shape[-2] == 0:
return torch.zeros(p0.shape[:-1]).to(p0)
if p0.shape[-1] != 3:
p0 = to_homogeneous(p0)
if p1.shape[-1] != 3:
p1 = to_homogeneous(p1)
p1_E_p0 = torch.einsum("...ni,...ij,...nj->...n", p1, E, p0)
E_p0 = torch.einsum("...ij,...nj->...ni", E, p0)
Et_p1 = torch.einsum("...ij,...ni->...nj", E, p1)
d0 = (E_p0[..., 0] ** 2 + E_p0[..., 1] ** 2).clamp(min=1e-6)
d1 = (Et_p1[..., 0] ** 2 + Et_p1[..., 1] ** 2).clamp(min=1e-6)
if squared:
d = p1_E_p0**2 * (1 / d0 + 1 / d1)
else:
d = p1_E_p0.abs() * (1 / d0.sqrt() + 1 / d1.sqrt()) / 2
return d
def sym_epipolar_distance_all(p0, p1, E, eps=1e-15):
if p0.shape[-1] != 3:
p0 = to_homogeneous(p0)
if p1.shape[-1] != 3:
p1 = to_homogeneous(p1)
p1_E_p0 = torch.einsum("...mi,...ij,...nj->...nm", p1, E, p0).abs()
E_p0 = torch.einsum("...ij,...nj->...ni", E, p0)
Et_p1 = torch.einsum("...ij,...mi->...mj", E, p1)
d0 = p1_E_p0 / (E_p0[..., None, 0] ** 2 + E_p0[..., None, 1] ** 2 + eps).sqrt()
d1 = (
p1_E_p0
/ (Et_p1[..., None, :, 0] ** 2 + Et_p1[..., None, :, 1] ** 2 + eps).sqrt()
)
return (d0 + d1) / 2
def generalized_epi_dist(
kpts0, kpts1, cam0: Camera, cam1: Camera, T_0to1: Pose, all=True, essential=True
):
if essential:
E = T_to_E(T_0to1)
p0 = cam0.image2cam(kpts0)
p1 = cam1.image2cam(kpts1)
if all:
return sym_epipolar_distance_all(p0, p1, E, agg="max")
else:
return sym_epipolar_distance(p0, p1, E, squared=False)
else:
assert cam0._data.shape[-1] == 6
assert cam1._data.shape[-1] == 6
K0, K1 = cam0.calibration_matrix(), cam1.calibration_matrix()
F = K1.inverse().transpose(-1, -2) @ T_to_E(T_0to1) @ K0.inverse()
if all:
return sym_epipolar_distance_all(kpts0, kpts1, F)
else:
return sym_epipolar_distance(kpts0, kpts1, F, squared=False)
def decompose_essential_matrix(E):
# decompose matrix by its singular values
U, _, V = torch.svd(E)
Vt = V.transpose(-2, -1)
mask = torch.ones_like(E)
mask[..., -1:] *= -1.0 # fill last column with negative values
maskt = mask.transpose(-2, -1)
# avoid singularities
U = torch.where((torch.det(U) < 0.0)[..., None, None], U * mask, U)
Vt = torch.where((torch.det(Vt) < 0.0)[..., None, None], Vt * maskt, Vt)
W = skew_symmetric(E.new_tensor([[0, 0, 1]]))
W[..., 2, 2] += 1.0
# reconstruct rotations and retrieve translation vector
U_W_Vt = U @ W @ Vt
U_Wt_Vt = U @ W.transpose(-2, -1) @ Vt
# return values
R1 = U_W_Vt
R2 = U_Wt_Vt
T = U[..., -1]
return R1, R2, T
# pose errors
# TODO: test for batched data
def angle_error_mat(R1, R2):
cos = (torch.trace(torch.einsum("...ij, ...jk -> ...ik", R1.T, R2)) - 1) / 2
cos = torch.clip(cos, -1.0, 1.0) # numerical errors can make it out of bounds
return torch.rad2deg(torch.abs(torch.arccos(cos)))
def angle_error_vec(v1, v2, eps=1e-10):
n = torch.clip(v1.norm(dim=-1) * v2.norm(dim=-1), min=eps)
v1v2 = (v1 * v2).sum(dim=-1) # dot product in the last dimension
return torch.rad2deg(torch.arccos(torch.clip(v1v2 / n, -1.0, 1.0)))
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0, eps=1e-10):
if isinstance(T_0to1, torch.Tensor):
R_gt, t_gt = T_0to1[:3, :3], T_0to1[:3, 3]
else:
R_gt, t_gt = T_0to1.R, T_0to1.t
R_gt, t_gt = torch.squeeze(R_gt), torch.squeeze(t_gt)
# angle error between 2 vectors
t_err = angle_error_vec(t, t_gt, eps)
t_err = torch.minimum(t_err, 180 - t_err) # handle E ambiguity
if t_gt.norm() < ignore_gt_t_thr: # pure rotation is challenging
t_err = 0
# angle error between 2 rotation matrices
r_err = angle_error_mat(R, R_gt)
return t_err, r_err