|
import sys |
|
from pathlib import Path |
|
import torch |
|
|
|
from ..utils.base_model import BaseModel |
|
|
|
sys.path.append(str(Path(__file__).parent / "../../third_party")) |
|
from SuperGluePretrainedNetwork.models import superpoint |
|
|
|
|
|
|
|
|
|
def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8): |
|
"""Interpolate descriptors at keypoint locations""" |
|
b, c, h, w = descriptors.shape |
|
keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s) |
|
keypoints = keypoints * 2 - 1 |
|
descriptors = torch.nn.functional.grid_sample( |
|
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False |
|
) |
|
descriptors = torch.nn.functional.normalize( |
|
descriptors.reshape(b, c, -1), p=2, dim=1 |
|
) |
|
return descriptors |
|
|
|
|
|
class SuperPoint(BaseModel): |
|
default_conf = { |
|
"nms_radius": 4, |
|
"keypoint_threshold": 0.005, |
|
"max_keypoints": -1, |
|
"remove_borders": 4, |
|
"fix_sampling": False, |
|
} |
|
required_inputs = ["image"] |
|
detection_noise = 2.0 |
|
|
|
def _init(self, conf): |
|
if conf["fix_sampling"]: |
|
superpoint.sample_descriptors = sample_descriptors_fix_sampling |
|
self.net = superpoint.SuperPoint(conf) |
|
|
|
def _forward(self, data): |
|
return self.net(data) |
|
|