|
""" |
|
Implements the full pipeline from raw images to line matches. |
|
""" |
|
import time |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.functional import softmax |
|
|
|
from .model_util import get_model |
|
from .loss import get_loss_and_weights |
|
from .metrics import super_nms |
|
from .line_detection import LineSegmentDetectionModule |
|
from .line_matching import WunschLineMatcher |
|
from ..train import convert_junc_predictions |
|
from ..misc.train_utils import adapt_checkpoint |
|
from .line_detector import line_map_to_segments |
|
|
|
|
|
class LineMatcher(object): |
|
"""Full line matcher including line detection and matching |
|
with the Needleman-Wunsch algorithm.""" |
|
|
|
def __init__( |
|
self, |
|
model_cfg, |
|
ckpt_path, |
|
device, |
|
line_detector_cfg, |
|
line_matcher_cfg, |
|
multiscale=False, |
|
scales=[1.0, 2.0], |
|
): |
|
|
|
_, loss_weights = get_loss_and_weights(model_cfg, device) |
|
self.device = device |
|
|
|
|
|
self.model = get_model(model_cfg, loss_weights) |
|
checkpoint = torch.load(ckpt_path, map_location=self.device) |
|
checkpoint = adapt_checkpoint(checkpoint["model_state_dict"]) |
|
self.model.load_state_dict(checkpoint) |
|
self.model = self.model.to(self.device) |
|
self.model = self.model.eval() |
|
|
|
self.grid_size = model_cfg["grid_size"] |
|
self.junc_detect_thresh = model_cfg["detection_thresh"] |
|
self.max_num_junctions = model_cfg.get("max_num_junctions", 300) |
|
|
|
|
|
self.line_detector = LineSegmentDetectionModule(**line_detector_cfg) |
|
self.multiscale = multiscale |
|
self.scales = scales |
|
|
|
|
|
self.line_matcher = WunschLineMatcher(**line_matcher_cfg) |
|
|
|
|
|
for key, val in line_detector_cfg.items(): |
|
print(f"[Debug] {key}: {val}") |
|
|
|
|
|
|
|
|
|
def line_detection( |
|
self, input_image, valid_mask=None, desc_only=False, profile=False |
|
): |
|
|
|
if (not len(input_image.shape) == 4) or ( |
|
not isinstance(input_image, torch.Tensor) |
|
): |
|
raise ValueError("[Error] the input image should be a 4D torch tensor") |
|
|
|
|
|
input_image = input_image.to(self.device) |
|
|
|
|
|
start_time = time.time() |
|
with torch.no_grad(): |
|
net_outputs = self.model(input_image) |
|
|
|
outputs = {"descriptor": net_outputs["descriptors"]} |
|
|
|
if not desc_only: |
|
junc_np = convert_junc_predictions( |
|
net_outputs["junctions"], |
|
self.grid_size, |
|
self.junc_detect_thresh, |
|
self.max_num_junctions, |
|
) |
|
if valid_mask is None: |
|
junctions = np.where(junc_np["junc_pred_nms"].squeeze()) |
|
else: |
|
junctions = np.where(junc_np["junc_pred_nms"].squeeze() * valid_mask) |
|
junctions = np.concatenate( |
|
[junctions[0][..., None], junctions[1][..., None]], axis=-1 |
|
) |
|
|
|
if net_outputs["heatmap"].shape[1] == 2: |
|
|
|
heatmap = ( |
|
softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :] |
|
.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1) |
|
) |
|
else: |
|
heatmap = ( |
|
torch.sigmoid(net_outputs["heatmap"]) |
|
.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1) |
|
) |
|
heatmap = heatmap[0, :, :, 0] |
|
|
|
|
|
line_map, junctions, heatmap = self.line_detector.detect( |
|
junctions, heatmap, device=self.device |
|
) |
|
if isinstance(line_map, torch.Tensor): |
|
line_map = line_map.cpu().numpy() |
|
if isinstance(junctions, torch.Tensor): |
|
junctions = junctions.cpu().numpy() |
|
outputs["heatmap"] = heatmap.cpu().numpy() |
|
outputs["junctions"] = junctions |
|
|
|
|
|
if len(line_map.shape) > 2: |
|
num_detect_thresh = line_map.shape[0] |
|
num_inlier_thresh = line_map.shape[1] |
|
line_segments = [] |
|
for detect_idx in range(num_detect_thresh): |
|
line_segments_inlier = [] |
|
for inlier_idx in range(num_inlier_thresh): |
|
line_map_tmp = line_map[detect_idx, inlier_idx, :, :] |
|
line_segments_tmp = line_map_to_segments( |
|
junctions, line_map_tmp |
|
) |
|
line_segments_inlier.append(line_segments_tmp) |
|
line_segments.append(line_segments_inlier) |
|
else: |
|
line_segments = line_map_to_segments(junctions, line_map) |
|
|
|
outputs["line_segments"] = line_segments |
|
|
|
end_time = time.time() |
|
|
|
if profile: |
|
outputs["time"] = end_time - start_time |
|
|
|
return outputs |
|
|
|
|
|
def multiscale_line_detection( |
|
self, |
|
input_image, |
|
valid_mask=None, |
|
desc_only=False, |
|
profile=False, |
|
scales=[1.0, 2.0], |
|
aggregation="mean", |
|
): |
|
|
|
if (not len(input_image.shape) == 4) or ( |
|
not isinstance(input_image, torch.Tensor) |
|
): |
|
raise ValueError("[Error] the input image should be a 4D torch tensor") |
|
|
|
|
|
input_image = input_image.to(self.device) |
|
img_size = input_image.shape[2:4] |
|
desc_size = tuple(np.array(img_size) // 4) |
|
|
|
|
|
start_time = time.time() |
|
junctions, heatmaps, descriptors = [], [], [] |
|
for s in scales: |
|
|
|
resized_img = F.interpolate(input_image, scale_factor=s, mode="bilinear") |
|
|
|
|
|
with torch.no_grad(): |
|
net_outputs = self.model(resized_img) |
|
|
|
descriptors.append( |
|
F.interpolate( |
|
net_outputs["descriptors"], size=desc_size, mode="bilinear" |
|
) |
|
) |
|
|
|
if not desc_only: |
|
junc_prob = convert_junc_predictions( |
|
net_outputs["junctions"], self.grid_size |
|
)["junc_pred"] |
|
junctions.append( |
|
cv2.resize( |
|
junc_prob.squeeze(), |
|
(img_size[1], img_size[0]), |
|
interpolation=cv2.INTER_LINEAR, |
|
) |
|
) |
|
|
|
if net_outputs["heatmap"].shape[1] == 2: |
|
|
|
heatmap = softmax(net_outputs["heatmap"], dim=1)[:, 1:, :, :] |
|
else: |
|
heatmap = torch.sigmoid(net_outputs["heatmap"]) |
|
heatmaps.append(F.interpolate(heatmap, size=img_size, mode="bilinear")) |
|
|
|
|
|
if aggregation == "mean": |
|
|
|
descriptors = torch.stack(descriptors, dim=0).mean(0) |
|
else: |
|
|
|
descriptors = torch.stack(descriptors, dim=0).max(0)[0] |
|
outputs = {"descriptor": descriptors} |
|
|
|
if not desc_only: |
|
if aggregation == "mean": |
|
junctions = np.stack(junctions, axis=0).mean(0)[None] |
|
heatmap = torch.stack(heatmaps, dim=0).mean(0)[0, 0, :, :] |
|
heatmap = heatmap.cpu().numpy() |
|
else: |
|
junctions = np.stack(junctions, axis=0).max(0)[None] |
|
heatmap = torch.stack(heatmaps, dim=0).max(0)[0][0, 0, :, :] |
|
heatmap = heatmap.cpu().numpy() |
|
|
|
|
|
junc_pred_nms = super_nms( |
|
junctions[..., None], |
|
self.grid_size, |
|
self.junc_detect_thresh, |
|
self.max_num_junctions, |
|
) |
|
if valid_mask is None: |
|
junctions = np.where(junc_pred_nms.squeeze()) |
|
else: |
|
junctions = np.where(junc_pred_nms.squeeze() * valid_mask) |
|
junctions = np.concatenate( |
|
[junctions[0][..., None], junctions[1][..., None]], axis=-1 |
|
) |
|
|
|
|
|
line_map, junctions, heatmap = self.line_detector.detect( |
|
junctions, heatmap, device=self.device |
|
) |
|
if isinstance(line_map, torch.Tensor): |
|
line_map = line_map.cpu().numpy() |
|
if isinstance(junctions, torch.Tensor): |
|
junctions = junctions.cpu().numpy() |
|
outputs["heatmap"] = heatmap.cpu().numpy() |
|
outputs["junctions"] = junctions |
|
|
|
|
|
if len(line_map.shape) > 2: |
|
num_detect_thresh = line_map.shape[0] |
|
num_inlier_thresh = line_map.shape[1] |
|
line_segments = [] |
|
for detect_idx in range(num_detect_thresh): |
|
line_segments_inlier = [] |
|
for inlier_idx in range(num_inlier_thresh): |
|
line_map_tmp = line_map[detect_idx, inlier_idx, :, :] |
|
line_segments_tmp = line_map_to_segments( |
|
junctions, line_map_tmp |
|
) |
|
line_segments_inlier.append(line_segments_tmp) |
|
line_segments.append(line_segments_inlier) |
|
else: |
|
line_segments = line_map_to_segments(junctions, line_map) |
|
|
|
outputs["line_segments"] = line_segments |
|
|
|
end_time = time.time() |
|
|
|
if profile: |
|
outputs["time"] = end_time - start_time |
|
|
|
return outputs |
|
|
|
def __call__(self, images, valid_masks=[None, None], profile=False): |
|
|
|
if self.multiscale: |
|
forward_outputs = [ |
|
self.multiscale_line_detection( |
|
images[0], valid_masks[0], profile=profile, scales=self.scales |
|
), |
|
self.multiscale_line_detection( |
|
images[1], valid_masks[1], profile=profile, scales=self.scales |
|
), |
|
] |
|
else: |
|
forward_outputs = [ |
|
self.line_detection(images[0], valid_masks[0], profile=profile), |
|
self.line_detection(images[1], valid_masks[1], profile=profile), |
|
] |
|
line_seg1 = forward_outputs[0]["line_segments"] |
|
line_seg2 = forward_outputs[1]["line_segments"] |
|
desc1 = forward_outputs[0]["descriptor"] |
|
desc2 = forward_outputs[1]["descriptor"] |
|
|
|
|
|
start_time = time.time() |
|
matches = self.line_matcher.forward(line_seg1, line_seg2, desc1, desc2) |
|
end_time = time.time() |
|
|
|
outputs = {"line_segments": [line_seg1, line_seg2], "matches": matches} |
|
|
|
if profile: |
|
outputs["line_detection_time"] = ( |
|
forward_outputs[0]["time"] + forward_outputs[1]["time"] |
|
) |
|
outputs["line_matching_time"] = end_time - start_time |
|
|
|
return outputs |
|
|