File size: 2,449 Bytes
4bde5d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import torch
from ..utils.base_model import BaseModel
def find_nn(sim, ratio_thresh, distance_thresh):
sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True)
dist_nn = 2 * (1 - sim_nn)
mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device)
if ratio_thresh:
mask = mask & (dist_nn[..., 0] <= (ratio_thresh**2) * dist_nn[..., 1])
if distance_thresh:
mask = mask & (dist_nn[..., 0] <= distance_thresh**2)
matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1))
scores = torch.where(mask, (sim_nn[..., 0] + 1) / 2, sim_nn.new_tensor(0))
return matches, scores
def mutual_check(m0, m1):
inds0 = torch.arange(m0.shape[-1], device=m0.device)
loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0)))
ok = (m0 > -1) & (inds0 == loop)
m0_new = torch.where(ok, m0, m0.new_tensor(-1))
return m0_new
class NearestNeighbor(BaseModel):
default_conf = {
"ratio_threshold": None,
"distance_threshold": None,
"do_mutual_check": True,
}
required_inputs = ["descriptors0", "descriptors1"]
def _init(self, conf):
pass
def _forward(self, data):
if (
data["descriptors0"].size(-1) == 0
or data["descriptors1"].size(-1) == 0
):
matches0 = torch.full(
data["descriptors0"].shape[:2],
-1,
device=data["descriptors0"].device,
)
return {
"matches0": matches0,
"matching_scores0": torch.zeros_like(matches0),
}
ratio_threshold = self.conf["ratio_threshold"]
if (
data["descriptors0"].size(-1) == 1
or data["descriptors1"].size(-1) == 1
):
ratio_threshold = None
sim = torch.einsum(
"bdn,bdm->bnm", data["descriptors0"], data["descriptors1"]
)
matches0, scores0 = find_nn(
sim, ratio_threshold, self.conf["distance_threshold"]
)
if self.conf["do_mutual_check"]:
matches1, scores1 = find_nn(
sim.transpose(1, 2),
ratio_threshold,
self.conf["distance_threshold"],
)
matches0 = mutual_check(matches0, matches1)
return {
"matches0": matches0,
"matching_scores0": scores0,
}
|