|
import sys |
|
from pathlib import Path |
|
import subprocess |
|
import torch |
|
from ..utils.base_model import BaseModel |
|
from .. import logger |
|
|
|
gluestick_path = Path(__file__).parent / "../../third_party/GlueStick" |
|
sys.path.append(str(gluestick_path)) |
|
|
|
from gluestick import batch_to_np |
|
from gluestick.models.two_view_pipeline import TwoViewPipeline |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class GlueStick(BaseModel): |
|
default_conf = { |
|
"name": "two_view_pipeline", |
|
"model_name": "checkpoint_GlueStick_MD.tar", |
|
"use_lines": True, |
|
"max_keypoints": 1000, |
|
"max_lines": 300, |
|
"force_num_keypoints": False, |
|
} |
|
required_inputs = [ |
|
"image0", |
|
"image1", |
|
] |
|
|
|
gluestick_models = { |
|
"checkpoint_GlueStick_MD.tar": "https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar", |
|
} |
|
|
|
|
|
def _init(self, conf): |
|
model_path = ( |
|
gluestick_path / "resources" / "weights" / conf["model_name"] |
|
) |
|
|
|
|
|
if not model_path.exists(): |
|
model_path.parent.mkdir(exist_ok=True) |
|
link = self.gluestick_models[conf["model_name"]] |
|
cmd = ["wget", link, "-O", str(model_path)] |
|
logger.info(f"Downloading the Gluestick model with `{cmd}`.") |
|
subprocess.run(cmd, check=True) |
|
logger.info(f"Loading GlueStick model...") |
|
|
|
gluestick_conf = { |
|
"name": "two_view_pipeline", |
|
"use_lines": True, |
|
"extractor": { |
|
"name": "wireframe", |
|
"sp_params": { |
|
"force_num_keypoints": False, |
|
"max_num_keypoints": 1000, |
|
}, |
|
"wireframe_params": { |
|
"merge_points": True, |
|
"merge_line_endpoints": True, |
|
}, |
|
"max_n_lines": 300, |
|
}, |
|
"matcher": { |
|
"name": "gluestick", |
|
"weights": str(model_path), |
|
"trainable": False, |
|
}, |
|
"ground_truth": { |
|
"from_pose_depth": False, |
|
}, |
|
} |
|
gluestick_conf["extractor"]["sp_params"]["max_num_keypoints"] = conf[ |
|
"max_keypoints" |
|
] |
|
gluestick_conf["extractor"]["sp_params"]["force_num_keypoints"] = conf[ |
|
"force_num_keypoints" |
|
] |
|
gluestick_conf["extractor"]["max_n_lines"] = conf["max_lines"] |
|
self.net = TwoViewPipeline(gluestick_conf) |
|
|
|
def _forward(self, data): |
|
pred = self.net(data) |
|
|
|
pred = batch_to_np(pred) |
|
kp0, kp1 = pred["keypoints0"], pred["keypoints1"] |
|
m0 = pred["matches0"] |
|
|
|
line_seg0, line_seg1 = pred["lines0"], pred["lines1"] |
|
line_matches = pred["line_matches0"] |
|
|
|
valid_matches = m0 != -1 |
|
match_indices = m0[valid_matches] |
|
matched_kps0 = kp0[valid_matches] |
|
matched_kps1 = kp1[match_indices] |
|
|
|
valid_matches = line_matches != -1 |
|
match_indices = line_matches[valid_matches] |
|
matched_lines0 = line_seg0[valid_matches] |
|
matched_lines1 = line_seg1[match_indices] |
|
|
|
pred["raw_lines0"], pred["raw_lines1"] = line_seg0, line_seg1 |
|
pred["lines0"], pred["lines1"] = matched_lines0, matched_lines1 |
|
pred["keypoints0"], pred["keypoints1"] = torch.from_numpy( |
|
matched_kps0 |
|
), torch.from_numpy(matched_kps1) |
|
pred = {**pred, **data} |
|
return pred |
|
|