File size: 3,918 Bytes
4d9207d
9223079
 
4d9207d
9223079
 
 
4d9207d
 
 
 
9223079
 
 
4d9207d
9223079
 
 
 
d46c0a9
9223079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b78237
 
 
9223079
 
 
 
 
 
 
 
 
 
 
4d9207d
9223079
 
 
 
 
 
4d9207d
2b78237
 
 
9223079
 
 
 
2b78237
 
 
 
 
 
 
 
 
4d9207d
9223079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b78237
 
 
9223079
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import subprocess
import sys
from pathlib import Path

import torch
import torchvision.transforms as transforms

from hloc import logger

from ..utils.base_model import BaseModel

dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
sys.path.append(str(dedode_path))

from DeDoDe import dedode_descriptor_B, dedode_detector_L
from DeDoDe.utils import to_pixel_coords

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class DeDoDe(BaseModel):
    default_conf = {
        "name": "dedode",
        "model_detector_name": "dedode_detector_L.pth",
        "model_descriptor_name": "dedode_descriptor_B.pth",
        "max_keypoints": 2000,
        "match_threshold": 0.2,
        "dense": False,  # Now fixed to be false
    }
    required_inputs = [
        "image",
    ]
    weight_urls = {
        "dedode_detector_L.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
        "dedode_descriptor_B.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
    }

    # Initialize the line matcher
    def _init(self, conf):
        model_detector_path = (
            dedode_path / "pretrained" / conf["model_detector_name"]
        )
        model_descriptor_path = (
            dedode_path / "pretrained" / conf["model_descriptor_name"]
        )

        self.normalizer = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        # Download the model.
        if not model_detector_path.exists():
            model_detector_path.parent.mkdir(exist_ok=True)
            link = self.weight_urls[conf["model_detector_name"]]
            cmd = ["wget", "--quiet", link, "-O", str(model_detector_path)]
            logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
            subprocess.run(cmd, check=True)

        if not model_descriptor_path.exists():
            model_descriptor_path.parent.mkdir(exist_ok=True)
            link = self.weight_urls[conf["model_descriptor_name"]]
            cmd = ["wget", "--quiet", link, "-O", str(model_descriptor_path)]
            logger.info(
                f"Downloading the DeDoDe descriptor model with `{cmd}`."
            )
            subprocess.run(cmd, check=True)

        # load the model
        weights_detector = torch.load(model_detector_path, map_location="cpu")
        weights_descriptor = torch.load(
            model_descriptor_path, map_location="cpu"
        )
        self.detector = dedode_detector_L(
            weights=weights_detector, device=device
        )
        self.descriptor = dedode_descriptor_B(
            weights=weights_descriptor, device=device
        )
        logger.info("Load DeDoDe model done.")

    def _forward(self, data):
        """
        data: dict, keys: {'image0','image1'}
        image shape: N x C x H x W
        color mode: RGB
        """
        img0 = self.normalizer(data["image"].squeeze()).float()[None]
        H_A, W_A = img0.shape[2:]

        # step 1: detect keypoints
        detections_A = None
        batch_A = {"image": img0}
        if self.conf["dense"]:
            detections_A = self.detector.detect_dense(batch_A)
        else:
            detections_A = self.detector.detect(
                batch_A, num_keypoints=self.conf["max_keypoints"]
            )
        keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]

        # step 2: describe keypoints
        # dim: 1 x N x 256
        description_A = self.descriptor.describe_keypoints(
            batch_A, keypoints_A
        )["descriptions"]
        keypoints_A = to_pixel_coords(keypoints_A, H_A, W_A)

        return {
            "keypoints": keypoints_A,  # 1 x N x 2
            "descriptors": description_A.permute(0, 2, 1),  # 1 x 256 x N
            "scores": P_A,  # 1 x N
        }