Spaces:
Running
Running
File size: 5,132 Bytes
9223079 e64cfb1 f269db9 e64cfb1 f269db9 9223079 773eef4 9223079 773eef4 9223079 773eef4 9223079 773eef4 9223079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import kornia
import numpy as np
import pycolmap
import torch
from kornia.feature.laf import (
extract_patches_from_pyramid,
laf_from_center_scale_ori,
)
from ..utils.base_model import BaseModel
EPS = 1e-6
def sift_to_rootsift(x):
x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
x = np.sqrt(x.clip(min=EPS))
x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
return x
class DoG(BaseModel):
default_conf = {
"options": {
"first_octave": 0,
"peak_threshold": 0.01,
},
"descriptor": "rootsift",
"max_keypoints": -1,
"patch_size": 32,
"mr_size": 12,
}
required_inputs = ["image"]
detection_noise = 1.0
max_batch_size = 1024
def _init(self, conf):
if conf["descriptor"] == "sosnet":
self.describe = kornia.feature.SOSNet(pretrained=True)
elif conf["descriptor"] == "hardnet":
self.describe = kornia.feature.HardNet(pretrained=True)
elif conf["descriptor"] not in ["sift", "rootsift"]:
raise ValueError(f'Unknown descriptor: {conf["descriptor"]}')
self.sift = None # lazily instantiated on the first image
self.dummy_param = torch.nn.Parameter(torch.empty(0))
self.device = torch.device("cpu")
def to(self, *args, **kwargs):
device = kwargs.get("device")
if device is None:
match = [a for a in args if isinstance(a, (torch.device, str))]
if len(match) > 0:
device = match[0]
if device is not None:
self.device = torch.device(device)
return super().to(*args, **kwargs)
def _forward(self, data):
image = data["image"]
image_np = image.cpu().numpy()[0, 0]
assert image.shape[1] == 1
assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS
if self.sift is None:
device = self.dummy_param.device
use_gpu = pycolmap.has_cuda and device.type == "cuda"
options = {**self.conf["options"]}
if self.conf["descriptor"] == "rootsift":
options["normalization"] = pycolmap.Normalization.L1_ROOT
else:
options["normalization"] = pycolmap.Normalization.L2
self.sift = pycolmap.Sift(
options=pycolmap.SiftExtractionOptions(options),
device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"),
)
keypoints, descriptors = self.sift.extract(image_np)
scales = keypoints[:, 2]
oris = np.rad2deg(keypoints[:, 3])
if self.conf["descriptor"] in ["sift", "rootsift"]:
# We still renormalize because COLMAP does not normalize well,
# maybe due to numerical errors
if self.conf["descriptor"] == "rootsift":
descriptors = sift_to_rootsift(descriptors)
descriptors = torch.from_numpy(descriptors)
elif self.conf["descriptor"] in ("sosnet", "hardnet"):
center = keypoints[:, :2] + 0.5
laf_scale = scales * self.conf["mr_size"] / 2
laf_ori = -oris
lafs = laf_from_center_scale_ori(
torch.from_numpy(center)[None],
torch.from_numpy(laf_scale)[None, :, None, None],
torch.from_numpy(laf_ori)[None, :, None],
).to(image.device)
patches = extract_patches_from_pyramid(
image, lafs, PS=self.conf["patch_size"]
)[0]
descriptors = patches.new_zeros((len(patches), 128))
if len(patches) > 0:
for start_idx in range(0, len(patches), self.max_batch_size):
end_idx = min(len(patches), start_idx + self.max_batch_size)
descriptors[start_idx:end_idx] = self.describe(
patches[start_idx:end_idx]
)
else:
raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}')
keypoints = torch.from_numpy(keypoints[:, :2]) # keep only x, y
scales = torch.from_numpy(scales)
oris = torch.from_numpy(oris)
scores = keypoints.new_zeros(len(keypoints)) # no scores for SIFT yet
if self.conf["max_keypoints"] != -1:
# TODO: check that the scores from PyCOLMAP are 100% correct,
# follow https://github.com/mihaidusmanu/pycolmap/issues/8
max_number = (
scores.shape[0]
if scores.shape[0] < self.conf["max_keypoints"]
else self.conf["max_keypoints"]
)
values, indices = torch.topk(scores, max_number)
keypoints = keypoints[indices]
scales = scales[indices]
oris = oris[indices]
scores = scores[indices]
descriptors = descriptors[indices]
return {
"keypoints": keypoints[None],
"scales": scales[None],
"oris": oris[None],
"scores": scores[None],
"descriptors": descriptors.T[None],
}
|