File size: 3,918 Bytes
8320ccc
9223079
 
8320ccc
9223079
 
 
8320ccc
 
 
 
9223079
 
 
8320ccc
9223079
 
 
 
2eaeef9
9223079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e15a186
 
 
9223079
 
 
 
 
 
 
 
 
 
 
8320ccc
9223079
 
 
 
 
 
8320ccc
e15a186
 
 
9223079
 
 
 
e15a186
 
 
 
 
 
 
 
 
8320ccc
9223079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e15a186
 
 
9223079
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import subprocess
import sys
from pathlib import Path

import torch
import torchvision.transforms as transforms

from hloc import logger

from ..utils.base_model import BaseModel

dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
sys.path.append(str(dedode_path))

from DeDoDe import dedode_descriptor_B, dedode_detector_L
from DeDoDe.utils import to_pixel_coords

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class DeDoDe(BaseModel):
    default_conf = {
        "name": "dedode",
        "model_detector_name": "dedode_detector_L.pth",
        "model_descriptor_name": "dedode_descriptor_B.pth",
        "max_keypoints": 2000,
        "match_threshold": 0.2,
        "dense": False,  # Now fixed to be false
    }
    required_inputs = [
        "image",
    ]
    weight_urls = {
        "dedode_detector_L.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
        "dedode_descriptor_B.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
    }

    # Initialize the line matcher
    def _init(self, conf):
        model_detector_path = (
            dedode_path / "pretrained" / conf["model_detector_name"]
        )
        model_descriptor_path = (
            dedode_path / "pretrained" / conf["model_descriptor_name"]
        )

        self.normalizer = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        # Download the model.
        if not model_detector_path.exists():
            model_detector_path.parent.mkdir(exist_ok=True)
            link = self.weight_urls[conf["model_detector_name"]]
            cmd = ["wget", "--quiet", link, "-O", str(model_detector_path)]
            logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
            subprocess.run(cmd, check=True)

        if not model_descriptor_path.exists():
            model_descriptor_path.parent.mkdir(exist_ok=True)
            link = self.weight_urls[conf["model_descriptor_name"]]
            cmd = ["wget", "--quiet", link, "-O", str(model_descriptor_path)]
            logger.info(
                f"Downloading the DeDoDe descriptor model with `{cmd}`."
            )
            subprocess.run(cmd, check=True)

        # load the model
        weights_detector = torch.load(model_detector_path, map_location="cpu")
        weights_descriptor = torch.load(
            model_descriptor_path, map_location="cpu"
        )
        self.detector = dedode_detector_L(
            weights=weights_detector, device=device
        )
        self.descriptor = dedode_descriptor_B(
            weights=weights_descriptor, device=device
        )
        logger.info("Load DeDoDe model done.")

    def _forward(self, data):
        """
        data: dict, keys: {'image0','image1'}
        image shape: N x C x H x W
        color mode: RGB
        """
        img0 = self.normalizer(data["image"].squeeze()).float()[None]
        H_A, W_A = img0.shape[2:]

        # step 1: detect keypoints
        detections_A = None
        batch_A = {"image": img0}
        if self.conf["dense"]:
            detections_A = self.detector.detect_dense(batch_A)
        else:
            detections_A = self.detector.detect(
                batch_A, num_keypoints=self.conf["max_keypoints"]
            )
        keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]

        # step 2: describe keypoints
        # dim: 1 x N x 256
        description_A = self.descriptor.describe_keypoints(
            batch_A, keypoints_A
        )["descriptions"]
        keypoints_A = to_pixel_coords(keypoints_A, H_A, W_A)

        return {
            "keypoints": keypoints_A,  # 1 x N x 2
            "descriptors": description_A.permute(0, 2, 1),  # 1 x 256 x N
            "scores": P_A,  # 1 x N
        }