|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops.einops import rearrange |
|
|
|
INF = 1e9 |
|
|
|
|
|
def mask_border(m, b: int, v): |
|
"""Mask borders with value |
|
Args: |
|
m (torch.Tensor): [N, H0, W0, H1, W1] |
|
b (int) |
|
v (m.dtype) |
|
""" |
|
if b <= 0: |
|
return |
|
|
|
m[:, :b] = v |
|
m[:, :, :b] = v |
|
m[:, :, :, :b] = v |
|
m[:, :, :, :, :b] = v |
|
m[:, -b:] = v |
|
m[:, :, -b:] = v |
|
m[:, :, :, -b:] = v |
|
m[:, :, :, :, -b:] = v |
|
|
|
|
|
def mask_border_with_padding(m, bd, v, p_m0, p_m1): |
|
if bd <= 0: |
|
return |
|
|
|
m[:, :bd] = v |
|
m[:, :, :bd] = v |
|
m[:, :, :, :bd] = v |
|
m[:, :, :, :, :bd] = v |
|
|
|
h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() |
|
h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() |
|
for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): |
|
m[b_idx, h0 - bd :] = v |
|
m[b_idx, :, w0 - bd :] = v |
|
m[b_idx, :, :, h1 - bd :] = v |
|
m[b_idx, :, :, :, w1 - bd :] = v |
|
|
|
|
|
def compute_max_candidates(p_m0, p_m1): |
|
"""Compute the max candidates of all pairs within a batch |
|
|
|
Args: |
|
p_m0, p_m1 (torch.Tensor): padded masks |
|
""" |
|
h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] |
|
h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] |
|
max_cand = torch.sum(torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) |
|
return max_cand |
|
|
|
|
|
class CoarseMatching(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
self.thr = config["thr"] |
|
self.border_rm = config["border_rm"] |
|
|
|
self.train_coarse_percent = config["train_coarse_percent"] |
|
self.train_pad_num_gt_min = config["train_pad_num_gt_min"] |
|
|
|
|
|
self.match_type = config["match_type"] |
|
if self.match_type == "dual_softmax": |
|
self.temperature = config["dsmax_temperature"] |
|
elif self.match_type == "sinkhorn": |
|
try: |
|
from .superglue import log_optimal_transport |
|
except ImportError: |
|
raise ImportError("download superglue.py first!") |
|
self.log_optimal_transport = log_optimal_transport |
|
self.bin_score = nn.Parameter( |
|
torch.tensor(config["skh_init_bin_score"], requires_grad=True) |
|
) |
|
self.skh_iters = config["skh_iters"] |
|
self.skh_prefilter = config["skh_prefilter"] |
|
else: |
|
raise NotImplementedError() |
|
|
|
def forward(self, data): |
|
""" |
|
Args: |
|
data (dict) |
|
Update: |
|
data (dict): { |
|
'b_ids' (torch.Tensor): [M'], |
|
'i_ids' (torch.Tensor): [M'], |
|
'j_ids' (torch.Tensor): [M'], |
|
'gt_mask' (torch.Tensor): [M'], |
|
'mkpts0_c' (torch.Tensor): [M, 2], |
|
'mkpts1_c' (torch.Tensor): [M, 2], |
|
'mconf' (torch.Tensor): [M]} |
|
NOTE: M' != M during training. |
|
""" |
|
conf_matrix = data["conf_matrix"] |
|
|
|
data.update(**self.get_coarse_match(conf_matrix, data)) |
|
|
|
@torch.no_grad() |
|
def get_coarse_match(self, conf_matrix, data): |
|
""" |
|
Args: |
|
conf_matrix (torch.Tensor): [N, L, S] |
|
data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] |
|
Returns: |
|
coarse_matches (dict): { |
|
'b_ids' (torch.Tensor): [M'], |
|
'i_ids' (torch.Tensor): [M'], |
|
'j_ids' (torch.Tensor): [M'], |
|
'gt_mask' (torch.Tensor): [M'], |
|
'm_bids' (torch.Tensor): [M], |
|
'mkpts0_c' (torch.Tensor): [M, 2], |
|
'mkpts1_c' (torch.Tensor): [M, 2], |
|
'mconf' (torch.Tensor): [M]} |
|
""" |
|
axes_lengths = { |
|
"h0c": data["hw0_c"][0], |
|
"w0c": data["hw0_c"][1], |
|
"h1c": data["hw1_c"][0], |
|
"w1c": data["hw1_c"][1], |
|
} |
|
_device = conf_matrix.device |
|
|
|
mask = conf_matrix > self.thr |
|
mask = rearrange( |
|
mask, "b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c", **axes_lengths |
|
) |
|
if "mask0" not in data: |
|
mask_border(mask, self.border_rm, False) |
|
else: |
|
mask_border_with_padding( |
|
mask, self.border_rm, False, data["mask0"], data["mask1"] |
|
) |
|
mask = rearrange( |
|
mask, "b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)", **axes_lengths |
|
) |
|
|
|
|
|
mask = ( |
|
mask |
|
* (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) |
|
* (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) |
|
) |
|
|
|
|
|
|
|
mask_v, all_j_ids = mask.max(dim=2) |
|
b_ids, i_ids = torch.where(mask_v) |
|
j_ids = all_j_ids[b_ids, i_ids] |
|
mconf = conf_matrix[b_ids, i_ids, j_ids] |
|
|
|
|
|
|
|
if self.training: |
|
|
|
|
|
|
|
if "mask0" not in data: |
|
num_candidates_max = mask.size(0) * max(mask.size(1), mask.size(2)) |
|
else: |
|
num_candidates_max = compute_max_candidates( |
|
data["mask0"], data["mask1"] |
|
) |
|
num_matches_train = int(num_candidates_max * self.train_coarse_percent) |
|
num_matches_pred = len(b_ids) |
|
assert ( |
|
self.train_pad_num_gt_min < num_matches_train |
|
), "min-num-gt-pad should be less than num-train-matches" |
|
|
|
|
|
if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: |
|
pred_indices = torch.arange(num_matches_pred, device=_device) |
|
else: |
|
pred_indices = torch.randint( |
|
num_matches_pred, |
|
(num_matches_train - self.train_pad_num_gt_min,), |
|
device=_device, |
|
) |
|
|
|
|
|
gt_pad_indices = torch.randint( |
|
len(data["spv_b_ids"]), |
|
(max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min),), |
|
device=_device, |
|
) |
|
mconf_gt = torch.zeros( |
|
len(data["spv_b_ids"]), device=_device |
|
) |
|
|
|
b_ids, i_ids, j_ids, mconf = map( |
|
lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], dim=0), |
|
*zip( |
|
[b_ids, data["spv_b_ids"]], |
|
[i_ids, data["spv_i_ids"]], |
|
[j_ids, data["spv_j_ids"]], |
|
[mconf, mconf_gt], |
|
) |
|
) |
|
|
|
|
|
coarse_matches = {"b_ids": b_ids, "i_ids": i_ids, "j_ids": j_ids} |
|
|
|
|
|
scale = data["hw0_i"][0] / data["hw0_c"][0] |
|
scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale |
|
scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale |
|
mkpts0_c = ( |
|
torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1) |
|
* scale0 |
|
) |
|
mkpts1_c = ( |
|
torch.stack([j_ids % data["hw1_c"][1], j_ids // data["hw1_c"][1]], dim=1) |
|
* scale1 |
|
) |
|
|
|
|
|
coarse_matches.update( |
|
{ |
|
"gt_mask": mconf == 0, |
|
"m_bids": b_ids[mconf != 0], |
|
"mkpts0_c": mkpts0_c[mconf != 0], |
|
"mkpts1_c": mkpts1_c[mconf != 0], |
|
"mconf": mconf[mconf != 0], |
|
} |
|
) |
|
|
|
return coarse_matches |
|
|