Vincentqyw
update: limit keypoints number
60ad158
raw
history blame
3.99 kB
import sys
from pathlib import Path
import subprocess
import logging
import torch
from PIL import Image
from ..utils.base_model import BaseModel
import torchvision.transforms as transforms
dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
sys.path.append(str(dedode_path))
from DeDoDe import dedode_detector_L, dedode_descriptor_B
from DeDoDe.utils import to_pixel_coords
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = logging.getLogger(__name__)
class DeDoDe(BaseModel):
default_conf = {
"name": "dedode",
"model_detector_name": "dedode_detector_L.pth",
"model_descriptor_name": "dedode_descriptor_B.pth",
"max_keypoints": 2000,
"match_threshold": 0.2,
"dense": False, # Now fixed to be false
}
required_inputs = [
"image",
]
weight_urls = {
"dedode_detector_L.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
"dedode_descriptor_B.pth": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
}
# Initialize the line matcher
def _init(self, conf):
model_detector_path = (
dedode_path / "pretrained" / conf["model_detector_name"]
)
model_descriptor_path = (
dedode_path / "pretrained" / conf["model_descriptor_name"]
)
self.normalizer = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
# Download the model.
if not model_detector_path.exists():
model_detector_path.parent.mkdir(exist_ok=True)
link = self.weight_urls[conf["model_detector_name"]]
cmd = ["wget", link, "-O", str(model_detector_path)]
logger.info(f"Downloading the DeDoDe detector model with `{cmd}`.")
subprocess.run(cmd, check=True)
if not model_descriptor_path.exists():
model_descriptor_path.parent.mkdir(exist_ok=True)
link = self.weight_urls[conf["model_descriptor_name"]]
cmd = ["wget", link, "-O", str(model_descriptor_path)]
logger.info(
f"Downloading the DeDoDe descriptor model with `{cmd}`."
)
subprocess.run(cmd, check=True)
logger.info(f"Loading DeDoDe model...")
# load the model
weights_detector = torch.load(model_detector_path, map_location="cpu")
weights_descriptor = torch.load(
model_descriptor_path, map_location="cpu"
)
self.detector = dedode_detector_L(
weights=weights_detector, device=device
)
self.descriptor = dedode_descriptor_B(
weights=weights_descriptor, device=device
)
logger.info(f"Load DeDoDe model done.")
def _forward(self, data):
"""
data: dict, keys: {'image0','image1'}
image shape: N x C x H x W
color mode: RGB
"""
img0 = self.normalizer(data["image"].squeeze()).float()[None]
H_A, W_A = img0.shape[2:]
# step 1: detect keypoints
detections_A = None
batch_A = {"image": img0}
if self.conf["dense"]:
detections_A = self.detector.detect_dense(batch_A)
else:
detections_A = self.detector.detect(
batch_A, num_keypoints=self.conf["max_keypoints"]
)
keypoints_A, P_A = detections_A["keypoints"], detections_A["confidence"]
# step 2: describe keypoints
# dim: 1 x N x 256
description_A = self.descriptor.describe_keypoints(
batch_A, keypoints_A
)["descriptions"]
keypoints_A = to_pixel_coords(keypoints_A, H_A, W_A)
return {
"keypoints": keypoints_A, # 1 x N x 2
"descriptors": description_A.permute(0, 2, 1), # 1 x 256 x N
"scores": P_A, # 1 x N
}