|
import torch |
|
import torch.nn as nn |
|
import kornia |
|
from types import SimpleNamespace |
|
from .utils import ImagePreprocessor |
|
|
|
|
|
class DISK(nn.Module): |
|
default_conf = { |
|
'weights': 'depth', |
|
'max_num_keypoints': None, |
|
'desc_dim': 128, |
|
'nms_window_size': 5, |
|
'detection_threshold': 0.0, |
|
'pad_if_not_divisible': True, |
|
} |
|
|
|
preprocess_conf = { |
|
**ImagePreprocessor.default_conf, |
|
'resize': 1024, |
|
'grayscale': False, |
|
} |
|
|
|
required_data_keys = ['image'] |
|
|
|
def __init__(self, **conf) -> None: |
|
super().__init__() |
|
self.conf = {**self.default_conf, **conf} |
|
self.conf = SimpleNamespace(**self.conf) |
|
self.model = kornia.feature.DISK.from_pretrained(self.conf.weights) |
|
|
|
def forward(self, data: dict) -> dict: |
|
""" Compute keypoints, scores, descriptors for image """ |
|
for key in self.required_data_keys: |
|
assert key in data, f'Missing key {key} in data' |
|
image = data['image'] |
|
features = self.model( |
|
image, |
|
n=self.conf.max_num_keypoints, |
|
window_size=self.conf.nms_window_size, |
|
score_threshold=self.conf.detection_threshold, |
|
pad_if_not_divisible=self.conf.pad_if_not_divisible |
|
) |
|
keypoints = [f.keypoints for f in features] |
|
scores = [f.detection_scores for f in features] |
|
descriptors = [f.descriptors for f in features] |
|
del features |
|
|
|
keypoints = torch.stack(keypoints, 0) |
|
scores = torch.stack(scores, 0) |
|
descriptors = torch.stack(descriptors, 0) |
|
|
|
return { |
|
'keypoints': keypoints.to(image), |
|
'keypoint_scores': scores.to(image), |
|
'descriptors': descriptors.to(image), |
|
} |
|
|
|
def extract(self, img: torch.Tensor, **conf) -> dict: |
|
""" Perform extraction with online resizing""" |
|
if img.dim() == 3: |
|
img = img[None] |
|
assert img.dim() == 4 and img.shape[0] == 1 |
|
shape = img.shape[-2:][::-1] |
|
img, scales = ImagePreprocessor( |
|
**{**self.preprocess_conf, **conf})(img) |
|
feats = self.forward({'image': img}) |
|
feats['image_size'] = torch.tensor(shape)[None].to(img).float() |
|
feats['keypoints'] = (feats['keypoints'] + .5) / scales[None] - .5 |
|
return feats |
|
|