|
import warnings |
|
import numpy as np |
|
import cv2 |
|
import math |
|
import torch |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
import kornia |
|
|
|
def recover_pose(E, kpts0, kpts1, K0, K1, mask): |
|
best_num_inliers = 0 |
|
K0inv = np.linalg.inv(K0[:2,:2]) |
|
K1inv = np.linalg.inv(K1[:2,:2]) |
|
|
|
kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T |
|
kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T |
|
|
|
for _E in np.split(E, len(E) / 3): |
|
n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) |
|
if n > best_num_inliers: |
|
best_num_inliers = n |
|
ret = (R, t, mask.ravel() > 0) |
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): |
|
if len(kpts0) < 5: |
|
return None |
|
K0inv = np.linalg.inv(K0[:2,:2]) |
|
K1inv = np.linalg.inv(K1[:2,:2]) |
|
|
|
kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T |
|
kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T |
|
E, mask = cv2.findEssentialMat( |
|
kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf |
|
) |
|
|
|
ret = None |
|
if E is not None: |
|
best_num_inliers = 0 |
|
|
|
for _E in np.split(E, len(E) / 3): |
|
n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) |
|
if n > best_num_inliers: |
|
best_num_inliers = n |
|
ret = (R, t, mask.ravel() > 0) |
|
return ret |
|
|
|
def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): |
|
if len(kpts0) < 5: |
|
return None |
|
method = cv2.USAC_ACCURATE |
|
F, mask = cv2.findFundamentalMat( |
|
kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000 |
|
) |
|
E = K1.T@F@K0 |
|
ret = None |
|
if E is not None: |
|
best_num_inliers = 0 |
|
K0inv = np.linalg.inv(K0[:2,:2]) |
|
K1inv = np.linalg.inv(K1[:2,:2]) |
|
|
|
kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T |
|
kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T |
|
|
|
for _E in np.split(E, len(E) / 3): |
|
n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) |
|
if n > best_num_inliers: |
|
best_num_inliers = n |
|
ret = (R, t, mask.ravel() > 0) |
|
return ret |
|
|
|
def unnormalize_coords(x_n,h,w): |
|
x = torch.stack( |
|
(w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1 |
|
) |
|
return x |
|
|
|
|
|
def rotate_intrinsic(K, n): |
|
base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) |
|
rot = np.linalg.matrix_power(base_rot, n) |
|
return rot @ K |
|
|
|
|
|
def rotate_pose_inplane(i_T_w, rot): |
|
rotation_matrices = [ |
|
np.array( |
|
[ |
|
[np.cos(r), -np.sin(r), 0.0, 0.0], |
|
[np.sin(r), np.cos(r), 0.0, 0.0], |
|
[0.0, 0.0, 1.0, 0.0], |
|
[0.0, 0.0, 0.0, 1.0], |
|
], |
|
dtype=np.float32, |
|
) |
|
for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] |
|
] |
|
return np.dot(rotation_matrices[rot], i_T_w) |
|
|
|
|
|
def scale_intrinsics(K, scales): |
|
scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) |
|
return np.dot(scales, K) |
|
|
|
|
|
def to_homogeneous(points): |
|
return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) |
|
|
|
|
|
def angle_error_mat(R1, R2): |
|
cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 |
|
cos = np.clip(cos, -1.0, 1.0) |
|
return np.rad2deg(np.abs(np.arccos(cos))) |
|
|
|
|
|
def angle_error_vec(v1, v2): |
|
n = np.linalg.norm(v1) * np.linalg.norm(v2) |
|
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) |
|
|
|
|
|
def compute_pose_error(T_0to1, R, t): |
|
R_gt = T_0to1[:3, :3] |
|
t_gt = T_0to1[:3, 3] |
|
error_t = angle_error_vec(t.squeeze(), t_gt) |
|
error_t = np.minimum(error_t, 180 - error_t) |
|
error_R = angle_error_mat(R, R_gt) |
|
return error_t, error_R |
|
|
|
|
|
def pose_auc(errors, thresholds): |
|
sort_idx = np.argsort(errors) |
|
errors = np.array(errors.copy())[sort_idx] |
|
recall = (np.arange(len(errors)) + 1) / len(errors) |
|
errors = np.r_[0.0, errors] |
|
recall = np.r_[0.0, recall] |
|
aucs = [] |
|
for t in thresholds: |
|
last_index = np.searchsorted(errors, t) |
|
r = np.r_[recall[:last_index], recall[last_index - 1]] |
|
e = np.r_[errors[:last_index], t] |
|
aucs.append(np.trapz(r, x=e) / t) |
|
return aucs |
|
|
|
|
|
|
|
def get_depth_tuple_transform_ops_nearest_exact(resize=None): |
|
ops = [] |
|
if resize: |
|
ops.append(TupleResizeNearestExact(resize)) |
|
return TupleCompose(ops) |
|
|
|
def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): |
|
ops = [] |
|
if resize: |
|
ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR)) |
|
return TupleCompose(ops) |
|
|
|
|
|
def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None): |
|
ops = [] |
|
if resize: |
|
ops.append(TupleResize(resize)) |
|
ops.append(TupleToTensorScaled()) |
|
if normalize: |
|
ops.append( |
|
TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
) |
|
return TupleCompose(ops) |
|
|
|
class ToTensorScaled(object): |
|
"""Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" |
|
|
|
def __call__(self, im): |
|
if not isinstance(im, torch.Tensor): |
|
im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) |
|
im /= 255.0 |
|
return torch.from_numpy(im) |
|
else: |
|
return im |
|
|
|
def __repr__(self): |
|
return "ToTensorScaled(./255)" |
|
|
|
|
|
class TupleToTensorScaled(object): |
|
def __init__(self): |
|
self.to_tensor = ToTensorScaled() |
|
|
|
def __call__(self, im_tuple): |
|
return [self.to_tensor(im) for im in im_tuple] |
|
|
|
def __repr__(self): |
|
return "TupleToTensorScaled(./255)" |
|
|
|
|
|
class ToTensorUnscaled(object): |
|
"""Convert a RGB PIL Image to a CHW ordered Tensor""" |
|
|
|
def __call__(self, im): |
|
return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) |
|
|
|
def __repr__(self): |
|
return "ToTensorUnscaled()" |
|
|
|
|
|
class TupleToTensorUnscaled(object): |
|
"""Convert a RGB PIL Image to a CHW ordered Tensor""" |
|
|
|
def __init__(self): |
|
self.to_tensor = ToTensorUnscaled() |
|
|
|
def __call__(self, im_tuple): |
|
return [self.to_tensor(im) for im in im_tuple] |
|
|
|
def __repr__(self): |
|
return "TupleToTensorUnscaled()" |
|
|
|
class TupleResizeNearestExact: |
|
def __init__(self, size): |
|
self.size = size |
|
def __call__(self, im_tuple): |
|
return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple] |
|
|
|
def __repr__(self): |
|
return "TupleResizeNearestExact(size={})".format(self.size) |
|
|
|
|
|
class TupleResize(object): |
|
def __init__(self, size, mode=InterpolationMode.BICUBIC): |
|
self.size = size |
|
self.resize = transforms.Resize(size, mode) |
|
def __call__(self, im_tuple): |
|
return [self.resize(im) for im in im_tuple] |
|
|
|
def __repr__(self): |
|
return "TupleResize(size={})".format(self.size) |
|
|
|
class Normalize: |
|
def __call__(self,im): |
|
mean = im.mean(dim=(1,2), keepdims=True) |
|
std = im.std(dim=(1,2), keepdims=True) |
|
return (im-mean)/std |
|
|
|
|
|
class TupleNormalize(object): |
|
def __init__(self, mean, std): |
|
self.mean = mean |
|
self.std = std |
|
self.normalize = transforms.Normalize(mean=mean, std=std) |
|
|
|
def __call__(self, im_tuple): |
|
c,h,w = im_tuple[0].shape |
|
if c > 3: |
|
warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb") |
|
return [self.normalize(im[:3]) for im in im_tuple] |
|
|
|
def __repr__(self): |
|
return "TupleNormalize(mean={}, std={})".format(self.mean, self.std) |
|
|
|
|
|
class TupleCompose(object): |
|
def __init__(self, transforms): |
|
self.transforms = transforms |
|
|
|
def __call__(self, im_tuple): |
|
for t in self.transforms: |
|
im_tuple = t(im_tuple) |
|
return im_tuple |
|
|
|
def __repr__(self): |
|
format_string = self.__class__.__name__ + "(" |
|
for t in self.transforms: |
|
format_string += "\n" |
|
format_string += " {0}".format(t) |
|
format_string += "\n)" |
|
return format_string |
|
|
|
@torch.no_grad() |
|
def cls_to_flow(cls, deterministic_sampling = True): |
|
B,C,H,W = cls.shape |
|
device = cls.device |
|
res = round(math.sqrt(C)) |
|
G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)]) |
|
G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) |
|
if deterministic_sampling: |
|
sampled_cls = cls.max(dim=1).indices |
|
else: |
|
sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W) |
|
flow = G[sampled_cls] |
|
return flow |
|
|
|
@torch.no_grad() |
|
def cls_to_flow_refine(cls): |
|
B,C,H,W = cls.shape |
|
device = cls.device |
|
res = round(math.sqrt(C)) |
|
G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)]) |
|
G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) |
|
cls = cls.softmax(dim=1) |
|
mode = cls.max(dim=1).indices |
|
|
|
index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long() |
|
neighbours = torch.gather(cls, dim = 1, index = index)[...,None] |
|
flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]] |
|
tot_prob = neighbours.sum(dim=1) |
|
flow = flow / tot_prob |
|
return flow |
|
|
|
|
|
def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): |
|
|
|
if H is None: |
|
B,H,W = depth1.shape |
|
else: |
|
B = depth1.shape[0] |
|
with torch.no_grad(): |
|
x1_n = torch.meshgrid( |
|
*[ |
|
torch.linspace( |
|
-1 + 1 / n, 1 - 1 / n, n, device=depth1.device |
|
) |
|
for n in (B, H, W) |
|
] |
|
) |
|
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) |
|
mask, x2 = warp_kpts( |
|
x1_n.double(), |
|
depth1.double(), |
|
depth2.double(), |
|
T_1to2.double(), |
|
K1.double(), |
|
K2.double(), |
|
depth_interpolation_mode = depth_interpolation_mode, |
|
relative_depth_error_threshold = relative_depth_error_threshold, |
|
) |
|
prob = mask.float().reshape(B, H, W) |
|
x2 = x2.reshape(B, H, W, 2) |
|
return x2, prob |
|
|
|
@torch.no_grad() |
|
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): |
|
"""Warp kpts0 from I0 to I1 with depth, K and Rt |
|
Also check covisibility and depth consistency. |
|
Depth is consistent if relative error < 0.2 (hard-coded). |
|
# https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here |
|
Args: |
|
kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1) |
|
depth0 (torch.Tensor): [N, H, W], |
|
depth1 (torch.Tensor): [N, H, W], |
|
T_0to1 (torch.Tensor): [N, 3, 4], |
|
K0 (torch.Tensor): [N, 3, 3], |
|
K1 (torch.Tensor): [N, 3, 3], |
|
Returns: |
|
calculable_mask (torch.Tensor): [N, L] |
|
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat> |
|
""" |
|
( |
|
n, |
|
h, |
|
w, |
|
) = depth0.shape |
|
if depth_interpolation_mode == "combined": |
|
|
|
if smooth_mask: |
|
raise NotImplementedError("Combined bilinear and NN warp not implemented") |
|
valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, |
|
smooth_mask = smooth_mask, |
|
return_relative_depth_error = return_relative_depth_error, |
|
depth_interpolation_mode = "bilinear", |
|
relative_depth_error_threshold = relative_depth_error_threshold) |
|
valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, |
|
smooth_mask = smooth_mask, |
|
return_relative_depth_error = return_relative_depth_error, |
|
depth_interpolation_mode = "nearest-exact", |
|
relative_depth_error_threshold = relative_depth_error_threshold) |
|
nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) |
|
warp = warp_bilinear.clone() |
|
warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] |
|
valid = valid_bilinear | valid_nearest |
|
return valid, warp |
|
|
|
|
|
kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ |
|
:, 0, :, 0 |
|
] |
|
kpts0 = torch.stack( |
|
(w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 |
|
) |
|
|
|
nonzero_mask = kpts0_depth != 0 |
|
|
|
|
|
kpts0_h = ( |
|
torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) |
|
* kpts0_depth[..., None] |
|
) |
|
kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) |
|
kpts0_cam = kpts0_n |
|
|
|
|
|
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] |
|
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] |
|
|
|
|
|
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) |
|
w_kpts0 = w_kpts0_h[:, :, :2] / ( |
|
w_kpts0_h[:, :, [2]] + 1e-4 |
|
) |
|
|
|
|
|
h, w = depth1.shape[1:3] |
|
covisible_mask = ( |
|
(w_kpts0[:, :, 0] > 0) |
|
* (w_kpts0[:, :, 0] < w - 1) |
|
* (w_kpts0[:, :, 1] > 0) |
|
* (w_kpts0[:, :, 1] < h - 1) |
|
) |
|
w_kpts0 = torch.stack( |
|
(2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 |
|
) |
|
|
|
|
|
w_kpts0_depth = F.grid_sample( |
|
depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False |
|
)[:, 0, :, 0] |
|
|
|
relative_depth_error = ( |
|
(w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth |
|
).abs() |
|
if not smooth_mask: |
|
consistent_mask = relative_depth_error < relative_depth_error_threshold |
|
else: |
|
consistent_mask = (-relative_depth_error/smooth_mask).exp() |
|
valid_mask = nonzero_mask * covisible_mask * consistent_mask |
|
if return_relative_depth_error: |
|
return relative_depth_error, w_kpts0 |
|
else: |
|
return valid_mask, w_kpts0 |
|
|
|
imagenet_mean = torch.tensor([0.485, 0.456, 0.406]) |
|
imagenet_std = torch.tensor([0.229, 0.224, 0.225]) |
|
|
|
|
|
def numpy_to_pil(x: np.ndarray): |
|
""" |
|
Args: |
|
x: Assumed to be of shape (h,w,c) |
|
""" |
|
if isinstance(x, torch.Tensor): |
|
x = x.detach().cpu().numpy() |
|
if x.max() <= 1.01: |
|
x *= 255 |
|
x = x.astype(np.uint8) |
|
return Image.fromarray(x) |
|
|
|
|
|
def tensor_to_pil(x, unnormalize=False): |
|
if unnormalize: |
|
x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device)) |
|
x = x.detach().permute(1, 2, 0).cpu().numpy() |
|
x = np.clip(x, 0.0, 1.0) |
|
return numpy_to_pil(x) |
|
|
|
|
|
def to_cuda(batch): |
|
for key, value in batch.items(): |
|
if isinstance(value, torch.Tensor): |
|
batch[key] = value.cuda() |
|
return batch |
|
|
|
|
|
def to_cpu(batch): |
|
for key, value in batch.items(): |
|
if isinstance(value, torch.Tensor): |
|
batch[key] = value.cpu() |
|
return batch |
|
|
|
|
|
def get_pose(calib): |
|
w, h = np.array(calib["imsize"])[0] |
|
return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w |
|
|
|
|
|
def compute_relative_pose(R1, t1, R2, t2): |
|
rots = R2 @ (R1.T) |
|
trans = -rots @ t1 + t2 |
|
return rots, trans |
|
|
|
@torch.no_grad() |
|
def reset_opt(opt): |
|
for group in opt.param_groups: |
|
for p in group['params']: |
|
if p.requires_grad: |
|
state = opt.state[p] |
|
|
|
|
|
|
|
state['exp_avg'] = torch.zeros_like(p) |
|
|
|
state['exp_avg_sq'] = torch.zeros_like(p) |
|
|
|
state['exp_avg_diff'] = torch.zeros_like(p) |
|
|
|
|
|
def flow_to_pixel_coords(flow, h1, w1): |
|
flow = ( |
|
torch.stack( |
|
( |
|
w1 * (flow[..., 0] + 1) / 2, |
|
h1 * (flow[..., 1] + 1) / 2, |
|
), |
|
axis=-1, |
|
) |
|
) |
|
return flow |
|
|
|
def flow_to_normalized_coords(flow, h1, w1): |
|
flow = ( |
|
torch.stack( |
|
( |
|
2 * (flow[..., 0]) / w1 - 1, |
|
2 * (flow[..., 1]) / h1 - 1, |
|
), |
|
axis=-1, |
|
) |
|
) |
|
return flow |
|
|
|
|
|
def warp_to_pixel_coords(warp, h1, w1, h2, w2): |
|
warp1 = warp[..., :2] |
|
warp1 = ( |
|
torch.stack( |
|
( |
|
w1 * (warp1[..., 0] + 1) / 2, |
|
h1 * (warp1[..., 1] + 1) / 2, |
|
), |
|
axis=-1, |
|
) |
|
) |
|
warp2 = warp[..., 2:] |
|
warp2 = ( |
|
torch.stack( |
|
( |
|
w2 * (warp2[..., 0] + 1) / 2, |
|
h2 * (warp2[..., 1] + 1) / 2, |
|
), |
|
axis=-1, |
|
) |
|
) |
|
return torch.cat((warp1,warp2), dim=-1) |
|
|
|
|
|
|
|
def signed_point_line_distance(point, line, eps: float = 1e-9): |
|
r"""Return the distance from points to lines. |
|
|
|
Args: |
|
point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`. |
|
line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`. |
|
eps: Small constant for safe sqrt. |
|
|
|
Returns: |
|
the computed distance with shape :math:`(*, N)`. |
|
""" |
|
|
|
if not point.shape[-1] in (2, 3): |
|
raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}") |
|
|
|
if not line.shape[-1] == 3: |
|
raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}") |
|
|
|
numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2]) |
|
denominator = line[..., :2].norm(dim=-1) |
|
|
|
return numerator / (denominator + eps) |
|
|
|
|
|
def signed_left_to_right_epipolar_distance(pts1, pts2, Fm): |
|
r"""Return one-sided epipolar distance for correspondences given the fundamental matrix. |
|
|
|
This method measures the distance from points in the right images to the epilines |
|
of the corresponding points in the left images as they reflect in the right images. |
|
|
|
Args: |
|
pts1: correspondences from the left images with shape |
|
:math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. |
|
pts2: correspondences from the right images with shape |
|
:math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. |
|
Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to |
|
avoid ambiguity with torch.nn.functional. |
|
|
|
Returns: |
|
the computed Symmetrical distance with shape :math:`(*, N)`. |
|
""" |
|
import kornia |
|
if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3): |
|
raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}") |
|
|
|
if pts1.shape[-1] == 2: |
|
pts1 = kornia.geometry.convert_points_to_homogeneous(pts1) |
|
|
|
F_t = Fm.transpose(dim0=-2, dim1=-1) |
|
line1_in_2 = pts1 @ F_t |
|
|
|
return signed_point_line_distance(pts2, line1_in_2) |
|
|
|
def get_grid(b, h, w, device): |
|
grid = torch.meshgrid( |
|
*[ |
|
torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) |
|
for n in (b, h, w) |
|
] |
|
) |
|
grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2) |
|
return grid |
|
|