|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from lib.utils import ( |
|
grid_positions, |
|
upscale_positions, |
|
downscale_positions, |
|
savefig, |
|
imshow_image |
|
) |
|
from lib.exceptions import NoGradientError, EmptyTensorError |
|
|
|
matplotlib.use('Agg') |
|
|
|
|
|
def loss_function( |
|
model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False |
|
): |
|
output = model({ |
|
'image1': batch['image1'].to(device), |
|
'image2': batch['image2'].to(device) |
|
}) |
|
|
|
loss = torch.tensor(np.array([0], dtype=np.float32), device=device) |
|
has_grad = False |
|
|
|
n_valid_samples = 0 |
|
for idx_in_batch in range(batch['image1'].size(0)): |
|
|
|
depth1 = batch['depth1'][idx_in_batch].to(device) |
|
intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device) |
|
pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device) |
|
bbox1 = batch['bbox1'][idx_in_batch].to(device) |
|
|
|
depth2 = batch['depth2'][idx_in_batch].to(device) |
|
intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device) |
|
pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device) |
|
bbox2 = batch['bbox2'][idx_in_batch].to(device) |
|
|
|
|
|
dense_features1 = output['dense_features1'][idx_in_batch] |
|
c, h1, w1 = dense_features1.size() |
|
scores1 = output['scores1'][idx_in_batch].view(-1) |
|
|
|
dense_features2 = output['dense_features2'][idx_in_batch] |
|
_, h2, w2 = dense_features2.size() |
|
scores2 = output['scores2'][idx_in_batch] |
|
|
|
all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0) |
|
descriptors1 = all_descriptors1 |
|
|
|
all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0) |
|
|
|
|
|
fmap_pos1 = grid_positions(h1, w1, device) |
|
pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps) |
|
try: |
|
pos1, pos2, ids = warp( |
|
pos1, |
|
depth1, intrinsics1, pose1, bbox1, |
|
depth2, intrinsics2, pose2, bbox2 |
|
) |
|
except EmptyTensorError: |
|
continue |
|
fmap_pos1 = fmap_pos1[:, ids] |
|
descriptors1 = descriptors1[:, ids] |
|
scores1 = scores1[ids] |
|
|
|
|
|
if ids.size(0) < 128: |
|
continue |
|
|
|
|
|
fmap_pos2 = torch.round( |
|
downscale_positions(pos2, scaling_steps=scaling_steps) |
|
).long() |
|
descriptors2 = F.normalize( |
|
dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]], |
|
dim=0 |
|
) |
|
positive_distance = 2 - 2 * ( |
|
descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2) |
|
).squeeze() |
|
|
|
all_fmap_pos2 = grid_positions(h2, w2, device) |
|
position_distance = torch.max( |
|
torch.abs( |
|
fmap_pos2.unsqueeze(2).float() - |
|
all_fmap_pos2.unsqueeze(1) |
|
), |
|
dim=0 |
|
)[0] |
|
is_out_of_safe_radius = position_distance > safe_radius |
|
distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2) |
|
negative_distance2 = torch.min( |
|
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., |
|
dim=1 |
|
)[0] |
|
|
|
all_fmap_pos1 = grid_positions(h1, w1, device) |
|
position_distance = torch.max( |
|
torch.abs( |
|
fmap_pos1.unsqueeze(2).float() - |
|
all_fmap_pos1.unsqueeze(1) |
|
), |
|
dim=0 |
|
)[0] |
|
is_out_of_safe_radius = position_distance > safe_radius |
|
distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1) |
|
negative_distance1 = torch.min( |
|
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., |
|
dim=1 |
|
)[0] |
|
|
|
diff = positive_distance - torch.min( |
|
negative_distance1, negative_distance2 |
|
) |
|
|
|
scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]] |
|
|
|
loss = loss + ( |
|
torch.sum(scores1 * scores2 * F.relu(margin + diff)) / |
|
torch.sum(scores1 * scores2) |
|
) |
|
|
|
has_grad = True |
|
n_valid_samples += 1 |
|
|
|
if plot and batch['batch_idx'] % batch['log_interval'] == 0: |
|
pos1_aux = pos1.cpu().numpy() |
|
pos2_aux = pos2.cpu().numpy() |
|
k = pos1_aux.shape[1] |
|
col = np.random.rand(k, 3) |
|
n_sp = 4 |
|
plt.figure() |
|
plt.subplot(1, n_sp, 1) |
|
im1 = imshow_image( |
|
batch['image1'][idx_in_batch].cpu().numpy(), |
|
preprocessing=batch['preprocessing'] |
|
) |
|
plt.imshow(im1) |
|
plt.scatter( |
|
pos1_aux[1, :], pos1_aux[0, :], |
|
s=0.25**2, c=col, marker=',', alpha=0.5 |
|
) |
|
plt.axis('off') |
|
plt.subplot(1, n_sp, 2) |
|
plt.imshow( |
|
output['scores1'][idx_in_batch].data.cpu().numpy(), |
|
cmap='Reds' |
|
) |
|
plt.axis('off') |
|
plt.subplot(1, n_sp, 3) |
|
im2 = imshow_image( |
|
batch['image2'][idx_in_batch].cpu().numpy(), |
|
preprocessing=batch['preprocessing'] |
|
) |
|
plt.imshow(im2) |
|
plt.scatter( |
|
pos2_aux[1, :], pos2_aux[0, :], |
|
s=0.25**2, c=col, marker=',', alpha=0.5 |
|
) |
|
plt.axis('off') |
|
plt.subplot(1, n_sp, 4) |
|
plt.imshow( |
|
output['scores2'][idx_in_batch].data.cpu().numpy(), |
|
cmap='Reds' |
|
) |
|
plt.axis('off') |
|
savefig('train_vis/%s.%02d.%02d.%d.png' % ( |
|
'train' if batch['train'] else 'valid', |
|
batch['epoch_idx'], |
|
batch['batch_idx'] // batch['log_interval'], |
|
idx_in_batch |
|
), dpi=300) |
|
plt.close() |
|
|
|
if not has_grad: |
|
raise NoGradientError |
|
|
|
loss = loss / n_valid_samples |
|
|
|
return loss |
|
|
|
|
|
def interpolate_depth(pos, depth): |
|
device = pos.device |
|
|
|
ids = torch.arange(0, pos.size(1), device=device) |
|
|
|
h, w = depth.size() |
|
|
|
i = pos[0, :] |
|
j = pos[1, :] |
|
|
|
|
|
i_top_left = torch.floor(i).long() |
|
j_top_left = torch.floor(j).long() |
|
valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0) |
|
|
|
i_top_right = torch.floor(i).long() |
|
j_top_right = torch.ceil(j).long() |
|
valid_top_right = torch.min(i_top_right >= 0, j_top_right < w) |
|
|
|
i_bottom_left = torch.ceil(i).long() |
|
j_bottom_left = torch.floor(j).long() |
|
valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0) |
|
|
|
i_bottom_right = torch.ceil(i).long() |
|
j_bottom_right = torch.ceil(j).long() |
|
valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w) |
|
|
|
valid_corners = torch.min( |
|
torch.min(valid_top_left, valid_top_right), |
|
torch.min(valid_bottom_left, valid_bottom_right) |
|
) |
|
|
|
i_top_left = i_top_left[valid_corners] |
|
j_top_left = j_top_left[valid_corners] |
|
|
|
i_top_right = i_top_right[valid_corners] |
|
j_top_right = j_top_right[valid_corners] |
|
|
|
i_bottom_left = i_bottom_left[valid_corners] |
|
j_bottom_left = j_bottom_left[valid_corners] |
|
|
|
i_bottom_right = i_bottom_right[valid_corners] |
|
j_bottom_right = j_bottom_right[valid_corners] |
|
|
|
ids = ids[valid_corners] |
|
if ids.size(0) == 0: |
|
raise EmptyTensorError |
|
|
|
|
|
valid_depth = torch.min( |
|
torch.min( |
|
depth[i_top_left, j_top_left] > 0, |
|
depth[i_top_right, j_top_right] > 0 |
|
), |
|
torch.min( |
|
depth[i_bottom_left, j_bottom_left] > 0, |
|
depth[i_bottom_right, j_bottom_right] > 0 |
|
) |
|
) |
|
|
|
i_top_left = i_top_left[valid_depth] |
|
j_top_left = j_top_left[valid_depth] |
|
|
|
i_top_right = i_top_right[valid_depth] |
|
j_top_right = j_top_right[valid_depth] |
|
|
|
i_bottom_left = i_bottom_left[valid_depth] |
|
j_bottom_left = j_bottom_left[valid_depth] |
|
|
|
i_bottom_right = i_bottom_right[valid_depth] |
|
j_bottom_right = j_bottom_right[valid_depth] |
|
|
|
ids = ids[valid_depth] |
|
if ids.size(0) == 0: |
|
raise EmptyTensorError |
|
|
|
|
|
i = i[ids] |
|
j = j[ids] |
|
dist_i_top_left = i - i_top_left.float() |
|
dist_j_top_left = j - j_top_left.float() |
|
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) |
|
w_top_right = (1 - dist_i_top_left) * dist_j_top_left |
|
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) |
|
w_bottom_right = dist_i_top_left * dist_j_top_left |
|
|
|
interpolated_depth = ( |
|
w_top_left * depth[i_top_left, j_top_left] + |
|
w_top_right * depth[i_top_right, j_top_right] + |
|
w_bottom_left * depth[i_bottom_left, j_bottom_left] + |
|
w_bottom_right * depth[i_bottom_right, j_bottom_right] |
|
) |
|
|
|
pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0) |
|
|
|
return [interpolated_depth, pos, ids] |
|
|
|
|
|
def uv_to_pos(uv): |
|
return torch.cat([uv[1, :].view(1, -1), uv[0, :].view(1, -1)], dim=0) |
|
|
|
|
|
def warp( |
|
pos1, |
|
depth1, intrinsics1, pose1, bbox1, |
|
depth2, intrinsics2, pose2, bbox2 |
|
): |
|
device = pos1.device |
|
|
|
Z1, pos1, ids = interpolate_depth(pos1, depth1) |
|
|
|
|
|
u1 = pos1[1, :] + bbox1[1] + .5 |
|
v1 = pos1[0, :] + bbox1[0] + .5 |
|
|
|
X1 = (u1 - intrinsics1[0, 2]) * (Z1 / intrinsics1[0, 0]) |
|
Y1 = (v1 - intrinsics1[1, 2]) * (Z1 / intrinsics1[1, 1]) |
|
|
|
XYZ1_hom = torch.cat([ |
|
X1.view(1, -1), |
|
Y1.view(1, -1), |
|
Z1.view(1, -1), |
|
torch.ones(1, Z1.size(0), device=device) |
|
], dim=0) |
|
XYZ2_hom = torch.chain_matmul(pose2, torch.inverse(pose1), XYZ1_hom) |
|
XYZ2 = XYZ2_hom[: -1, :] / XYZ2_hom[-1, :].view(1, -1) |
|
|
|
uv2_hom = torch.matmul(intrinsics2, XYZ2) |
|
uv2 = uv2_hom[: -1, :] / uv2_hom[-1, :].view(1, -1) |
|
|
|
u2 = uv2[0, :] - bbox2[1] - .5 |
|
v2 = uv2[1, :] - bbox2[0] - .5 |
|
uv2 = torch.cat([u2.view(1, -1), v2.view(1, -1)], dim=0) |
|
|
|
annotated_depth, pos2, new_ids = interpolate_depth(uv_to_pos(uv2), depth2) |
|
|
|
ids = ids[new_ids] |
|
pos1 = pos1[:, new_ids] |
|
estimated_depth = XYZ2[2, new_ids] |
|
|
|
inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05 |
|
|
|
ids = ids[inlier_mask] |
|
if ids.size(0) == 0: |
|
raise EmptyTensorError |
|
|
|
pos2 = pos2[:, inlier_mask] |
|
pos1 = pos1[:, inlier_mask] |
|
|
|
return pos1, pos2, ids |
|
|