|
from matplotlib.pyplot import grid |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import math |
|
from scipy.spatial.transform import Rotation |
|
|
|
def tum_to_pose_matrix(pose): |
|
|
|
assert pose.shape == (7,) |
|
pose_xyzw = pose[[3, 4, 5, 6]] |
|
r = Rotation.from_quat(pose_xyzw) |
|
return np.vstack([np.hstack([r.as_matrix(), pose[:3].reshape(-1, 1)]), [0, 0, 0, 1]]) |
|
|
|
def depth_regularization_si_weighted(depth_pred, depth_init, pixel_wise_weight=None, pixel_wise_weight_scale=1, pixel_wise_weight_bias=1, eps=1e-6, pixel_weight_normalize=False): |
|
|
|
depth_pred = torch.clamp(depth_pred, min=eps) |
|
depth_init = torch.clamp(depth_init, min=eps) |
|
log_d_pred = torch.log(depth_pred) |
|
log_d_init = torch.log(depth_init) |
|
B, _, H, W = depth_pred.shape |
|
scale = torch.sum(log_d_init - log_d_pred, |
|
dim=[1, 2, 3], keepdim=True)/(H*W) |
|
if pixel_wise_weight is not None: |
|
if pixel_weight_normalize: |
|
norm = torch.max(pixel_wise_weight.detach().view( |
|
B, -1), dim=1, keepdim=False)[0] |
|
pixel_wise_weight = pixel_wise_weight / \ |
|
(norm[:, None, None, None]+eps) |
|
pixel_wise_weight = pixel_wise_weight * \ |
|
pixel_wise_weight_scale + pixel_wise_weight_bias |
|
else: |
|
pixel_wise_weight = 1 |
|
si_loss = torch.sum(pixel_wise_weight*(log_d_pred - |
|
log_d_init + scale)**2, dim=[1, 2, 3])/(H*W) |
|
return si_loss.mean() |
|
|
|
class WarpImage(torch.nn.Module): |
|
def __init__(self): |
|
super(WarpImage, self).__init__() |
|
self.base_coord = None |
|
|
|
def init_grid(self, shape, device): |
|
H, W = shape |
|
hh, ww = torch.meshgrid(torch.arange( |
|
H).float(), torch.arange(W).float()) |
|
coord = torch.zeros([1, H, W, 2]) |
|
coord[0, ..., 0] = ww |
|
coord[0, ..., 1] = hh |
|
self.base_coord = coord.to(device) |
|
self.W = W |
|
self.H = H |
|
|
|
def warp_image(self, base_coord, img_1, flow_2_1): |
|
B, C, H, W = flow_2_1.shape |
|
sample_grids = base_coord + flow_2_1.permute([0, 2, 3, 1]) |
|
sample_grids[..., 0] /= (W - 1) / 2 |
|
sample_grids[..., 1] /= (H - 1) / 2 |
|
sample_grids -= 1 |
|
warped_image_2_from_1 = F.grid_sample( |
|
img_1, sample_grids, align_corners=True) |
|
return warped_image_2_from_1 |
|
|
|
def forward(self, img_1, flow_2_1): |
|
B, _, H, W = flow_2_1.shape |
|
if self.base_coord is None: |
|
self.init_grid([H, W], device=flow_2_1.device) |
|
base_coord = self.base_coord.expand([B, -1, -1, -1]) |
|
return self.warp_image(base_coord, img_1, flow_2_1) |
|
|
|
|
|
class CameraIntrinsics(nn.Module): |
|
def __init__(self, init_focal_length=0.45, pixel_size=None): |
|
super().__init__() |
|
self.focal_length = nn.Parameter(torch.tensor(init_focal_length)) |
|
self.pixel_size_buffer = pixel_size |
|
|
|
def register_shape(self, orig_shape, opt_shape) -> None: |
|
self.orig_shape = orig_shape |
|
self.opt_shape = opt_shape |
|
H_orig, W_orig = orig_shape |
|
H_opt, W_opt = opt_shape |
|
if self.pixel_size_buffer is None: |
|
|
|
pixel_size = 0.433 / (H_orig ** 2 + W_orig ** 2) ** 0.5 |
|
else: |
|
pixel_size = self.pixel_size_buffer |
|
self.register_buffer("pixel_size", torch.tensor(pixel_size)) |
|
intrinsics_mat_buffer = torch.zeros(3, 3) |
|
intrinsics_mat_buffer[0, -1] = (W_opt - 1) / 2 |
|
intrinsics_mat_buffer[1, -1] = (H_opt - 1) / 2 |
|
intrinsics_mat_buffer[2, -1] = 1 |
|
self.register_buffer("intrinsics_mat", intrinsics_mat_buffer) |
|
self.register_buffer("scale_H", torch.tensor( |
|
H_opt / (H_orig * pixel_size))) |
|
self.register_buffer("scale_W", torch.tensor( |
|
W_opt / (W_orig * pixel_size))) |
|
|
|
def get_K_and_inv(self, with_batch_dim=True) -> torch.Tensor: |
|
intrinsics_mat = self.intrinsics_mat.clone() |
|
intrinsics_mat[0, 0] = self.focal_length * self.scale_W |
|
intrinsics_mat[1, 1] = self.focal_length * self.scale_H |
|
inv_intrinsics_mat = torch.linalg.inv(intrinsics_mat) |
|
if with_batch_dim: |
|
return intrinsics_mat[None, ...], inv_intrinsics_mat[None, ...] |
|
else: |
|
return intrinsics_mat, inv_intrinsics_mat |
|
|
|
|
|
@torch.jit.script |
|
def hat(v: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Compute the Hat operator [1] of a batch of 3D vectors. |
|
|
|
Args: |
|
v: Batch of vectors of shape `(minibatch , 3)`. |
|
|
|
Returns: |
|
Batch of skew-symmetric matrices of shape |
|
`(minibatch, 3 , 3)` where each matrix is of the form: |
|
`[ 0 -v_z v_y ] |
|
[ v_z 0 -v_x ] |
|
[ -v_y v_x 0 ]` |
|
|
|
Raises: |
|
ValueError if `v` is of incorrect shape. |
|
|
|
[1] https://en.wikipedia.org/wiki/Hat_operator |
|
""" |
|
|
|
N, dim = v.shape |
|
|
|
|
|
|
|
h = torch.zeros((N, 3, 3), dtype=v.dtype, device=v.device) |
|
|
|
x, y, z = v.unbind(1) |
|
|
|
h[:, 0, 1] = -z |
|
h[:, 0, 2] = y |
|
h[:, 1, 0] = z |
|
h[:, 1, 2] = -x |
|
h[:, 2, 0] = -y |
|
h[:, 2, 1] = x |
|
|
|
return h |
|
|
|
|
|
@torch.jit.script |
|
def get_relative_transform(src_R, src_t, tgt_R, tgt_t): |
|
tgt_R_inv = tgt_R.permute([0, 2, 1]) |
|
relative_R = torch.matmul(tgt_R_inv, src_R) |
|
relative_t = torch.matmul(tgt_R_inv, src_t - tgt_t) |
|
return relative_R, relative_t |
|
|
|
|
|
def reproject_depth(src_R, src_t, src_disp, tgt_R, tgt_t, tgt_disp, K_src, K_inv_src, K_trg, K_inv_trg, coord, eps=1e-6): |
|
""" |
|
Convert the depth map's value to another camera pose. |
|
input: |
|
src_R: rotation matrix of source camera |
|
src_t: translation vector of source camera |
|
tgt_R: rotation matrix of target camera |
|
tgt_t: translation vector of target camera |
|
K: intrinsics matrix of the camera |
|
src_disp: disparity map of source camera |
|
tgt_disp: disparity map of target camera |
|
coord: coordinate grids |
|
K_inv: inverse intrinsics matrix of the camera |
|
output: |
|
tgt_depth_from_src: source depth map reprojected to target camera, values are ready for warping. |
|
src_depth_from_tgt: target depth map reprojected to source camera, values are ready for warping. |
|
""" |
|
B, _, H, W = src_disp.shape |
|
|
|
src_depth = 1/(src_disp + eps) |
|
tgt_depth = 1/(tgt_disp + eps) |
|
|
|
|
|
src_depth_flat = src_depth.view([B, 1, H*W]) |
|
src_xyz = src_depth_flat * src_R.matmul(K_inv_src.matmul(coord)) + src_t |
|
src_xyz_at_tgt_cam = K_trg.matmul( |
|
tgt_R.transpose(1, 2).matmul(src_xyz - tgt_t)) |
|
tgt_depth_from_src = src_xyz_at_tgt_cam[:, 2, :].view([B, 1, H, W]) |
|
|
|
tgt_depth_flat = tgt_depth.view([B, 1, H*W]) |
|
tgt_xyz = tgt_depth_flat * tgt_R.matmul(K_inv_trg.matmul(coord)) + tgt_t |
|
tgt_xyz_at_src_cam = K_src.matmul( |
|
src_R.transpose(1, 2).matmul(tgt_xyz - src_t)) |
|
src_depth_from_tgt = tgt_xyz_at_src_cam[:, 2, :].view([B, 1, H, W]) |
|
return tgt_depth_from_src, src_depth_from_tgt |
|
|
|
|
|
|
|
def warp_by_disp(src_R, src_t, tgt_R, tgt_t, K, src_disp, coord, inv_K, debug_mode=False, use_depth=False): |
|
|
|
if debug_mode: |
|
B, C, H, W = src_disp.shape |
|
relative_R, relative_t = get_relative_transform( |
|
src_R, src_t, tgt_R, tgt_t) |
|
|
|
print(relative_t.shape) |
|
H_mat = K.matmul(relative_R.matmul(inv_K)) |
|
flat_disp = src_disp.view([B, 1, H * W]) |
|
relative_t_flat = relative_t.expand([-1, -1, H*W]) |
|
rot_coord = torch.matmul(H_mat, coord) |
|
tr_coord = flat_disp * \ |
|
torch.matmul(K, relative_t_flat) |
|
tgt_coord = rot_coord + tr_coord |
|
normalization_factor = (tgt_coord[:, 2:, :] + 1e-6) |
|
rot_coord_normalized = rot_coord / normalization_factor |
|
tr_coord_normalized = tr_coord / normalization_factor |
|
tgt_coord_normalized = rot_coord_normalized + tr_coord_normalized |
|
debug_info = {} |
|
debug_info['tr_coord_normalized'] = tr_coord_normalized |
|
debug_info['rot_coord_normalized'] = rot_coord_normalized |
|
debug_info['tgt_coord_normalized'] = tgt_coord_normalized |
|
debug_info['tr_coord'] = tr_coord |
|
debug_info['rot_coord'] = rot_coord |
|
debug_info['normalization_factor'] = normalization_factor |
|
debug_info['relative_t_flat'] = relative_t_flat |
|
return (tgt_coord_normalized - coord).view([B, 3, H, W]), debug_info |
|
else: |
|
B, C, H, W = src_disp.shape |
|
relative_R, relative_t = get_relative_transform( |
|
src_R, src_t, tgt_R, tgt_t) |
|
H_mat = K.matmul(relative_R.matmul(inv_K)) |
|
flat_disp = src_disp.view([B, 1, H * W]) |
|
if use_depth: |
|
tgt_coord = flat_disp * torch.matmul(H_mat, coord) + \ |
|
torch.matmul(K, relative_t) |
|
else: |
|
tgt_coord = torch.matmul(H_mat, coord) + flat_disp * \ |
|
torch.matmul(K, relative_t) |
|
tgt_coord = tgt_coord / (tgt_coord[:, -1:, :] + 1e-6) |
|
return (tgt_coord - coord).view([B, 3, H, W]), tgt_coord |
|
|
|
|
|
def unproject_depth(depth, K_inv, R, t, coord): |
|
|
|
B, _, H, W = depth.shape |
|
disp_flat = depth.view([B, 1, H * W]) |
|
xyz = disp_flat * R.matmul(K_inv.matmul(coord)) + t |
|
return xyz.reshape([B, 3, H, W]) |
|
|
|
|
|
@torch.jit.script |
|
def _so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001): |
|
""" |
|
A helper function that computes the so3 exponential map and, |
|
apart from the rotation matrix, also returns intermediate variables |
|
that can be re-used in other functions. |
|
""" |
|
_, dim = log_rot.shape |
|
|
|
|
|
|
|
nrms = (log_rot * log_rot).sum(1) |
|
|
|
rot_angles = torch.clamp(nrms, eps).sqrt() |
|
rot_angles_inv = 1.0 / rot_angles |
|
fac1 = rot_angles_inv * rot_angles.sin() |
|
fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos()) |
|
skews = hat(log_rot) |
|
skews_square = torch.bmm(skews, skews) |
|
|
|
R = ( |
|
|
|
fac1[:, None, None] * skews |
|
+ fac2[:, None, None] * skews_square |
|
+ torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None] |
|
) |
|
|
|
return R, rot_angles, skews, skews_square |
|
|
|
|
|
def quaternion_to_matrix(quaternions): |
|
""" |
|
Convert rotations given as quaternions to rotation matrices. |
|
|
|
Args: |
|
quaternions: quaternions with real part first, |
|
as tensor of shape (..., 4). |
|
|
|
Returns: |
|
Rotation matrices as tensor of shape (..., 3, 3). |
|
""" |
|
r, i, j, k = torch.unbind(quaternions, -1) |
|
two_s = 2.0 / (quaternions * quaternions).sum(-1) |
|
|
|
o = torch.stack( |
|
( |
|
1 - two_s * (j * j + k * k), |
|
two_s * (i * j - k * r), |
|
two_s * (i * k + j * r), |
|
two_s * (i * j + k * r), |
|
1 - two_s * (i * i + k * k), |
|
two_s * (j * k - i * r), |
|
two_s * (i * k - j * r), |
|
two_s * (j * k + i * r), |
|
1 - two_s * (i * i + j * j), |
|
), |
|
-1, |
|
) |
|
return o.reshape(quaternions.shape[:-1] + (3, 3)) |
|
|
|
|
|
class CameraPoseDeltaCollection(torch.nn.Module): |
|
def __init__(self, number_of_points=10) -> None: |
|
super().__init__() |
|
zero_rotation = torch.ones([1, 3]) * 1e-3 |
|
zero_translation = torch.zeros([1, 3, 1]) + 1e-4 |
|
for n in range(number_of_points): |
|
self.register_parameter( |
|
f"delta_rotation_{n}", nn.Parameter(zero_rotation)) |
|
self.register_parameter( |
|
f"delta_translation_{n}", nn.Parameter(zero_translation) |
|
) |
|
self.register_buffer("zero_rotation", torch.eye(3)[None, ...]) |
|
self.register_buffer("zero_translation", torch.zeros([1, 3, 1])) |
|
self.traced_so3_exp_map = None |
|
self.number_of_points = number_of_points |
|
|
|
def get_rotation_and_translation_params(self): |
|
rotation_params = [] |
|
translation_params = [] |
|
for n in range(self.number_of_points): |
|
rotation_params.append(getattr(self, f"delta_rotation_{n}")) |
|
translation_params.append(getattr(self, f"delta_translation_{n}")) |
|
return rotation_params, translation_params |
|
|
|
def set_rotation_and_translation(self, index, rotaion_so3, translation): |
|
delta_rotation = getattr(self, f"delta_rotation_{index}") |
|
delta_translation = getattr(self, f"delta_translation_{index}") |
|
delta_rotation.data = rotaion_so3.detach().clone() |
|
delta_translation.data = translation.detach().clone() |
|
|
|
def set_first_frame_pose(self, R, t): |
|
self.zero_rotation.data = R.detach().clone().reshape([1, 3, 3]) |
|
self.zero_translation.data = t.detach().clone().reshape([1, 3, 1]) |
|
|
|
def get_raw_value(self, index): |
|
so3 = getattr(self, f"delta_rotation_{index}") |
|
translation = getattr(self, f"delta_translation_{index}") |
|
return so3, translation |
|
|
|
def forward(self, list_of_index): |
|
se_3 = [] |
|
t_out = [] |
|
for idx in list_of_index: |
|
delta_rotation, delta_translation = self.get_raw_value(idx) |
|
se_3.append(delta_rotation) |
|
t_out.append(delta_translation) |
|
se_3 = torch.cat(se_3, dim=0) |
|
t_out = torch.cat(t_out, dim=0) |
|
if self.traced_so3_exp_map is None: |
|
self.traced_so3_exp_map = torch.jit.trace( |
|
_so3_exp_map, (se_3,)) |
|
R_out = _so3_exp_map(se_3)[0] |
|
return R_out, t_out |
|
|
|
def forward_index(self, index): |
|
|
|
|
|
|
|
delta_rotation, delta_translation = self.get_raw_value(index) |
|
if self.traced_so3_exp_map is None: |
|
self.traced_so3_exp_map = torch.jit.trace( |
|
_so3_exp_map, (delta_rotation,)) |
|
R = _so3_exp_map(delta_rotation)[0] |
|
return R, delta_translation |
|
|
|
|
|
class DepthScaleShiftCollection(torch.nn.Module): |
|
def __init__(self, n_points=10, use_inverse=False, grid_size=1): |
|
super().__init__() |
|
self.grid_size = grid_size |
|
for n in range(n_points): |
|
self.register_parameter( |
|
f"shift_{n}", nn.Parameter(torch.FloatTensor([0.0])) |
|
) |
|
self.register_parameter( |
|
f"scale_{n}", nn.Parameter( |
|
torch.ones([1, 1, grid_size, grid_size])) |
|
) |
|
|
|
self.use_inverse = use_inverse |
|
self.output_shape = None |
|
|
|
def set_outputshape(self, output_shape): |
|
self.output_shape = output_shape |
|
|
|
def forward(self, index): |
|
shift = getattr(self, f"shift_{index}") |
|
scale = getattr(self, f"scale_{index}") |
|
if self.use_inverse: |
|
scale = torch.exp(scale) |
|
if self.grid_size != 1: |
|
scale = F.interpolate(scale, self.output_shape, |
|
mode='bilinear', align_corners=True) |
|
return scale, shift |
|
|
|
def set_scale(self, index, scale): |
|
scale_param = getattr(self, f"scale_{index}") |
|
if self.use_inverse: |
|
scale = math.log(scale) |
|
scale_param.data.fill_(scale) |
|
|
|
def get_scale_data(self, index): |
|
scale = getattr(self, f"scale_{index}").data |
|
if self.use_inverse: |
|
scale = torch.exp(scale) |
|
if self.grid_size != 1: |
|
scale = F.interpolate(scale, self.output_shape, |
|
mode='bilinear', align_corners=True) |
|
return scale |
|
|
|
|
|
def check_R_shape(R): |
|
r0, r1, r2 = R.shape |
|
assert r1 == 3 and r2 == 3 |
|
|
|
|
|
def check_t_shape(t): |
|
t0, t1, t2 = t.shape |
|
assert t1 == 3 and t2 == 1 |
|
|
|
|
|
class DepthBasedWarping(nn.Module): |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def generate_grid(self, H, W, device): |
|
yy, xx = torch.meshgrid( |
|
torch.arange(H, device=device, dtype=torch.float32), |
|
torch.arange(W, device=device, dtype=torch.float32), |
|
) |
|
self.coord = torch.ones( |
|
[1, 3, H, W], device=device, dtype=torch.float32) |
|
self.coord[0, 0, ...] = xx |
|
self.coord[0, 1, ...] = yy |
|
self.coord = self.coord.reshape([1, 3, H * W]) |
|
self.jitted_warp_by_disp = None |
|
|
|
def reproject_depth(self, src_R, src_t, src_disp, tgt_R, tgt_t, tgt_disp, K_src, K_inv_src, K_trg, K_inv_trg, eps=1e-6, check_shape=False): |
|
if check_shape: |
|
check_R_shape(src_R) |
|
check_R_shape(tgt_R) |
|
check_t_shape(src_t) |
|
check_t_shape(tgt_t) |
|
check_t_shape(src_disp) |
|
check_t_shape(tgt_disp) |
|
device = src_disp.device |
|
B, _, H, W = src_disp.shape |
|
if not hasattr(self, "coord"): |
|
self.generate_grid(src_disp.shape[2], src_disp.shape[3], device) |
|
else: |
|
if self.coord.shape[-1] != H * W: |
|
del self.coord |
|
self.generate_grid(H, W, device) |
|
return reproject_depth(src_R, src_t, src_disp, tgt_R, tgt_t, tgt_disp, K_src, K_inv_src, K_trg, K_inv_trg, self.coord, eps=eps) |
|
|
|
def unproject_depth(self, disp, R, t, K_inv, eps=1e-6, check_shape=False): |
|
if check_shape: |
|
check_R_shape(R) |
|
check_R_shape(t) |
|
|
|
_, _, H, W = disp.shape |
|
B = R.shape[0] |
|
device = disp.device |
|
if not hasattr(self, "coord"): |
|
self.generate_grid(H, W, device=device) |
|
else: |
|
if self.coord.shape[-1] != H * W: |
|
del self.coord |
|
self.generate_grid(H, W, device=device) |
|
|
|
|
|
|
|
return unproject_depth(1 / (disp + eps), K_inv, R, t, self.coord) |
|
|
|
def forward( |
|
self, |
|
src_R, |
|
src_t, |
|
tgt_R, |
|
tgt_t, |
|
src_disp, |
|
K, |
|
inv_K, |
|
eps=1e-6, |
|
use_depth=False, |
|
check_shape=False, |
|
debug_mode=False, |
|
): |
|
"""warp the current depth frame and generate flow field. |
|
|
|
Args: |
|
src_R (FloatTensor): 1x3x3 |
|
src_t (FloatTensor): 1x3x1 |
|
tgt_R (FloatTensor): Nx3x3 |
|
tgt_t (FloatTensor): Nx3x1 |
|
src_disp (FloatTensor): Nx1XHxW |
|
src_K (FloatTensor): 1x3x3 |
|
""" |
|
if check_shape: |
|
check_R_shape(src_R) |
|
check_R_shape(tgt_R) |
|
check_t_shape(src_t) |
|
check_t_shape(tgt_t) |
|
|
|
_, _, H, W = src_disp.shape |
|
B = tgt_R.shape[0] |
|
device = src_disp.device |
|
if not hasattr(self, "coord"): |
|
self.generate_grid(H, W, device=device) |
|
else: |
|
if self.coord.shape[-1] != H * W: |
|
del self.coord |
|
self.generate_grid(H, W, device=device) |
|
|
|
|
|
|
|
|
|
return warp_by_disp(src_R, src_t, tgt_R, tgt_t, K, src_disp, self.coord, inv_K, debug_mode, use_depth) |
|
|
|
|
|
class DepthToXYZ(nn.Module): |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def generate_grid(self, H, W, device): |
|
yy, xx = torch.meshgrid( |
|
torch.arange(H, device=device, dtype=torch.float32), |
|
torch.arange(W, device=device, dtype=torch.float32), |
|
) |
|
self.coord = torch.ones( |
|
[1, 3, H, W], device=device, dtype=torch.float32) |
|
self.coord[0, 0, ...] = xx |
|
self.coord[0, 1, ...] = yy |
|
self.coord = self.coord.reshape([1, 3, H * W]) |
|
|
|
def forward(self, disp, K_inv, R, t, eps=1e-6, check_shape=False): |
|
"""warp the current depth frame and generate flow field. |
|
|
|
Args: |
|
src_R (FloatTensor): 1x3x3 |
|
src_t (FloatTensor): 1x3x1 |
|
tgt_R (FloatTensor): Nx3x3 |
|
tgt_t (FloatTensor): Nx3x1 |
|
src_disp (FloatTensor): Nx1XHxW |
|
src_K (FloatTensor): 1x3x3 |
|
""" |
|
if check_shape: |
|
check_R_shape(R) |
|
check_R_shape(t) |
|
|
|
_, _, H, W = disp.shape |
|
B = R.shape[0] |
|
device = disp.device |
|
if not hasattr(self, "coord"): |
|
self.generate_grid(H, W, device=device) |
|
else: |
|
if self.coord.shape[-1] != H * W: |
|
del self.coord |
|
self.generate_grid(H, W, device=device) |
|
|
|
|
|
|
|
|
|
return unproject_depth(1 / (disp + eps), K_inv, R, t, self.coord) |
|
|
|
class OccMask(torch.nn.Module): |
|
def __init__(self, th=3): |
|
super(OccMask, self).__init__() |
|
self.th = th |
|
self.base_coord = None |
|
|
|
def init_grid(self, shape, device): |
|
H, W = shape |
|
hh, ww = torch.meshgrid(torch.arange( |
|
H).float(), torch.arange(W).float()) |
|
coord = torch.zeros([1, H, W, 2]) |
|
coord[0, ..., 0] = ww |
|
coord[0, ..., 1] = hh |
|
self.base_coord = coord.to(device) |
|
self.W = W |
|
self.H = H |
|
|
|
@torch.no_grad() |
|
def get_oob_mask(self, base_coord, flow_1_2): |
|
target_range = base_coord + flow_1_2.permute([0, 2, 3, 1]) |
|
oob_mask = (target_range[..., 0] < 0) | (target_range[..., 0] > self.W-1) | ( |
|
target_range[..., 1] < 0) | (target_range[..., 1] > self.H-1) |
|
return ~oob_mask[:, None, ...] |
|
|
|
@torch.no_grad() |
|
def get_flow_inconsistency_tensor(self, base_coord, flow_1_2, flow_2_1): |
|
B, C, H, W = flow_1_2.shape |
|
sample_grids = base_coord + flow_1_2.permute([0, 2, 3, 1]) |
|
sample_grids[..., 0] /= (W - 1) / 2 |
|
sample_grids[..., 1] /= (H - 1) / 2 |
|
sample_grids -= 1 |
|
sampled_flow = F.grid_sample( |
|
flow_2_1, sample_grids, align_corners=True) |
|
return torch.abs((sampled_flow+flow_1_2).sum(1, keepdim=True)) |
|
|
|
def forward(self, flow_1_2, flow_2_1): |
|
B, _, H, W = flow_1_2.shape |
|
if self.base_coord is None: |
|
self.init_grid([H, W], device=flow_1_2.device) |
|
base_coord = self.base_coord.expand([B, -1, -1, -1]) |
|
oob_mask = self.get_oob_mask(base_coord, flow_1_2) |
|
flow_inconsistency_tensor = self.get_flow_inconsistency_tensor( |
|
base_coord, flow_1_2, flow_2_1) |
|
valid_flow_mask = flow_inconsistency_tensor < self.th |
|
return valid_flow_mask*oob_mask |