Vincentqyw
add: rord libs
9fb6531
raw
history blame
5.72 kB
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import cv2
from sys import exit
import torch
import torch.nn.functional as F
from lib.utils import (
grid_positions,
upscale_positions,
downscale_positions,
savefig,
imshow_image
)
from lib.exceptions import NoGradientError, EmptyTensorError
matplotlib.use('Agg')
def loss_function(
model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False, plot_path=None
):
output = model({
'image1': batch['image1'].to(device),
'image2': batch['image2'].to(device)
})
loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
has_grad = False
n_valid_samples = 0
for idx_in_batch in range(batch['image1'].size(0)):
# Network output
dense_features1 = output['dense_features1'][idx_in_batch]
c, h1, w1 = dense_features1.size()
scores1 = output['scores1'][idx_in_batch].view(-1)
dense_features2 = output['dense_features2'][idx_in_batch]
_, h2, w2 = dense_features2.size()
scores2 = output['scores2'][idx_in_batch]
all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
descriptors1 = all_descriptors1
all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)
fmap_pos1 = grid_positions(h1, w1, device)
pos1 = batch['pos1'][idx_in_batch].to(device)
pos2 = batch['pos2'][idx_in_batch].to(device)
ids = idsAlign(pos1, device, h1, w1)
fmap_pos1 = fmap_pos1[:, ids]
descriptors1 = descriptors1[:, ids]
scores1 = scores1[ids]
# Skip the pair if not enough GT correspondences are available
if ids.size(0) < 128:
continue
# Descriptors at the corresponding positions
fmap_pos2 = torch.round(
downscale_positions(pos2, scaling_steps=scaling_steps)
).long()
descriptors2 = F.normalize(
dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
dim=0
)
positive_distance = 2 - 2 * (
descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
).squeeze()
all_fmap_pos2 = grid_positions(h2, w2, device)
position_distance = torch.max(
torch.abs(
fmap_pos2.unsqueeze(2).float() -
all_fmap_pos2.unsqueeze(1)
),
dim=0
)[0]
is_out_of_safe_radius = position_distance > safe_radius
distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
negative_distance2 = torch.min(
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
dim=1
)[0]
all_fmap_pos1 = grid_positions(h1, w1, device)
position_distance = torch.max(
torch.abs(
fmap_pos1.unsqueeze(2).float() -
all_fmap_pos1.unsqueeze(1)
),
dim=0
)[0]
is_out_of_safe_radius = position_distance > safe_radius
distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
negative_distance1 = torch.min(
distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
dim=1
)[0]
diff = positive_distance - torch.min(
negative_distance1, negative_distance2
)
scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]
loss = loss + (
torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
(torch.sum(scores1 * scores2) )
)
has_grad = True
n_valid_samples += 1
if plot and batch['batch_idx'] % batch['log_interval'] == 0:
drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=True, plot_path=plot_path)
if not has_grad:
raise NoGradientError
loss = loss / (n_valid_samples )
return loss
def idsAlign(pos1, device, h1, w1):
pos1D = downscale_positions(pos1, scaling_steps=3)
row = pos1D[0, :]
col = pos1D[1, :]
ids = []
for i in range(row.shape[0]):
index = ((w1) * (row[i])) + (col[i])
ids.append(index)
ids = torch.round(torch.Tensor(ids)).long().to(device)
return ids
def drawTraining(image1, image2, pos1, pos2, batch, idx_in_batch, output, save=False, plot_path="train_viz"):
pos1_aux = pos1.cpu().numpy()
pos2_aux = pos2.cpu().numpy()
k = pos1_aux.shape[1]
col = np.random.rand(k, 3)
n_sp = 4
plt.figure()
plt.subplot(1, n_sp, 1)
im1 = imshow_image(
image1[0].cpu().numpy(),
preprocessing=batch['preprocessing']
)
plt.imshow(im1)
plt.scatter(
pos1_aux[1, :], pos1_aux[0, :],
s=0.25**2, c=col, marker=',', alpha=0.5
)
plt.axis('off')
plt.subplot(1, n_sp, 2)
plt.imshow(
output['scores1'][idx_in_batch].data.cpu().numpy(),
cmap='Reds'
)
plt.axis('off')
plt.subplot(1, n_sp, 3)
im2 = imshow_image(
image2[0].cpu().numpy(),
preprocessing=batch['preprocessing']
)
plt.imshow(im2)
plt.scatter(
pos2_aux[1, :], pos2_aux[0, :],
s=0.25**2, c=col, marker=',', alpha=0.5
)
plt.axis('off')
plt.subplot(1, n_sp, 4)
plt.imshow(
output['scores2'][idx_in_batch].data.cpu().numpy(),
cmap='Reds'
)
plt.axis('off')
if(save == True):
savefig(plot_path+'/%s.%02d.%02d.%d.png' % (
'train' if batch['train'] else 'valid',
batch['epoch_idx'],
batch['batch_idx'] // batch['log_interval'],
idx_in_batch
), dpi=300)
else:
plt.show()
plt.close()
im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2RGB)
im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB)
for i in range(0, pos1_aux.shape[1], 5):
im1 = cv2.circle(im1, (pos1_aux[1, i], pos1_aux[0, i]), 1, (0, 0, 255), 2)
for i in range(0, pos2_aux.shape[1], 5):
im2 = cv2.circle(im2, (pos2_aux[1, i], pos2_aux[0, i]), 1, (0, 0, 255), 2)
im3 = cv2.hconcat([im1, im2])
for i in range(0, pos1_aux.shape[1], 5):
im3 = cv2.line(im3, (int(pos1_aux[1, i]), int(pos1_aux[0, i])), (int(pos2_aux[1, i]) + im1.shape[1], int(pos2_aux[0, i])), (0, 255, 0), 1)
if(save == True):
cv2.imwrite(plot_path+'/%s.%02d.%02d.%d.png' % (
'train_corr' if batch['train'] else 'valid',
batch['epoch_idx'],
batch['batch_idx'] // batch['log_interval'],
idx_in_batch
), im3)
else:
cv2.imshow('Image', im3)
cv2.waitKey(0)