|
import subprocess |
|
import sys |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
from .. import logger |
|
from ..utils.base_model import BaseModel |
|
|
|
sold2_path = Path(__file__).parent / "../../third_party/SOLD2" |
|
sys.path.append(str(sold2_path)) |
|
|
|
from sold2.model.line_matcher import LineMatcher |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class SOLD2(BaseModel): |
|
default_conf = { |
|
"weights": "sold2_wireframe.tar", |
|
"match_threshold": 0.2, |
|
"checkpoint_dir": sold2_path / "pretrained", |
|
"detect_thresh": 0.25, |
|
"multiscale": False, |
|
"valid_thresh": 1e-3, |
|
"num_blocks": 20, |
|
"overlap_ratio": 0.5, |
|
} |
|
required_inputs = [ |
|
"image0", |
|
"image1", |
|
] |
|
|
|
weight_urls = { |
|
"sold2_wireframe.tar": "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download", |
|
} |
|
|
|
|
|
def _init(self, conf): |
|
checkpoint_path = conf["checkpoint_dir"] / conf["weights"] |
|
|
|
|
|
if not checkpoint_path.exists(): |
|
checkpoint_path.parent.mkdir(exist_ok=True) |
|
link = self.weight_urls[conf["weights"]] |
|
cmd = ["wget", "--quiet", link, "-O", str(checkpoint_path)] |
|
logger.info(f"Downloading the SOLD2 model with `{cmd}`.") |
|
subprocess.run(cmd, check=True) |
|
|
|
mode = "dynamic" |
|
match_config = { |
|
"model_cfg": { |
|
"model_name": "lcnn_simple", |
|
"model_architecture": "simple", |
|
|
|
"backbone": "lcnn", |
|
"backbone_cfg": { |
|
"input_channel": 1, |
|
"depth": 4, |
|
"num_stacks": 2, |
|
"num_blocks": 1, |
|
"num_classes": 5, |
|
}, |
|
|
|
"junction_decoder": "superpoint_decoder", |
|
"junc_decoder_cfg": {}, |
|
|
|
"heatmap_decoder": "pixel_shuffle", |
|
"heatmap_decoder_cfg": {}, |
|
|
|
"descriptor_decoder": "superpoint_descriptor", |
|
"descriptor_decoder_cfg": {}, |
|
|
|
"grid_size": 8, |
|
"keep_border_valid": True, |
|
|
|
"detection_thresh": 0.0153846, |
|
"max_num_junctions": 300, |
|
|
|
"prob_thresh": 0.5, |
|
|
|
"weighting_policy": mode, |
|
|
|
"w_heatmap": 0.0, |
|
"w_heatmap_class": 1, |
|
"heatmap_loss_func": "cross_entropy", |
|
"heatmap_loss_cfg": {"policy": mode}, |
|
|
|
|
|
"w_junc": 0.0, |
|
"junction_loss_func": "superpoint", |
|
"junction_loss_cfg": {"policy": mode}, |
|
|
|
"w_desc": 0.0, |
|
"descriptor_loss_func": "regular_sampling", |
|
"descriptor_loss_cfg": { |
|
"dist_threshold": 8, |
|
"grid_size": 4, |
|
"margin": 1, |
|
"policy": mode, |
|
}, |
|
}, |
|
"line_detector_cfg": { |
|
"detect_thresh": 0.25, |
|
"num_samples": 64, |
|
"sampling_method": "local_max", |
|
"inlier_thresh": 0.9, |
|
"use_candidate_suppression": True, |
|
"nms_dist_tolerance": 3.0, |
|
"use_heatmap_refinement": True, |
|
"heatmap_refine_cfg": { |
|
"mode": "local", |
|
"ratio": 0.2, |
|
"valid_thresh": 1e-3, |
|
"num_blocks": 20, |
|
"overlap_ratio": 0.5, |
|
}, |
|
}, |
|
"multiscale": False, |
|
"line_matcher_cfg": { |
|
"cross_check": True, |
|
"num_samples": 5, |
|
"min_dist_pts": 8, |
|
"top_k_candidates": 10, |
|
"grid_size": 4, |
|
}, |
|
} |
|
self.net = LineMatcher( |
|
match_config["model_cfg"], |
|
checkpoint_path, |
|
device, |
|
match_config["line_detector_cfg"], |
|
match_config["line_matcher_cfg"], |
|
match_config["multiscale"], |
|
) |
|
|
|
def _forward(self, data): |
|
img0 = data["image0"] |
|
img1 = data["image1"] |
|
pred = self.net([img0, img1]) |
|
line_seg1 = pred["line_segments"][0] |
|
line_seg2 = pred["line_segments"][1] |
|
matches = pred["matches"] |
|
|
|
valid_matches = matches != -1 |
|
match_indices = matches[valid_matches] |
|
matched_lines1 = line_seg1[valid_matches][:, :, ::-1] |
|
matched_lines2 = line_seg2[match_indices][:, :, ::-1] |
|
|
|
pred["raw_lines0"], pred["raw_lines1"] = line_seg1, line_seg2 |
|
pred["lines0"], pred["lines1"] = matched_lines1, matched_lines2 |
|
pred = {**pred, **data} |
|
return pred |
|
|