|
import torch |
|
from torch import nn |
|
|
|
|
|
class NN2(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, data): |
|
desc1, desc2 = data["descriptors0"].cuda(), data["descriptors1"].cuda() |
|
kpts1, kpts2 = data["keypoints0"].cuda(), data["keypoints1"].cuda() |
|
|
|
|
|
|
|
|
|
if kpts1.shape[1] <= 1 or kpts2.shape[1] <= 1: |
|
shape0, shape1 = kpts1.shape[:-1], kpts2.shape[:-1] |
|
return { |
|
"matches0": kpts1.new_full(shape0, -1, dtype=torch.int), |
|
"matches1": kpts2.new_full(shape1, -1, dtype=torch.int), |
|
"matching_scores0": kpts1.new_zeros(shape0), |
|
"matching_scores1": kpts2.new_zeros(shape1), |
|
} |
|
|
|
sim = torch.matmul(desc1.squeeze().T, desc2.squeeze()) |
|
ids1 = torch.arange(0, sim.shape[0], device=desc1.device) |
|
nn12 = torch.argmax(sim, dim=1) |
|
|
|
nn21 = torch.argmax(sim, dim=0) |
|
mask = torch.eq(ids1, nn21[nn12]) |
|
matches = torch.stack( |
|
[torch.masked_select(ids1, mask), torch.masked_select(nn12, mask)] |
|
) |
|
|
|
indices0 = torch.ones((1, desc1.shape[-1]), dtype=int) * -1 |
|
mscores0 = torch.ones((1, desc1.shape[-1]), dtype=float) * -1 |
|
|
|
|
|
|
|
|
|
matches_0 = matches[0].cpu().int().numpy() |
|
matches_1 = matches[1].cpu().int() |
|
for i in range(matches.shape[-1]): |
|
indices0[0, matches_0[i]] = matches_1[i].int() |
|
mscores0[0, matches_0[i]] = sim[matches_0[i], matches_1[i]] |
|
|
|
return { |
|
"matches0": indices0, |
|
"matches1": indices0, |
|
"matching_scores0": mscores0, |
|
"matching_scores1": mscores0, |
|
} |
|
|