|
import subprocess |
|
import sys |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .. import logger |
|
from ..utils.base_model import BaseModel |
|
|
|
thirdparty_path = Path(__file__).parent / "../../third_party" |
|
sys.path.append(str(thirdparty_path)) |
|
from omniglue.src import omniglue |
|
|
|
omniglue_path = thirdparty_path / "omniglue" |
|
|
|
|
|
class OmniGlue(BaseModel): |
|
default_conf = { |
|
"match_threshold": 0.02, |
|
"max_keypoints": 2048, |
|
} |
|
required_inputs = ["image0", "image1"] |
|
dino_v2_link_dict = { |
|
"dinov2_vitb14_pretrain.pth": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" |
|
} |
|
|
|
def _init(self, conf): |
|
logger.info("Loading OmniGlue model") |
|
og_model_path = omniglue_path / "models" / "omniglue.onnx" |
|
sp_model_path = omniglue_path / "models" / "sp_v6.onnx" |
|
dino_model_path = ( |
|
omniglue_path / "models" / "dinov2_vitb14_pretrain.pth" |
|
) |
|
if not dino_model_path.exists(): |
|
link = self.dino_v2_link_dict.get(dino_model_path.name, None) |
|
if link is not None: |
|
cmd = ["wget", "--quiet", link, "-O", str(dino_model_path)] |
|
logger.info(f"Downloading the dinov2 model with `{cmd}`.") |
|
subprocess.run(cmd, check=True) |
|
else: |
|
logger.error(f"Invalid dinov2 model: {dino_model_path.name}") |
|
self.net = omniglue.OmniGlue( |
|
og_export=str(og_model_path), |
|
sp_export=str(sp_model_path), |
|
dino_export=str(dino_model_path), |
|
max_keypoints=self.conf["max_keypoints"], |
|
) |
|
logger.info("Loaded OmniGlue model done!") |
|
|
|
def _forward(self, data): |
|
image0_rgb_np = data["image0"][0].permute(1, 2, 0).cpu().numpy() * 255 |
|
image1_rgb_np = data["image1"][0].permute(1, 2, 0).cpu().numpy() * 255 |
|
image0_rgb_np = image0_rgb_np.astype(np.uint8) |
|
image1_rgb_np = image1_rgb_np.astype(np.uint8) |
|
match_kp0, match_kp1, match_confidences = self.net.FindMatches( |
|
image0_rgb_np, image1_rgb_np, self.conf["max_keypoints"] |
|
) |
|
|
|
match_threshold = self.conf["match_threshold"] |
|
keep_idx = [] |
|
for i in range(match_kp0.shape[0]): |
|
if match_confidences[i] > match_threshold: |
|
keep_idx.append(i) |
|
scores = torch.from_numpy(match_confidences[keep_idx]).reshape(-1, 1) |
|
pred = { |
|
"keypoints0": torch.from_numpy(match_kp0[keep_idx]), |
|
"keypoints1": torch.from_numpy(match_kp1[keep_idx]), |
|
"mconf": scores, |
|
} |
|
|
|
top_k = self.conf["max_keypoints"] |
|
if top_k is not None and len(scores) > top_k: |
|
keep = torch.argsort(scores, descending=True)[:top_k] |
|
scores = scores[keep] |
|
pred["keypoints0"], pred["keypoints1"], pred["mconf"] = ( |
|
pred["keypoints0"][keep], |
|
pred["keypoints1"][keep], |
|
scores, |
|
) |
|
return pred |
|
|