import math import torch import torch.nn as nn import torch.nn.functional as F from kornia.geometry.subpix import dsnt from kornia.utils.grid import create_meshgrid class FineMatching(nn.Module): """FineMatching with s2d paradigm""" def __init__(self): super().__init__() def forward(self, feat_f0, feat_f1, data): """ Args: feat0 (torch.Tensor): [M, WW, C] feat1 (torch.Tensor): [M, WW, C] data (dict) Update: data (dict):{ 'expec_f' (torch.Tensor): [M, 3], 'mkpts0_f' (torch.Tensor): [M, 2], 'mkpts1_f' (torch.Tensor): [M, 2]} """ M, WW, C = feat_f0.shape W = int(math.sqrt(WW)) scale = data["hw0_i"][0] / data["hw0_f"][0] self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale # corner case: if no coarse matches found if M == 0: assert ( self.training == False ), "M is always >0, when training, see coarse_matching.py" # logger.warning('No matches found in coarse-level.') data.update( { "expec_f": torch.empty(0, 3, device=feat_f0.device), "mkpts0_f": data["mkpts0_c"], "mkpts1_f": data["mkpts1_c"], } ) return feat_f0_picked = feat_f0[:, WW // 2, :] sim_matrix = torch.einsum("mc,mrc->mr", feat_f0_picked, feat_f1) softmax_temp = 1.0 / C**0.5 heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) feat_f1_picked = (feat_f1 * heatmap.unsqueeze(-1)).sum(dim=1) # [M, C] heatmap = heatmap.view(-1, W, W) # compute coordinates from heatmap coords1_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[ 0 ] # [M, 2] grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape( 1, -1, 2 ) # [1, WW, 2] # compute std over var = ( torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords1_normalized**2 ) # [M, 2] std = torch.sum( torch.sqrt(torch.clamp(var, min=1e-10)), -1 ) # [M] clamp needed for numerical stability # for fine-level supervision data.update( { "expec_f": torch.cat([coords1_normalized, std.unsqueeze(1)], -1), "descriptors0": feat_f0_picked.detach(), "descriptors1": feat_f1_picked.detach(), } ) # compute absolute kpt coords self.get_fine_match(coords1_normalized, data) @torch.no_grad() def get_fine_match(self, coords1_normed, data): W, WW, C, scale = self.W, self.WW, self.C, self.scale # mkpts0_f and mkpts1_f # scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale mkpts0_f = data[ "mkpts0_c" ] # + (coords0_normed * (W // 2) * scale0 )[:len(data['mconf'])] scale1 = scale * data["scale1"][data["b_ids"]] if "scale1" in data else scale mkpts1_f = ( data["mkpts1_c"] + (coords1_normed * (W // 2) * scale1)[: len(data["mconf"])] ) data.update({"mkpts0_f": mkpts0_f, "mkpts1_f": mkpts1_f})