Spaces:
Running
Running
File size: 6,714 Bytes
4ecd006 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia.geometry.subpix import dsnt
from kornia.utils.grid import create_meshgrid
from loguru import logger
class FineMatching(nn.Module):
"""FineMatching with s2d paradigm"""
def __init__(self, config):
super().__init__()
self.config = config
self.local_regress_temperature = config['match_fine']['local_regress_temperature']
self.local_regress_slicedim = config['match_fine']['local_regress_slicedim']
self.fp16 = config['half']
def forward(self, feat_0, feat_1, data):
"""
Args:
feat0 (torch.Tensor): [M, WW, C]
feat1 (torch.Tensor): [M, WW, C]
data (dict)
Update:
data (dict):{
'expec_f' (torch.Tensor): [M, 3],
'mkpts0_f' (torch.Tensor): [M, 2],
'mkpts1_f' (torch.Tensor): [M, 2]}
"""
M, WW, C = feat_0.shape
W = int(math.sqrt(WW))
scale = data['hw0_i'][0] / data['hw0_f'][0]
self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
# corner case: if no coarse matches found
if M == 0:
assert self.training == False, "M is always > 0 while training, see coarse_matching.py"
data.update({
'conf_matrix_f': torch.empty(0, WW, WW, device=feat_0.device),
'mkpts0_f': data['mkpts0_c'],
'mkpts1_f': data['mkpts1_c'],
})
return
# compute pixel-level confidence matrix
with torch.autocast(enabled=True, device_type='cuda'):
feat_f0, feat_f1 = feat_0[...,:-self.local_regress_slicedim], feat_1[...,:-self.local_regress_slicedim]
feat_ff0, feat_ff1 = feat_0[...,-self.local_regress_slicedim:], feat_1[...,-self.local_regress_slicedim:]
feat_f0, feat_f1 = feat_f0 / C**.5, feat_f1 / C**.5
conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1)
conf_matrix_ff = torch.einsum('mlc,mrc->mlr', feat_ff0, feat_ff1 / (self.local_regress_slicedim)**.5)
softmax_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2)
softmax_matrix_f = softmax_matrix_f.reshape(M, self.WW, self.W+2, self.W+2)
softmax_matrix_f = softmax_matrix_f[...,1:-1,1:-1].reshape(M, self.WW, self.WW)
# for fine-level supervision
if self.training:
data.update({'sim_matrix_ff': conf_matrix_ff})
data.update({'conf_matrix_f': softmax_matrix_f})
# compute pixel-level absolute kpt coords
self.get_fine_ds_match(softmax_matrix_f, data)
# generate seconde-stage 3x3 grid
idx_l, idx_r = data['idx_l'], data['idx_r']
m_ids = torch.arange(M, device=idx_l.device, dtype=torch.long).unsqueeze(-1)
m_ids = m_ids[:len(data['mconf'])]
idx_r_iids, idx_r_jids = idx_r // W, idx_r % W
m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1)
delta = create_meshgrid(3, 3, True, conf_matrix_ff.device).to(torch.long) # [1, 3, 3, 2]
m_ids = m_ids[...,None,None].expand(-1, 3, 3)
idx_l = idx_l[...,None,None].expand(-1, 3, 3) # [m, k, 3, 3]
idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1]
idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0]
if idx_l.numel() == 0:
data.update({
'mkpts0_f': data['mkpts0_c'],
'mkpts1_f': data['mkpts1_c'],
})
return
# compute second-stage heatmap
conf_matrix_ff = conf_matrix_ff.reshape(M, self.WW, self.W+2, self.W+2)
conf_matrix_ff = conf_matrix_ff[m_ids, idx_l, idx_r_iids, idx_r_jids]
conf_matrix_ff = conf_matrix_ff.reshape(-1, 9)
conf_matrix_ff = F.softmax(conf_matrix_ff / self.local_regress_temperature, -1)
heatmap = conf_matrix_ff.reshape(-1, 3, 3)
# compute coordinates from heatmap
coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
if data['bs'] == 1:
scale1 = scale * data['scale1'] if 'scale0' in data else scale
else:
scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']), ...][:,None,:].expand(-1, -1, 2).reshape(-1, 2) if 'scale0' in data else scale
# compute subpixel-level absolute kpt coords
self.get_fine_match_local(coords_normalized, data, scale1)
def get_fine_match_local(self, coords_normed, data, scale1):
W, WW, C, scale = self.W, self.WW, self.C, self.scale
mkpts0_c, mkpts1_c = data['mkpts0_c'], data['mkpts1_c']
# mkpts0_f and mkpts1_f
mkpts0_f = mkpts0_c
mkpts1_f = mkpts1_c + (coords_normed * (3 // 2) * scale1)
data.update({
"mkpts0_f": mkpts0_f,
"mkpts1_f": mkpts1_f
})
@torch.no_grad()
def get_fine_ds_match(self, conf_matrix, data):
W, WW, C, scale = self.W, self.WW, self.C, self.scale
m, _, _ = conf_matrix.shape
conf_matrix = conf_matrix.reshape(m, -1)[:len(data['mconf']),...]
val, idx = torch.max(conf_matrix, dim = -1)
idx = idx[:,None]
idx_l, idx_r = idx // WW, idx % WW
data.update({'idx_l': idx_l, 'idx_r': idx_r})
if self.fp16:
grid = create_meshgrid(W, W, False, conf_matrix.device, dtype=torch.float16) - W // 2 + 0.5 # kornia >= 0.5.1
else:
grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5
grid = grid.reshape(1, -1, 2).expand(m, -1, -1)
delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2))
delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2))
scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
if torch.is_tensor(scale0) and scale0.numel() > 1: # scale0 is a tensor
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:len(data['mconf']),...][:,None,:])).reshape(-1, 2)
mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:len(data['mconf']),...][:,None,:])).reshape(-1, 2)
else: # scale0 is a float
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)).reshape(-1, 2)
mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)).reshape(-1, 2)
data.update({
"mkpts0_c": mkpts0_f,
"mkpts1_c": mkpts1_f
}) |