Spaces:
Running
Running
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from kornia.geometry.subpix import dsnt | |
from kornia.utils.grid import create_meshgrid | |
class FineMatching(nn.Module): | |
"""FineMatching with s2d paradigm""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, feat_f0, feat_f1, data): | |
""" | |
Args: | |
feat0 (torch.Tensor): [M, WW, C] | |
feat1 (torch.Tensor): [M, WW, C] | |
data (dict) | |
Update: | |
data (dict):{ | |
'expec_f' (torch.Tensor): [M, 3], | |
'mkpts0_f' (torch.Tensor): [M, 2], | |
'mkpts1_f' (torch.Tensor): [M, 2]} | |
""" | |
M, WW, C = feat_f0.shape | |
W = int(math.sqrt(WW)) | |
scale = data['hw0_i'][0] / data['hw0_f'][0] | |
self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale | |
# corner case: if no coarse matches found | |
if M == 0: | |
assert self.training == False, "M is always >0, when training, see coarse_matching.py" | |
# logger.warning('No matches found in coarse-level.') | |
data.update({ | |
'expec_f': torch.empty(0, 3, device=feat_f0.device), | |
'mkpts0_f': data['mkpts0_c'], | |
'mkpts1_f': data['mkpts1_c'], | |
}) | |
return | |
feat_f0_picked = feat_f0[:, WW//2, :] | |
sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) | |
softmax_temp = 1. / C**.5 | |
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1) | |
feat_f1_picked = (feat_f1 * heatmap.unsqueeze(-1)).sum(dim=1) # [M, C] | |
heatmap = heatmap.view(-1, W, W) | |
# compute coordinates from heatmap | |
coords1_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] | |
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] | |
# compute std over <x, y> | |
var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords1_normalized**2 # [M, 2] | |
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability | |
# for fine-level supervision | |
data.update({'expec_f': torch.cat([coords1_normalized, std.unsqueeze(1)], -1), | |
'descriptors0': feat_f0_picked.detach(), 'descriptors1': feat_f1_picked.detach()}) | |
# compute absolute kpt coords | |
self.get_fine_match(coords1_normalized, data) | |
def get_fine_match(self, coords1_normed, data): | |
W, WW, C, scale = self.W, self.WW, self.C, self.scale | |
# mkpts0_f and mkpts1_f | |
# scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale | |
mkpts0_f = data['mkpts0_c'] # + (coords0_normed * (W // 2) * scale0 )[:len(data['mconf'])] | |
scale1 = scale * data['scale1'][data['b_ids']] if 'scale1' in data else scale | |
mkpts1_f = data['mkpts1_c'] + (coords1_normed * (W // 2) * scale1)[:len(data['mconf'])] | |
data.update({ | |
"mkpts0_f": mkpts0_f, | |
"mkpts1_f": mkpts1_f | |
}) | |