Vincentqyw
update: limit keypoints number
f269db9
raw
history blame
5 kB
import kornia
from kornia.feature.laf import (
laf_from_center_scale_ori,
extract_patches_from_pyramid,
)
import numpy as np
import torch
import pycolmap
from ..utils.base_model import BaseModel
EPS = 1e-6
def sift_to_rootsift(x):
x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
x = np.sqrt(x.clip(min=EPS))
x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
return x
class DoG(BaseModel):
default_conf = {
"options": {
"first_octave": 0,
"peak_threshold": 0.01,
},
"descriptor": "rootsift",
"max_keypoints": -1,
"patch_size": 32,
"mr_size": 12,
}
required_inputs = ["image"]
detection_noise = 1.0
max_batch_size = 1024
def _init(self, conf):
if conf["descriptor"] == "sosnet":
self.describe = kornia.feature.SOSNet(pretrained=True)
elif conf["descriptor"] == "hardnet":
self.describe = kornia.feature.HardNet(pretrained=True)
elif conf["descriptor"] not in ["sift", "rootsift"]:
raise ValueError(f'Unknown descriptor: {conf["descriptor"]}')
self.sift = None # lazily instantiated on the first image
self.device = torch.device("cpu")
def to(self, *args, **kwargs):
device = kwargs.get("device")
if device is None:
match = [a for a in args if isinstance(a, (torch.device, str))]
if len(match) > 0:
device = match[0]
if device is not None:
self.device = torch.device(device)
return super().to(*args, **kwargs)
def _forward(self, data):
image = data["image"]
image_np = image.cpu().numpy()[0, 0]
assert image.shape[1] == 1
assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
if self.sift is None:
use_gpu = pycolmap.has_cuda and self.device.type == "cuda"
options = {**self.conf["options"]}
if self.conf["descriptor"] == "rootsift":
options["normalization"] = pycolmap.Normalization.L1_ROOT
else:
options["normalization"] = pycolmap.Normalization.L2
self.sift = pycolmap.Sift(
options=pycolmap.SiftExtractionOptions(options),
device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"),
)
keypoints, scores, descriptors = self.sift.extract(image_np)
scales = keypoints[:, 2]
oris = np.rad2deg(keypoints[:, 3])
if self.conf["descriptor"] in ["sift", "rootsift"]:
# We still renormalize because COLMAP does not normalize well,
# maybe due to numerical errors
if self.conf["descriptor"] == "rootsift":
descriptors = sift_to_rootsift(descriptors)
descriptors = torch.from_numpy(descriptors)
elif self.conf["descriptor"] in ("sosnet", "hardnet"):
center = keypoints[:, :2] + 0.5
laf_scale = scales * self.conf["mr_size"] / 2
laf_ori = -oris
lafs = laf_from_center_scale_ori(
torch.from_numpy(center)[None],
torch.from_numpy(laf_scale)[None, :, None, None],
torch.from_numpy(laf_ori)[None, :, None],
).to(image.device)
patches = extract_patches_from_pyramid(
image, lafs, PS=self.conf["patch_size"]
)[0]
descriptors = patches.new_zeros((len(patches), 128))
if len(patches) > 0:
for start_idx in range(0, len(patches), self.max_batch_size):
end_idx = min(len(patches), start_idx + self.max_batch_size)
descriptors[start_idx:end_idx] = self.describe(
patches[start_idx:end_idx]
)
else:
raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}')
keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
scales = torch.from_numpy(scales)
oris = torch.from_numpy(oris)
scores = torch.from_numpy(scores)
if self.conf["max_keypoints"] != -1:
# TODO: check that the scores from PyCOLMAP are 100% correct,
# follow https://github.com/mihaidusmanu/pycolmap/issues/8
max_number = (
scores.shape[0]
if scores.shape[0] < self.conf["max_keypoints"]
else self.conf["max_keypoints"]
)
values, indices = torch.topk(scores, max_number)
keypoints = keypoints[indices]
scales = scales[indices]
oris = oris[indices]
scores = scores[indices]
descriptors = descriptors[indices]
return {
"keypoints": keypoints[None],
"scales": scales[None],
"oris": oris[None],
"scores": scores[None],
"descriptors": descriptors.T[None],
}