Realcat
update: moving lfs files
8811cfe
raw
history blame
3.12 kB
import sys
from pathlib import Path
import torch
from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel
gluestick_path = Path(__file__).parent / "../../third_party/GlueStick"
sys.path.append(str(gluestick_path))
from gluestick import batch_to_np
from gluestick.models.two_view_pipeline import TwoViewPipeline
class GlueStick(BaseModel):
default_conf = {
"name": "two_view_pipeline",
"model_name": "checkpoint_GlueStick_MD.tar",
"use_lines": True,
"max_keypoints": 1000,
"max_lines": 300,
"force_num_keypoints": False,
}
required_inputs = [
"image0",
"image1",
]
# Initialize the line matcher
def _init(self, conf):
# Download the model.
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, self.conf["model_name"]
),
)
logger.info("Loading GlueStick model...")
gluestick_conf = {
"name": "two_view_pipeline",
"use_lines": True,
"extractor": {
"name": "wireframe",
"sp_params": {
"force_num_keypoints": False,
"max_num_keypoints": 1000,
},
"wireframe_params": {
"merge_points": True,
"merge_line_endpoints": True,
},
"max_n_lines": 300,
},
"matcher": {
"name": "gluestick",
"weights": str(model_path),
"trainable": False,
},
"ground_truth": {
"from_pose_depth": False,
},
}
gluestick_conf["extractor"]["sp_params"]["max_num_keypoints"] = conf[
"max_keypoints"
]
gluestick_conf["extractor"]["sp_params"]["force_num_keypoints"] = conf[
"force_num_keypoints"
]
gluestick_conf["extractor"]["max_n_lines"] = conf["max_lines"]
self.net = TwoViewPipeline(gluestick_conf)
def _forward(self, data):
pred = self.net(data)
pred = batch_to_np(pred)
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0 = pred["matches0"]
line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
line_matches = pred["line_matches0"]
valid_matches = m0 != -1
match_indices = m0[valid_matches]
matched_kps0 = kp0[valid_matches]
matched_kps1 = kp1[match_indices]
valid_matches = line_matches != -1
match_indices = line_matches[valid_matches]
matched_lines0 = line_seg0[valid_matches]
matched_lines1 = line_seg1[match_indices]
pred["raw_lines0"], pred["raw_lines1"] = line_seg0, line_seg1
pred["lines0"], pred["lines1"] = matched_lines0, matched_lines1
pred["keypoints0"], pred["keypoints1"] = torch.from_numpy(
matched_kps0
), torch.from_numpy(matched_kps1)
pred = {**pred, **data}
return pred