Vincentqyw
add: rord libs
9fb6531
raw
history blame
10.7 kB
import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import torch
import torch.nn.functional as F
from lib.utils import (
grid_positions,
upscale_positions,
downscale_positions,
savefig,
imshow_image
)
from lib.exceptions import NoGradientError, EmptyTensorError
matplotlib.use('Agg')
def loss_function(
model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False, plot_path=None
):
output = model({
'image1': batch['image1'].to(device),
'image2': batch['image2'].to(device)
})
loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
has_grad = False
n_valid_samples = 0
for idx_in_batch in range(batch['image1'].size(0)):
# Annotations
depth1 = batch['depth1'][idx_in_batch].to(device) # [h1, w1]
intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device) # [3, 3]
pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device) # [4, 4]
bbox1 = batch['bbox1'][idx_in_batch].to(device) # [2]
depth2 = batch['depth2'][idx_in_batch].to(device)
intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
bbox2 = batch['bbox2'][idx_in_batch].to(device)
# Network output
dense_features1 = output['dense_features1'][idx_in_batch]
c, h1, w1 = dense_features1.size()
scores1 = output['scores1'][idx_in_batch].view(-1)
dense_features2 = output['dense_features2'][idx_in_batch]
_, h2, w2 = dense_features2.size()
scores2 = output['scores2'][idx_in_batch]
all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
descriptors1 = all_descriptors1
all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)
# Warp the positions from image 1 to image 2
fmap_pos1 = grid_positions(h1, w1, device)
pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps)
try:
pos1, pos2, ids = warp(
pos1,
depth1, intrinsics1, pose1, bbox1,
depth2, intrinsics2, pose2, bbox2
)
except EmptyTensorError:
continue
fmap_pos1 = fmap_pos1[:, ids]
descriptors1 = descriptors1[:, ids]
scores1 = scores1[ids]
# Skip the pair if not enough GT correspondences are available
if ids.size(0) < 128:
continue
# Descriptors at the corresponding positions
fmap_pos2 = torch.round(
downscale_positions(pos2, scaling_steps=scaling_steps)
).long()
descriptors2 = F.normalize(
dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
dim=0
)
positive_distance = 2 - 2 * (
descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
).squeeze()
all_fmap_pos2 = grid_positions(h2, w2, device)
position_distance = torch.max(
torch.abs(
fmap_pos2.unsqueeze(2).float() -
all_fmap_pos2.unsqueeze(1)
),
dim=0
)[0]
is_out_of_safe_radius = position_distance > safe_radius
distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
negative_distance2 = torch.min(
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
dim=1
)[0]
all_fmap_pos1 = grid_positions(h1, w1, device)
position_distance = torch.max(
torch.abs(
fmap_pos1.unsqueeze(2).float() -
all_fmap_pos1.unsqueeze(1)
),
dim=0
)[0]
is_out_of_safe_radius = position_distance > safe_radius
distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
negative_distance1 = torch.min(
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
dim=1
)[0]
diff = positive_distance - torch.min(
negative_distance1, negative_distance2
)
scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]
loss = loss + (
torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
torch.sum(scores1 * scores2)
)
has_grad = True
n_valid_samples += 1
# print(plot, batch['batch_idx'],batch['log_interval'])
if plot and batch['batch_idx'] % batch['log_interval'] == 0:
# print("should plot")
pos1_aux = pos1.cpu().numpy()
pos2_aux = pos2.cpu().numpy()
k = pos1_aux.shape[1]
col = np.random.rand(k, 3)
n_sp = 4
plt.figure()
plt.subplot(1, n_sp, 1)
im1 = imshow_image(
batch['image1'][idx_in_batch].cpu().numpy(),
preprocessing=batch['preprocessing']
)
plt.imshow(im1)
plt.scatter(
pos1_aux[1, :], pos1_aux[0, :],
s=0.25**2, c=col, marker=',', alpha=0.5
)
plt.axis('off')
plt.subplot(1, n_sp, 2)
plt.imshow(
output['scores1'][idx_in_batch].data.cpu().numpy(),
cmap='Reds'
)
plt.axis('off')
plt.subplot(1, n_sp, 3)
im2 = imshow_image(
batch['image2'][idx_in_batch].cpu().numpy(),
preprocessing=batch['preprocessing']
)
plt.imshow(im2)
plt.scatter(
pos2_aux[1, :], pos2_aux[0, :],
s=0.25**2, c=col, marker=',', alpha=0.5
)
plt.axis('off')
plt.subplot(1, n_sp, 4)
plt.imshow(
output['scores2'][idx_in_batch].data.cpu().numpy(),
cmap='Reds'
)
plt.axis('off')
savefig(os.path.join(plot_path, '%s.%02d.%02d.%d.png' % (
'train' if batch['train'] else 'valid',
batch['epoch_idx'],
batch['batch_idx'] // batch['log_interval'],
idx_in_batch
)), dpi=300)
plt.close()
if not has_grad:
raise NoGradientError
loss = loss / n_valid_samples
return loss
def interpolate_depth(pos, depth):
device = pos.device
ids = torch.arange(0, pos.size(1), device=device)
h, w = depth.size()
i = pos[0, :]
j = pos[1, :]
# Valid corners
i_top_left = torch.floor(i).long()
j_top_left = torch.floor(j).long()
valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
i_top_right = torch.floor(i).long()
j_top_right = torch.ceil(j).long()
valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
i_bottom_left = torch.ceil(i).long()
j_bottom_left = torch.floor(j).long()
valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
i_bottom_right = torch.ceil(i).long()
j_bottom_right = torch.ceil(j).long()
valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
valid_corners = torch.min(
torch.min(valid_top_left, valid_top_right),
torch.min(valid_bottom_left, valid_bottom_right)
)
i_top_left = i_top_left[valid_corners]
j_top_left = j_top_left[valid_corners]
i_top_right = i_top_right[valid_corners]
j_top_right = j_top_right[valid_corners]
i_bottom_left = i_bottom_left[valid_corners]
j_bottom_left = j_bottom_left[valid_corners]
i_bottom_right = i_bottom_right[valid_corners]
j_bottom_right = j_bottom_right[valid_corners]
ids = ids[valid_corners]
if ids.size(0) == 0:
raise EmptyTensorError
# Valid depth
valid_depth = torch.min(
torch.min(
depth[i_top_left, j_top_left] > 0,
depth[i_top_right, j_top_right] > 0
),
torch.min(
depth[i_bottom_left, j_bottom_left] > 0,
depth[i_bottom_right, j_bottom_right] > 0
)
)
i_top_left = i_top_left[valid_depth]
j_top_left = j_top_left[valid_depth]
i_top_right = i_top_right[valid_depth]
j_top_right = j_top_right[valid_depth]
i_bottom_left = i_bottom_left[valid_depth]
j_bottom_left = j_bottom_left[valid_depth]
i_bottom_right = i_bottom_right[valid_depth]
j_bottom_right = j_bottom_right[valid_depth]
ids = ids[valid_depth]
if ids.size(0) == 0:
raise EmptyTensorError
# Interpolation
i = i[ids]
j = j[ids]
dist_i_top_left = i - i_top_left.float()
dist_j_top_left = j - j_top_left.float()
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
w_top_right = (1 - dist_i_top_left) * dist_j_top_left
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
w_bottom_right = dist_i_top_left * dist_j_top_left
interpolated_depth = (
w_top_left * depth[i_top_left, j_top_left] +
w_top_right * depth[i_top_right, j_top_right] +
w_bottom_left * depth[i_bottom_left, j_bottom_left] +
w_bottom_right * depth[i_bottom_right, j_bottom_right]
)
pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
return [interpolated_depth, pos, ids]
def uv_to_pos(uv):
return torch.cat([uv[1, :].view(1, -1), uv[0, :].view(1, -1)], dim=0)
def warp(
pos1,
depth1, intrinsics1, pose1, bbox1,
depth2, intrinsics2, pose2, bbox2
):
device = pos1.device
Z1, pos1, ids = interpolate_depth(pos1, depth1)
# COLMAP convention
u1 = pos1[1, :] + bbox1[1] + .5
v1 = pos1[0, :] + bbox1[0] + .5
X1 = (u1 - intrinsics1[0, 2]) * (Z1 / intrinsics1[0, 0])
Y1 = (v1 - intrinsics1[1, 2]) * (Z1 / intrinsics1[1, 1])
XYZ1_hom = torch.cat([
X1.view(1, -1),
Y1.view(1, -1),
Z1.view(1, -1),
torch.ones(1, Z1.size(0), device=device)
], dim=0)
XYZ2_hom = torch.chain_matmul(pose2, torch.inverse(pose1), XYZ1_hom)
XYZ2 = XYZ2_hom[: -1, :] / XYZ2_hom[-1, :].view(1, -1)
uv2_hom = torch.matmul(intrinsics2, XYZ2)
uv2 = uv2_hom[: -1, :] / uv2_hom[-1, :].view(1, -1)
u2 = uv2[0, :] - bbox2[1] - .5
v2 = uv2[1, :] - bbox2[0] - .5
uv2 = torch.cat([u2.view(1, -1), v2.view(1, -1)], dim=0)
annotated_depth, pos2, new_ids = interpolate_depth(uv_to_pos(uv2), depth2)
ids = ids[new_ids]
pos1 = pos1[:, new_ids]
estimated_depth = XYZ2[2, new_ids]
inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05
ids = ids[inlier_mask]
if ids.size(0) == 0:
raise EmptyTensorError
pos2 = pos2[:, inlier_mask]
pos1 = pos1[:, inlier_mask]
return pos1, pos2, ids