Vincentqyw
update: limit keypoints number
e15a186
raw
history blame
3.61 kB
import sys
from pathlib import Path
import subprocess
import logging
import torch
from ..utils.base_model import BaseModel
logger = logging.getLogger(__name__)
gluestick_path = Path(__file__).parent / "../../third_party/GlueStick"
sys.path.append(str(gluestick_path))
from gluestick import batch_to_np
from gluestick.models.two_view_pipeline import TwoViewPipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class GlueStick(BaseModel):
default_conf = {
"name": "two_view_pipeline",
"model_name": "checkpoint_GlueStick_MD.tar",
"use_lines": True,
"max_keypoints": 1000,
"max_lines": 300,
"force_num_keypoints": False,
}
required_inputs = [
"image0",
"image1",
]
gluestick_models = {
"checkpoint_GlueStick_MD.tar": "https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar",
}
# Initialize the line matcher
def _init(self, conf):
model_path = (
gluestick_path / "resources" / "weights" / conf["model_name"]
)
# Download the model.
if not model_path.exists():
model_path.parent.mkdir(exist_ok=True)
link = self.gluestick_models[conf["model_name"]]
cmd = ["wget", link, "-O", str(model_path)]
logger.info(f"Downloading the Gluestick model with `{cmd}`.")
subprocess.run(cmd, check=True)
logger.info(f"Loading GlueStick model...")
gluestick_conf = {
"name": "two_view_pipeline",
"use_lines": True,
"extractor": {
"name": "wireframe",
"sp_params": {
"force_num_keypoints": False,
"max_num_keypoints": 1000,
},
"wireframe_params": {
"merge_points": True,
"merge_line_endpoints": True,
},
"max_n_lines": 300,
},
"matcher": {
"name": "gluestick",
"weights": str(model_path),
"trainable": False,
},
"ground_truth": {
"from_pose_depth": False,
},
}
gluestick_conf["extractor"]["sp_params"]["max_num_keypoints"] = conf[
"max_keypoints"
]
gluestick_conf["extractor"]["sp_params"]["force_num_keypoints"] = conf[
"force_num_keypoints"
]
gluestick_conf["extractor"]["max_n_lines"] = conf["max_lines"]
self.net = TwoViewPipeline(gluestick_conf)
def _forward(self, data):
pred = self.net(data)
pred = batch_to_np(pred)
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0 = pred["matches0"]
line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
line_matches = pred["line_matches0"]
valid_matches = m0 != -1
match_indices = m0[valid_matches]
matched_kps0 = kp0[valid_matches]
matched_kps1 = kp1[match_indices]
valid_matches = line_matches != -1
match_indices = line_matches[valid_matches]
matched_lines0 = line_seg0[valid_matches]
matched_lines1 = line_seg1[match_indices]
pred["raw_lines0"], pred["raw_lines1"] = line_seg0, line_seg1
pred["lines0"], pred["lines1"] = matched_lines0, matched_lines1
pred["keypoints0"], pred["keypoints1"] = torch.from_numpy(
matched_kps0
), torch.from_numpy(matched_kps1)
pred = {**pred, **data}
return pred