|
import torch |
|
import torch.nn.functional as F |
|
|
|
from .geom import rnd_sample, interpolate, get_dist_mat |
|
|
|
|
|
def make_detector_loss(pos0, pos1, dense_feat_map0, dense_feat_map1, |
|
score_map0, score_map1, batch_size, num_corr, loss_type, config): |
|
joint_loss = 0. |
|
accuracy = 0. |
|
all_valid_pos0 = [] |
|
all_valid_pos1 = [] |
|
all_valid_match = [] |
|
for i in range(batch_size): |
|
|
|
valid_pos0, valid_pos1 = rnd_sample([pos0[i], pos1[i]], num_corr) |
|
valid_num = valid_pos0.shape[0] |
|
|
|
valid_feat0 = interpolate(valid_pos0 / 4, dense_feat_map0[i]) |
|
valid_feat1 = interpolate(valid_pos1 / 4, dense_feat_map1[i]) |
|
|
|
valid_feat0 = F.normalize(valid_feat0, p=2, dim=-1) |
|
valid_feat1 = F.normalize(valid_feat1, p=2, dim=-1) |
|
|
|
valid_score0 = interpolate(valid_pos0, torch.squeeze(score_map0[i], dim=-1), nd=False) |
|
valid_score1 = interpolate(valid_pos1, torch.squeeze(score_map1[i], dim=-1), nd=False) |
|
|
|
if config['network']['det']['corr_weight']: |
|
corr_weight = valid_score0 * valid_score1 |
|
else: |
|
corr_weight = None |
|
|
|
safe_radius = config['network']['det']['safe_radius'] |
|
if safe_radius > 0: |
|
radius_mask_row = get_dist_mat( |
|
valid_pos1, valid_pos1, "euclidean_dist_no_norm") |
|
radius_mask_row = torch.le(radius_mask_row, safe_radius) |
|
radius_mask_col = get_dist_mat( |
|
valid_pos0, valid_pos0, "euclidean_dist_no_norm") |
|
radius_mask_col = torch.le(radius_mask_col, safe_radius) |
|
radius_mask_row = radius_mask_row.float() - torch.eye(valid_num, device=radius_mask_row.device) |
|
radius_mask_col = radius_mask_col.float() - torch.eye(valid_num, device=radius_mask_col.device) |
|
else: |
|
radius_mask_row = None |
|
radius_mask_col = None |
|
|
|
if valid_num < 32: |
|
si_loss, si_accuracy, matched_mask = 0., 1., torch.zeros((1, valid_num)).bool() |
|
else: |
|
si_loss, si_accuracy, matched_mask = make_structured_loss( |
|
torch.unsqueeze(valid_feat0, 0), torch.unsqueeze(valid_feat1, 0), |
|
loss_type=loss_type, |
|
radius_mask_row=radius_mask_row, radius_mask_col=radius_mask_col, |
|
corr_weight=torch.unsqueeze(corr_weight, 0) if corr_weight is not None else None |
|
) |
|
|
|
joint_loss += si_loss / batch_size |
|
accuracy += si_accuracy / batch_size |
|
all_valid_match.append(torch.squeeze(matched_mask, dim=0)) |
|
all_valid_pos0.append(valid_pos0) |
|
all_valid_pos1.append(valid_pos1) |
|
|
|
return joint_loss, accuracy |
|
|
|
|
|
def make_structured_loss(feat_anc, feat_pos, |
|
loss_type='RATIO', inlier_mask=None, |
|
radius_mask_row=None, radius_mask_col=None, |
|
corr_weight=None, dist_mat=None): |
|
""" |
|
Structured loss construction. |
|
Args: |
|
feat_anc, feat_pos: Feature matrix. |
|
loss_type: Loss type. |
|
inlier_mask: |
|
Returns: |
|
|
|
""" |
|
batch_size = feat_anc.shape[0] |
|
num_corr = feat_anc.shape[1] |
|
if inlier_mask is None: |
|
inlier_mask = torch.ones((batch_size, num_corr), device=feat_anc.device).bool() |
|
inlier_num = torch.count_nonzero(inlier_mask.float(), dim=-1) |
|
|
|
if loss_type == 'L2NET' or loss_type == 'CIRCLE': |
|
dist_type = 'cosine_dist' |
|
elif loss_type.find('HARD') >= 0: |
|
dist_type = 'euclidean_dist' |
|
else: |
|
raise NotImplementedError() |
|
|
|
if dist_mat is None: |
|
dist_mat = get_dist_mat(feat_anc.squeeze(0), feat_pos.squeeze(0), dist_type).unsqueeze(0) |
|
pos_vec = dist_mat[0].diag().unsqueeze(0) |
|
|
|
if loss_type.find('HARD') >= 0: |
|
neg_margin = 1 |
|
dist_mat_without_min_on_diag = dist_mat + \ |
|
10 * torch.unsqueeze(torch.eye(num_corr, device=dist_mat.device), dim=0) |
|
mask = torch.le(dist_mat_without_min_on_diag, 0.008).float() |
|
dist_mat_without_min_on_diag += mask*10 |
|
|
|
if radius_mask_row is not None: |
|
hard_neg_dist_row = dist_mat_without_min_on_diag + 10 * radius_mask_row |
|
else: |
|
hard_neg_dist_row = dist_mat_without_min_on_diag |
|
if radius_mask_col is not None: |
|
hard_neg_dist_col = dist_mat_without_min_on_diag + 10 * radius_mask_col |
|
else: |
|
hard_neg_dist_col = dist_mat_without_min_on_diag |
|
|
|
hard_neg_dist_row = torch.min(hard_neg_dist_row, dim=-1)[0] |
|
hard_neg_dist_col = torch.min(hard_neg_dist_col, dim=-2)[0] |
|
|
|
if loss_type == 'HARD_TRIPLET': |
|
loss_row = torch.clamp(neg_margin + pos_vec - hard_neg_dist_row, min=0) |
|
loss_col = torch.clamp(neg_margin + pos_vec - hard_neg_dist_col, min=0) |
|
elif loss_type == 'HARD_CONTRASTIVE': |
|
pos_margin = 0.2 |
|
pos_loss = torch.clamp(pos_vec - pos_margin, min=0) |
|
loss_row = pos_loss + torch.clamp(neg_margin - hard_neg_dist_row, min=0) |
|
loss_col = pos_loss + torch.clamp(neg_margin - hard_neg_dist_col, min=0) |
|
else: |
|
raise NotImplementedError() |
|
|
|
elif loss_type == 'CIRCLE': |
|
log_scale = 512 |
|
m = 0.1 |
|
neg_mask_row = torch.unsqueeze(torch.eye(num_corr, device=feat_anc.device), 0) |
|
if radius_mask_row is not None: |
|
neg_mask_row += radius_mask_row |
|
neg_mask_col = torch.unsqueeze(torch.eye(num_corr, device=feat_anc.device), 0) |
|
if radius_mask_col is not None: |
|
neg_mask_col += radius_mask_col |
|
|
|
pos_margin = 1 - m |
|
neg_margin = m |
|
pos_optimal = 1 + m |
|
neg_optimal = -m |
|
|
|
neg_mat_row = dist_mat - 128 * neg_mask_row |
|
neg_mat_col = dist_mat - 128 * neg_mask_col |
|
|
|
lse_positive = torch.logsumexp(-log_scale * (pos_vec[..., None] - pos_margin) * \ |
|
torch.clamp(pos_optimal - pos_vec[..., None], min=0).detach(), dim=-1) |
|
|
|
lse_negative_row = torch.logsumexp(log_scale * (neg_mat_row - neg_margin) * \ |
|
torch.clamp(neg_mat_row - neg_optimal, min=0).detach(), dim=-1) |
|
|
|
lse_negative_col = torch.logsumexp(log_scale * (neg_mat_col - neg_margin) * \ |
|
torch.clamp(neg_mat_col - neg_optimal, min=0).detach(), dim=-2) |
|
|
|
loss_row = F.softplus(lse_positive + lse_negative_row) / log_scale |
|
loss_col = F.softplus(lse_positive + lse_negative_col) / log_scale |
|
|
|
else: |
|
raise NotImplementedError() |
|
|
|
if dist_type == 'cosine_dist': |
|
err_row = dist_mat - torch.unsqueeze(pos_vec, -1) |
|
err_col = dist_mat - torch.unsqueeze(pos_vec, -2) |
|
elif dist_type == 'euclidean_dist' or dist_type == 'euclidean_dist_no_norm': |
|
err_row = torch.unsqueeze(pos_vec, -1) - dist_mat |
|
err_col = torch.unsqueeze(pos_vec, -2) - dist_mat |
|
else: |
|
raise NotImplementedError() |
|
if radius_mask_row is not None: |
|
err_row = err_row - 10 * radius_mask_row |
|
if radius_mask_col is not None: |
|
err_col = err_col - 10 * radius_mask_col |
|
err_row = torch.sum(torch.clamp(err_row, min=0), dim=-1) |
|
err_col = torch.sum(torch.clamp(err_col, min=0), dim=-2) |
|
|
|
loss = 0 |
|
accuracy = 0 |
|
|
|
tot_loss = (loss_row + loss_col) / 2 |
|
if corr_weight is not None: |
|
tot_loss = tot_loss * corr_weight |
|
|
|
for i in range(batch_size): |
|
if corr_weight is not None: |
|
loss += torch.sum(tot_loss[i][inlier_mask[i]]) / \ |
|
(torch.sum(corr_weight[i][inlier_mask[i]]) + 1e-6) |
|
else: |
|
loss += torch.mean(tot_loss[i][inlier_mask[i]]) |
|
cnt_err_row = torch.count_nonzero(err_row[i][inlier_mask[i]]).float() |
|
cnt_err_col = torch.count_nonzero(err_col[i][inlier_mask[i]]).float() |
|
tot_err = cnt_err_row + cnt_err_col |
|
if inlier_num[i] != 0: |
|
accuracy += 1. - tot_err / inlier_num[i] / batch_size / 2. |
|
else: |
|
accuracy += 1. |
|
|
|
matched_mask = torch.logical_and(torch.eq(err_row, 0), torch.eq(err_col, 0)) |
|
matched_mask = torch.logical_and(matched_mask, inlier_mask) |
|
|
|
loss /= batch_size |
|
accuracy /= batch_size |
|
|
|
return loss, accuracy, matched_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_noise_score_map_loss(score_map, noise_score_map, indices, batch_size, thld=0.): |
|
H, W = score_map.shape[1:3] |
|
loss = 0 |
|
for i in range(batch_size): |
|
kpts_coords = indices[i].T |
|
mask = torch.zeros([H, W], device=score_map.device) |
|
mask[kpts_coords.cpu().numpy()] = 1 |
|
|
|
|
|
kernel = torch.ones([1, 1, 3, 3], device=score_map.device) |
|
mask = F.conv2d(mask.unsqueeze(0).unsqueeze(0), kernel, padding=1)[0, 0] > 0 |
|
|
|
loss1 = torch.sum(torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask) / torch.sum(mask) |
|
loss2 = torch.sum(torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() * torch.logical_not(mask)) / (H * W - torch.sum(mask)) |
|
|
|
loss += loss1 |
|
loss += loss2 |
|
|
|
if i == 0: |
|
first_mask = mask |
|
|
|
return loss, first_mask |
|
|
|
|
|
def make_noise_score_map_loss_labelmap(score_map, noise_score_map, labelmap, batch_size, thld=0.): |
|
H, W = score_map.shape[1:3] |
|
loss = 0 |
|
for i in range(batch_size): |
|
|
|
kernel = torch.ones([1, 1, 3, 3], device=score_map.device) |
|
mask = F.conv2d(labelmap[i].unsqueeze(0).to(score_map.device).float(), kernel, padding=1)[0, 0] > 0 |
|
|
|
loss1 = torch.sum(torch.abs(score_map[i] - noise_score_map[i]).squeeze() * mask) / torch.sum(mask) |
|
loss2 = torch.sum(torch.clamp(noise_score_map[i] - score_map[i] - thld, min=0).squeeze() * torch.logical_not(mask)) / (H * W - torch.sum(mask)) |
|
|
|
loss += loss1 |
|
loss += loss2 |
|
|
|
if i == 0: |
|
first_mask = mask |
|
|
|
return loss, first_mask |
|
|
|
|
|
def make_score_map_peakiness_loss(score_map, scores, batch_size): |
|
H, W = score_map.shape[1:3] |
|
loss = 0 |
|
|
|
for i in range(batch_size): |
|
loss += torch.mean(scores[i]) - torch.mean(score_map[i]) |
|
|
|
loss /= batch_size |
|
return 1 - loss |
|
|