File size: 2,366 Bytes
2673dcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import kornia
from types import SimpleNamespace
from .utils import ImagePreprocessor


class DISK(nn.Module):
    default_conf = {
        'weights': 'depth',
        'max_num_keypoints': None,
        'desc_dim': 128,
        'nms_window_size': 5,
        'detection_threshold': 0.0,
        'pad_if_not_divisible': True,
    }

    preprocess_conf = {
        **ImagePreprocessor.default_conf,
        'resize': 1024,
        'grayscale': False,
    }

    required_data_keys = ['image']

    def __init__(self, **conf) -> None:
        super().__init__()
        self.conf = {**self.default_conf, **conf}
        self.conf = SimpleNamespace(**self.conf)
        self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)

    def forward(self, data: dict) -> dict:
        """ Compute keypoints, scores, descriptors for image """
        for key in self.required_data_keys:
            assert key in data, f'Missing key {key} in data'
        image = data['image']
        features = self.model(
            image,
            n=self.conf.max_num_keypoints,
            window_size=self.conf.nms_window_size,
            score_threshold=self.conf.detection_threshold,
            pad_if_not_divisible=self.conf.pad_if_not_divisible
        )
        keypoints = [f.keypoints for f in features]
        scores = [f.detection_scores for f in features]
        descriptors = [f.descriptors for f in features]
        del features

        keypoints = torch.stack(keypoints, 0)
        scores = torch.stack(scores, 0)
        descriptors = torch.stack(descriptors, 0)

        return {
            'keypoints': keypoints.to(image),
            'keypoint_scores': scores.to(image),
            'descriptors': descriptors.to(image),
        }

    def extract(self, img: torch.Tensor, **conf) -> dict:
        """ Perform extraction with online resizing"""
        if img.dim() == 3:
            img = img[None]  # add batch dim
        assert img.dim() == 4 and img.shape[0] == 1
        shape = img.shape[-2:][::-1]
        img, scales = ImagePreprocessor(
            **{**self.preprocess_conf, **conf})(img)
        feats = self.forward({'image': img})
        feats['image_size'] = torch.tensor(shape)[None].to(img).float()
        feats['keypoints'] = (feats['keypoints'] + .5) / scales[None] - .5
        return feats