|
""" |
|
Implementation of the line segment detection module. |
|
""" |
|
import math |
|
import numpy as np |
|
import torch |
|
|
|
|
|
class LineSegmentDetectionModule(object): |
|
"""Module extracting line segments from junctions and line heatmaps.""" |
|
|
|
def __init__( |
|
self, |
|
detect_thresh, |
|
num_samples=64, |
|
sampling_method="local_max", |
|
inlier_thresh=0.0, |
|
heatmap_low_thresh=0.15, |
|
heatmap_high_thresh=0.2, |
|
max_local_patch_radius=3, |
|
lambda_radius=2.0, |
|
use_candidate_suppression=False, |
|
nms_dist_tolerance=3.0, |
|
use_heatmap_refinement=False, |
|
heatmap_refine_cfg=None, |
|
use_junction_refinement=False, |
|
junction_refine_cfg=None, |
|
): |
|
""" |
|
Parameters: |
|
detect_thresh: The probability threshold for mean activation (0. ~ 1.) |
|
num_samples: Number of sampling locations along the line segments. |
|
sampling_method: Sampling method on locations ("bilinear" or "local_max"). |
|
inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold. |
|
heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery. |
|
heatmap_high_thresh: The higher threshold for NMS in junction recovery. |
|
max_local_patch_radius: The max patch to be considered in local maximum search. |
|
lambda_radius: The lambda factor in linear local maximum search formulation |
|
use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments. |
|
nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line. |
|
use_heatmap_refinement: Use heatmap refinement method or not. |
|
heatmap_refine_cfg: The configs for heatmap refinement methods. |
|
use_junction_refinement: Use junction refinement method or not. |
|
junction_refine_cfg: The configs for junction refinement methods. |
|
""" |
|
|
|
self.detect_thresh = detect_thresh |
|
|
|
|
|
self.num_samples = num_samples |
|
self.sampling_method = sampling_method |
|
self.inlier_thresh = inlier_thresh |
|
self.local_patch_radius = max_local_patch_radius |
|
self.lambda_radius = lambda_radius |
|
|
|
|
|
self.low_thresh = heatmap_low_thresh |
|
self.high_thresh = heatmap_high_thresh |
|
|
|
|
|
self.sampler = np.linspace(0, 1, self.num_samples) |
|
self.torch_sampler = torch.linspace(0, 1, self.num_samples) |
|
|
|
|
|
self.use_candidate_suppression = use_candidate_suppression |
|
self.nms_dist_tolerance = nms_dist_tolerance |
|
|
|
|
|
self.use_heatmap_refinement = use_heatmap_refinement |
|
self.heatmap_refine_cfg = heatmap_refine_cfg |
|
if self.use_heatmap_refinement and self.heatmap_refine_cfg is None: |
|
raise ValueError("[Error] Missing heatmap refinement config.") |
|
|
|
|
|
self.use_junction_refinement = use_junction_refinement |
|
self.junction_refine_cfg = junction_refine_cfg |
|
if self.use_junction_refinement and self.junction_refine_cfg is None: |
|
raise ValueError("[Error] Missing junction refinement config.") |
|
|
|
def convert_inputs(self, inputs, device): |
|
"""Convert inputs to desired torch tensor.""" |
|
if isinstance(inputs, np.ndarray): |
|
outputs = torch.tensor(inputs, dtype=torch.float32, device=device) |
|
elif isinstance(inputs, torch.Tensor): |
|
outputs = inputs.to(torch.float32).to(device) |
|
else: |
|
raise ValueError( |
|
"[Error] Inputs must either be torch tensor or numpy ndarray." |
|
) |
|
|
|
return outputs |
|
|
|
def detect(self, junctions, heatmap, device=torch.device("cpu")): |
|
"""Main function performing line segment detection.""" |
|
|
|
junctions = self.convert_inputs(junctions, device=device) |
|
heatmap = self.convert_inputs(heatmap, device=device) |
|
|
|
|
|
if self.use_heatmap_refinement: |
|
if self.heatmap_refine_cfg["mode"] == "global": |
|
heatmap = self.refine_heatmap( |
|
heatmap, |
|
self.heatmap_refine_cfg["ratio"], |
|
self.heatmap_refine_cfg["valid_thresh"], |
|
) |
|
elif self.heatmap_refine_cfg["mode"] == "local": |
|
heatmap = self.refine_heatmap_local( |
|
heatmap, |
|
self.heatmap_refine_cfg["num_blocks"], |
|
self.heatmap_refine_cfg["overlap_ratio"], |
|
self.heatmap_refine_cfg["ratio"], |
|
self.heatmap_refine_cfg["valid_thresh"], |
|
) |
|
|
|
|
|
num_junctions = junctions.shape[0] |
|
line_map_pred = torch.zeros( |
|
[num_junctions, num_junctions], device=device, dtype=torch.int32 |
|
) |
|
|
|
|
|
if num_junctions < 2: |
|
return line_map_pred, junctions, heatmap |
|
|
|
|
|
candidate_map = torch.triu( |
|
torch.ones( |
|
[num_junctions, num_junctions], device=device, dtype=torch.int32 |
|
), |
|
diagonal=1, |
|
) |
|
|
|
|
|
if len(heatmap.shape) > 2: |
|
H, W, _ = heatmap.shape |
|
else: |
|
H, W = heatmap.shape |
|
|
|
|
|
if self.use_candidate_suppression: |
|
candidate_map = self.candidate_suppression(junctions, candidate_map) |
|
|
|
|
|
candidate_index_map = torch.where(candidate_map) |
|
candidate_index_map = torch.cat( |
|
[candidate_index_map[0][..., None], candidate_index_map[1][..., None]], |
|
dim=-1, |
|
) |
|
|
|
|
|
candidate_junc_start = junctions[candidate_index_map[:, 0], :] |
|
candidate_junc_end = junctions[candidate_index_map[:, 1], :] |
|
|
|
|
|
sampler = self.torch_sampler.to(device)[None, ...] |
|
cand_samples_h = candidate_junc_start[:, 0:1] * sampler + candidate_junc_end[ |
|
:, 0:1 |
|
] * (1 - sampler) |
|
cand_samples_w = candidate_junc_start[:, 1:2] * sampler + candidate_junc_end[ |
|
:, 1:2 |
|
] * (1 - sampler) |
|
|
|
|
|
cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1) |
|
cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1) |
|
|
|
|
|
if self.sampling_method == "local_max": |
|
|
|
segments_length = torch.sqrt( |
|
torch.sum( |
|
( |
|
candidate_junc_start.to(torch.float32) |
|
- candidate_junc_end.to(torch.float32) |
|
) |
|
** 2, |
|
dim=-1, |
|
) |
|
) |
|
normalized_seg_length = segments_length / (((H**2) + (W**2)) ** 0.5) |
|
|
|
|
|
num_cand = cand_h.shape[0] |
|
group_size = 10000 |
|
if num_cand > group_size: |
|
num_iter = math.ceil(num_cand / group_size) |
|
sampled_feat_lst = [] |
|
for iter_idx in range(num_iter): |
|
if not iter_idx == num_iter - 1: |
|
cand_h_ = cand_h[ |
|
iter_idx * group_size : (iter_idx + 1) * group_size, : |
|
] |
|
cand_w_ = cand_w[ |
|
iter_idx * group_size : (iter_idx + 1) * group_size, : |
|
] |
|
normalized_seg_length_ = normalized_seg_length[ |
|
iter_idx * group_size : (iter_idx + 1) * group_size |
|
] |
|
else: |
|
cand_h_ = cand_h[iter_idx * group_size :, :] |
|
cand_w_ = cand_w[iter_idx * group_size :, :] |
|
normalized_seg_length_ = normalized_seg_length[ |
|
iter_idx * group_size : |
|
] |
|
sampled_feat_ = self.detect_local_max( |
|
heatmap, cand_h_, cand_w_, H, W, normalized_seg_length_, device |
|
) |
|
sampled_feat_lst.append(sampled_feat_) |
|
sampled_feat = torch.cat(sampled_feat_lst, dim=0) |
|
else: |
|
sampled_feat = self.detect_local_max( |
|
heatmap, cand_h, cand_w, H, W, normalized_seg_length, device |
|
) |
|
|
|
elif self.sampling_method == "bilinear": |
|
|
|
sampled_feat = self.detect_bilinear(heatmap, cand_h, cand_w, H, W, device) |
|
else: |
|
raise ValueError("[Error] Unknown sampling method.") |
|
|
|
|
|
|
|
detection_results = torch.mean(sampled_feat, dim=-1) > self.detect_thresh |
|
|
|
|
|
if self.inlier_thresh > 0.0: |
|
inlier_ratio = ( |
|
torch.sum(sampled_feat > self.detect_thresh, dim=-1).to(torch.float32) |
|
/ self.num_samples |
|
) |
|
detection_results_inlier = inlier_ratio >= self.inlier_thresh |
|
detection_results = detection_results * detection_results_inlier |
|
|
|
|
|
detected_junc_indexes = candidate_index_map[detection_results, :] |
|
line_map_pred[detected_junc_indexes[:, 0], detected_junc_indexes[:, 1]] = 1 |
|
line_map_pred[detected_junc_indexes[:, 1], detected_junc_indexes[:, 0]] = 1 |
|
|
|
|
|
if self.use_junction_refinement and len(detected_junc_indexes) > 0: |
|
junctions, line_map_pred = self.refine_junction_perturb( |
|
junctions, line_map_pred, heatmap, H, W, device |
|
) |
|
|
|
return line_map_pred, junctions, heatmap |
|
|
|
def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2): |
|
"""Global heatmap refinement method.""" |
|
|
|
heatmap_values = heatmap[heatmap > valid_thresh] |
|
sorted_values = torch.sort(heatmap_values, descending=True)[0] |
|
top10_len = math.ceil(sorted_values.shape[0] * ratio) |
|
max20 = torch.mean(sorted_values[:top10_len]) |
|
heatmap = torch.clamp(heatmap / max20, min=0.0, max=1.0) |
|
return heatmap |
|
|
|
def refine_heatmap_local( |
|
self, heatmap, num_blocks=5, overlap_ratio=0.5, ratio=0.2, valid_thresh=2e-3 |
|
): |
|
"""Local heatmap refinement method.""" |
|
|
|
H, W = heatmap.shape |
|
increase_ratio = 1 - overlap_ratio |
|
h_block = round(H / (1 + (num_blocks - 1) * increase_ratio)) |
|
w_block = round(W / (1 + (num_blocks - 1) * increase_ratio)) |
|
|
|
count_map = torch.zeros(heatmap.shape, dtype=torch.int, device=heatmap.device) |
|
heatmap_output = torch.zeros( |
|
heatmap.shape, dtype=torch.float, device=heatmap.device |
|
) |
|
|
|
for h_idx in range(num_blocks): |
|
for w_idx in range(num_blocks): |
|
|
|
h_start = round(h_idx * h_block * increase_ratio) |
|
w_start = round(w_idx * w_block * increase_ratio) |
|
h_end = h_start + h_block if h_idx < num_blocks - 1 else H |
|
w_end = w_start + w_block if w_idx < num_blocks - 1 else W |
|
|
|
subheatmap = heatmap[h_start:h_end, w_start:w_end] |
|
if subheatmap.max() > valid_thresh: |
|
subheatmap = self.refine_heatmap( |
|
subheatmap, ratio, valid_thresh=valid_thresh |
|
) |
|
|
|
|
|
heatmap_output[h_start:h_end, w_start:w_end] += subheatmap |
|
count_map[h_start:h_end, w_start:w_end] += 1 |
|
heatmap_output = torch.clamp(heatmap_output / count_map, max=1.0, min=0.0) |
|
|
|
return heatmap_output |
|
|
|
def candidate_suppression(self, junctions, candidate_map): |
|
"""Suppress overlapping long lines in the candidate segments.""" |
|
|
|
dist_tolerance = self.nms_dist_tolerance |
|
|
|
|
|
|
|
line_dist_map = ( |
|
torch.sum( |
|
(torch.unsqueeze(junctions, dim=1) - junctions[None, ...]) ** 2, dim=-1 |
|
) |
|
** 0.5 |
|
) |
|
|
|
|
|
seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1)) |
|
start_point_idxs = seg_indexes[0] |
|
end_point_idxs = seg_indexes[1] |
|
start_points = junctions[start_point_idxs, :] |
|
end_points = junctions[end_point_idxs, :] |
|
|
|
|
|
line_dists = line_dist_map[start_point_idxs, end_point_idxs] |
|
|
|
|
|
dir_vecs = (end_points - start_points) / torch.norm( |
|
end_points - start_points, dim=-1 |
|
)[..., None] |
|
|
|
cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1) |
|
cand_vecs_norm = torch.norm(cand_vecs, dim=-1) |
|
|
|
proj = ( |
|
torch.einsum("bij,bjk->bik", cand_vecs, dir_vecs[..., None]) |
|
/ line_dists[..., None, None] |
|
) |
|
|
|
proj_mask = (proj >= 0) * (proj <= 1) |
|
cand_angles = torch.acos( |
|
torch.einsum("bij,bjk->bik", cand_vecs, dir_vecs[..., None]) |
|
/ cand_vecs_norm[..., None] |
|
) |
|
cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles) |
|
junc_dist_mask = cand_dists <= dist_tolerance |
|
junc_mask = junc_dist_mask * proj_mask |
|
|
|
|
|
num_segs = start_point_idxs.shape[0] |
|
junc_counts = torch.sum(junc_mask, dim=[1, 2]) |
|
junc_counts -= junc_mask[..., 0][ |
|
torch.arange(0, num_segs), start_point_idxs |
|
].to(torch.int) |
|
junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), end_point_idxs].to( |
|
torch.int |
|
) |
|
|
|
|
|
final_mask = junc_counts > 0 |
|
candidate_map[start_point_idxs[final_mask], end_point_idxs[final_mask]] = 0 |
|
|
|
return candidate_map |
|
|
|
def refine_junction_perturb(self, junctions, line_map_pred, heatmap, H, W, device): |
|
"""Refine the line endpoints in a similar way as in LSD.""" |
|
|
|
junction_refine_cfg = self.junction_refine_cfg |
|
|
|
|
|
num_perturbs = junction_refine_cfg["num_perturbs"] |
|
perturb_interval = junction_refine_cfg["perturb_interval"] |
|
side_perturbs = (num_perturbs - 1) // 2 |
|
|
|
perturb_vec = torch.arange( |
|
start=-perturb_interval * side_perturbs, |
|
end=perturb_interval * (side_perturbs + 1), |
|
step=perturb_interval, |
|
device=device, |
|
) |
|
w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid( |
|
perturb_vec, perturb_vec, perturb_vec, perturb_vec |
|
) |
|
perturb_tensor = torch.cat( |
|
[ |
|
w1_grid[..., None], |
|
h1_grid[..., None], |
|
w2_grid[..., None], |
|
h2_grid[..., None], |
|
], |
|
dim=-1, |
|
) |
|
perturb_tensor_flat = perturb_tensor.view(-1, 2, 2) |
|
|
|
|
|
junctions = junctions.clone() |
|
line_map = line_map_pred |
|
|
|
|
|
detected_seg_indexes = torch.where(torch.triu(line_map, diagonal=1)) |
|
start_point_idxs = detected_seg_indexes[0] |
|
end_point_idxs = detected_seg_indexes[1] |
|
start_points = junctions[start_point_idxs, :] |
|
end_points = junctions[end_point_idxs, :] |
|
|
|
line_segments = torch.cat( |
|
[start_points.unsqueeze(dim=1), end_points.unsqueeze(dim=1)], dim=1 |
|
) |
|
|
|
line_segment_candidates = ( |
|
line_segments.unsqueeze(dim=1) + perturb_tensor_flat[None, ...] |
|
) |
|
|
|
line_segment_candidates[..., 0] = torch.clamp( |
|
line_segment_candidates[..., 0], min=0, max=H - 1 |
|
) |
|
line_segment_candidates[..., 1] = torch.clamp( |
|
line_segment_candidates[..., 1], min=0, max=W - 1 |
|
) |
|
|
|
|
|
refined_segment_lst = [] |
|
num_segments = line_segments.shape[0] |
|
for idx in range(num_segments): |
|
segment = line_segment_candidates[idx, ...] |
|
|
|
candidate_junc_start = segment[:, 0, :] |
|
candidate_junc_end = segment[:, 1, :] |
|
|
|
|
|
sampler = self.torch_sampler.to(device)[None, ...] |
|
cand_samples_h = candidate_junc_start[ |
|
:, 0:1 |
|
] * sampler + candidate_junc_end[:, 0:1] * (1 - sampler) |
|
cand_samples_w = candidate_junc_start[ |
|
:, 1:2 |
|
] * sampler + candidate_junc_end[:, 1:2] * (1 - sampler) |
|
|
|
|
|
cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1) |
|
cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1) |
|
|
|
|
|
segment_feat = self.detect_bilinear(heatmap, cand_h, cand_w, H, W, device) |
|
segment_results = torch.mean(segment_feat, dim=-1) |
|
max_idx = torch.argmax(segment_results) |
|
refined_segment_lst.append(segment[max_idx, ...][None, ...]) |
|
|
|
|
|
refined_segments = torch.cat(refined_segment_lst, dim=0) |
|
|
|
|
|
junctions_new = torch.cat( |
|
[refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0 |
|
) |
|
junctions_new = torch.unique(junctions_new, dim=0) |
|
line_map_new = self.segments_to_line_map(junctions_new, refined_segments) |
|
|
|
return junctions_new, line_map_new |
|
|
|
def segments_to_line_map(self, junctions, segments): |
|
"""Convert the list of segments to line map.""" |
|
|
|
device = junctions.device |
|
num_junctions = junctions.shape[0] |
|
line_map = torch.zeros([num_junctions, num_junctions], device=device) |
|
|
|
|
|
for idx in range(segments.shape[0]): |
|
|
|
seg = segments[idx, ...] |
|
junction1 = seg[0, :] |
|
junction2 = seg[1, :] |
|
|
|
|
|
idx_junction1 = torch.where((junctions == junction1).sum(axis=1) == 2)[0] |
|
idx_junction2 = torch.where((junctions == junction2).sum(axis=1) == 2)[0] |
|
|
|
|
|
line_map[idx_junction1, idx_junction2] = 1 |
|
line_map[idx_junction2, idx_junction1] = 1 |
|
|
|
return line_map |
|
|
|
def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device): |
|
"""Detection by bilinear sampling.""" |
|
|
|
cand_h_floor = torch.floor(cand_h).to(torch.long) |
|
cand_h_ceil = torch.ceil(cand_h).to(torch.long) |
|
cand_w_floor = torch.floor(cand_w).to(torch.long) |
|
cand_w_ceil = torch.ceil(cand_w).to(torch.long) |
|
|
|
|
|
cand_samples_feat = ( |
|
heatmap[cand_h_floor, cand_w_floor] |
|
* (cand_h_ceil - cand_h) |
|
* (cand_w_ceil - cand_w) |
|
+ heatmap[cand_h_floor, cand_w_ceil] |
|
* (cand_h_ceil - cand_h) |
|
* (cand_w - cand_w_floor) |
|
+ heatmap[cand_h_ceil, cand_w_floor] |
|
* (cand_h - cand_h_floor) |
|
* (cand_w_ceil - cand_w) |
|
+ heatmap[cand_h_ceil, cand_w_ceil] |
|
* (cand_h - cand_h_floor) |
|
* (cand_w - cand_w_floor) |
|
) |
|
|
|
return cand_samples_feat |
|
|
|
def detect_local_max( |
|
self, heatmap, cand_h, cand_w, H, W, normalized_seg_length, device |
|
): |
|
"""Detection by local maximum search.""" |
|
|
|
dist_thresh = 0.5 * (2**0.5) + self.lambda_radius * normalized_seg_length |
|
|
|
dist_thresh = torch.repeat_interleave( |
|
dist_thresh[..., None], self.num_samples, dim=-1 |
|
) |
|
|
|
|
|
cand_points = torch.cat([cand_h[..., None], cand_w[..., None]], dim=-1) |
|
cand_points_round = torch.round(cand_points) |
|
|
|
|
|
patch_mask = torch.zeros( |
|
[ |
|
int(2 * self.local_patch_radius + 1), |
|
int(2 * self.local_patch_radius + 1), |
|
], |
|
device=device, |
|
) |
|
patch_center = torch.tensor( |
|
[[self.local_patch_radius, self.local_patch_radius]], |
|
device=device, |
|
dtype=torch.float32, |
|
) |
|
H_patch_points, W_patch_points = torch.where(patch_mask >= 0) |
|
patch_points = torch.cat( |
|
[H_patch_points[..., None], W_patch_points[..., None]], dim=-1 |
|
) |
|
|
|
patch_center_dist = torch.sqrt( |
|
torch.sum((patch_points - patch_center) ** 2, dim=-1) |
|
) |
|
patch_points = patch_points[patch_center_dist <= self.local_patch_radius, :] |
|
|
|
patch_points = patch_points - self.local_patch_radius |
|
|
|
|
|
patch_points_shifted = ( |
|
torch.unsqueeze(cand_points_round, dim=2) + patch_points[None, None, ...] |
|
) |
|
patch_dist = torch.sqrt( |
|
torch.sum( |
|
(torch.unsqueeze(cand_points, dim=2) - patch_points_shifted) ** 2, |
|
dim=-1, |
|
) |
|
) |
|
patch_dist_mask = patch_dist < dist_thresh[..., None] |
|
|
|
|
|
points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0, max=H - 1).to( |
|
torch.long |
|
) |
|
points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0, max=W - 1).to( |
|
torch.long |
|
) |
|
points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1) |
|
|
|
|
|
sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]] |
|
|
|
sampled_feat = sampled_feat * patch_dist_mask.to(torch.float32) |
|
if len(sampled_feat) == 0: |
|
sampled_feat_lmax = torch.empty(0, 64) |
|
else: |
|
sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1) |
|
|
|
return sampled_feat_lmax |
|
|