|
import torch |
|
|
|
|
|
def build_descriptor_loss( |
|
source_des, target_des, tar_points_un, top_kk=None, relax_field=4, eval_only=False |
|
): |
|
""" |
|
Desc Head Loss, per-pixel level triplet loss from https://arxiv.org/pdf/1902.11046.pdf. |
|
|
|
Parameters |
|
---------- |
|
source_des: torch.Tensor (B,256,H/8,W/8) |
|
Source image descriptors. |
|
target_des: torch.Tensor (B,256,H/8,W/8) |
|
Target image descriptors. |
|
source_points: torch.Tensor (B,H/8,W/8,2) |
|
Source image keypoints |
|
tar_points: torch.Tensor (B,H/8,W/8,2) |
|
Target image keypoints |
|
tar_points_un: torch.Tensor (B,2,H/8,W/8) |
|
Target image keypoints unnormalized |
|
eval_only: bool |
|
Computes only recall without the loss. |
|
Returns |
|
------- |
|
loss: torch.Tensor |
|
Descriptor loss. |
|
recall: torch.Tensor |
|
Descriptor match recall. |
|
""" |
|
device = source_des.device |
|
loss = 0 |
|
batch_size = source_des.size(0) |
|
recall = 0.0 |
|
|
|
relax_field_size = [relax_field] |
|
margins = [1.0] |
|
weights = [1.0] |
|
|
|
isource_dense = top_kk is None |
|
|
|
for b_id in range(batch_size): |
|
|
|
if isource_dense: |
|
ref_desc = source_des[b_id].squeeze().view(256, -1) |
|
tar_desc = target_des[b_id].squeeze().view(256, -1) |
|
tar_points_raw = tar_points_un[b_id].view(2, -1) |
|
else: |
|
top_k = top_kk[b_id].squeeze() |
|
|
|
n_feat = top_k.sum().item() |
|
if n_feat < 20: |
|
continue |
|
|
|
ref_desc = source_des[b_id].squeeze()[:, top_k] |
|
tar_desc = target_des[b_id].squeeze()[:, top_k] |
|
tar_points_raw = tar_points_un[b_id][:, top_k] |
|
|
|
|
|
ref_desc = ref_desc.div(torch.norm(ref_desc, p=2, dim=0)) |
|
tar_desc = tar_desc.div(torch.norm(tar_desc, p=2, dim=0)) |
|
dmat = torch.mm(ref_desc.t(), tar_desc) |
|
|
|
dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1)) |
|
_, idx = torch.sort(dmat, dim=1) |
|
|
|
|
|
for pyramid in range(len(relax_field_size)): |
|
|
|
candidates = idx.t() |
|
|
|
match_k_x = tar_points_raw[0, candidates] |
|
match_k_y = tar_points_raw[1, candidates] |
|
|
|
tru_x = tar_points_raw[0] |
|
tru_y = tar_points_raw[1] |
|
|
|
if pyramid == 0: |
|
correct2 = (abs(match_k_x[0] - tru_x) == 0) & ( |
|
abs(match_k_y[0] - tru_y) == 0 |
|
) |
|
correct2_cnt = correct2.float().sum() |
|
recall += float(1.0 / batch_size) * ( |
|
float(correct2_cnt) / float(ref_desc.size(1)) |
|
) |
|
|
|
if eval_only: |
|
continue |
|
correct_k = (abs(match_k_x - tru_x) <= relax_field_size[pyramid]) & ( |
|
abs(match_k_y - tru_y) <= relax_field_size[pyramid] |
|
) |
|
|
|
incorrect_index = ( |
|
torch.arange(start=correct_k.shape[0] - 1, end=-1, step=-1) |
|
.unsqueeze(1) |
|
.repeat(1, correct_k.shape[1]) |
|
.to(device) |
|
) |
|
incorrect_first = torch.argmax( |
|
incorrect_index * (1 - correct_k.long()), dim=0 |
|
) |
|
|
|
incorrect_first_index = candidates.gather( |
|
0, incorrect_first.unsqueeze(0) |
|
).squeeze() |
|
|
|
anchor_var = ref_desc |
|
posource_var = tar_desc |
|
neg_var = tar_desc[:, incorrect_first_index] |
|
|
|
loss += float(1.0 / batch_size) * torch.nn.functional.triplet_margin_loss( |
|
anchor_var.t(), posource_var.t(), neg_var.t(), margin=margins[pyramid] |
|
).mul(weights[pyramid]) |
|
|
|
return loss, recall |
|
|
|
|
|
class KeypointLoss(object): |
|
""" |
|
Loss function class encapsulating the location loss, the descriptor loss, and the score loss. |
|
""" |
|
|
|
def __init__(self, config): |
|
self.score_weight = config.score_weight |
|
self.loc_weight = config.loc_weight |
|
self.desc_weight = config.desc_weight |
|
self.corres_weight = config.corres_weight |
|
self.corres_threshold = config.corres_threshold |
|
|
|
def __call__(self, data): |
|
B, _, hc, wc = data["source_score"].shape |
|
|
|
loc_mat_abs = torch.abs( |
|
data["target_coord_warped"].view(B, 2, -1).unsqueeze(3) |
|
- data["target_coord"].view(B, 2, -1).unsqueeze(2) |
|
) |
|
l2_dist_loc_mat = torch.norm(loc_mat_abs, p=2, dim=1) |
|
l2_dist_loc_min, l2_dist_loc_min_index = l2_dist_loc_mat.min(dim=2) |
|
|
|
|
|
loc_min_mat = torch.repeat_interleave( |
|
l2_dist_loc_min.unsqueeze(dim=-1), repeats=l2_dist_loc_mat.shape[-1], dim=-1 |
|
) |
|
pos_mask = l2_dist_loc_mat.eq(loc_min_mat) & l2_dist_loc_mat.le(1.0) |
|
neg_mask = l2_dist_loc_mat.ge(4.0) |
|
|
|
pos_corres = -torch.log(data["confidence_matrix"][pos_mask]) |
|
neg_corres = -torch.log(1.0 - data["confidence_matrix"][neg_mask]) |
|
corres_loss = pos_corres.mean() + 5e5 * neg_corres.mean() |
|
|
|
|
|
dist_norm_valid_mask = l2_dist_loc_min.lt(self.corres_threshold) & data[ |
|
"border_mask" |
|
].view(B, hc * wc) |
|
|
|
|
|
loc_loss = l2_dist_loc_min[dist_norm_valid_mask].mean() |
|
|
|
|
|
desc_loss, _ = build_descriptor_loss( |
|
data["source_desc"], |
|
data["target_desc_warped"], |
|
data["target_coord_warped"].detach(), |
|
top_kk=data["border_mask"], |
|
relax_field=8, |
|
) |
|
|
|
|
|
target_score_associated = ( |
|
data["target_score"] |
|
.view(B, hc * wc) |
|
.gather(1, l2_dist_loc_min_index) |
|
.view(B, hc, wc) |
|
.unsqueeze(1) |
|
) |
|
dist_norm_valid_mask = dist_norm_valid_mask.view(B, hc, wc).unsqueeze(1) & data[ |
|
"border_mask" |
|
].unsqueeze(1) |
|
l2_dist_loc_min = l2_dist_loc_min.view(B, hc, wc).unsqueeze(1) |
|
loc_err = l2_dist_loc_min[dist_norm_valid_mask] |
|
|
|
|
|
repeatable_constrain = ( |
|
( |
|
target_score_associated[dist_norm_valid_mask] |
|
+ data["source_score"][dist_norm_valid_mask] |
|
) |
|
* (loc_err - loc_err.mean()) |
|
).mean() |
|
|
|
|
|
consistent_constrain = ( |
|
torch.nn.functional.mse_loss( |
|
data["target_score_warped"][data["border_mask"].unsqueeze(1)], |
|
data["source_score"][data["border_mask"].unsqueeze(1)], |
|
).mean() |
|
* 2 |
|
) |
|
aware_consistent_loss = ( |
|
torch.nn.functional.mse_loss( |
|
data["target_aware_warped"][ |
|
data["border_mask"].unsqueeze(1).repeat(1, 2, 1, 1) |
|
], |
|
data["source_aware"][ |
|
data["border_mask"].unsqueeze(1).repeat(1, 2, 1, 1) |
|
], |
|
).mean() |
|
* 2 |
|
) |
|
|
|
score_loss = repeatable_constrain + consistent_constrain + aware_consistent_loss |
|
|
|
loss = ( |
|
self.loc_weight * loc_loss |
|
+ self.desc_weight * desc_loss |
|
+ self.score_weight * score_loss |
|
+ self.corres_weight * corres_loss |
|
) |
|
|
|
return ( |
|
loss, |
|
self.loc_weight * loc_loss, |
|
self.desc_weight * desc_loss, |
|
self.score_weight * score_loss, |
|
self.corres_weight * corres_loss, |
|
) |
|
|