|
""" |
|
Implementation of the line segment detection module. |
|
""" |
|
import math |
|
import numpy as np |
|
import torch |
|
|
|
|
|
class LineSegmentDetectionModule(object): |
|
""" Module extracting line segments from junctions and line heatmaps. """ |
|
def __init__( |
|
self, detect_thresh, num_samples=64, sampling_method="local_max", |
|
inlier_thresh=0., heatmap_low_thresh=0.15, heatmap_high_thresh=0.2, |
|
max_local_patch_radius=3, lambda_radius=2., |
|
use_candidate_suppression=False, nms_dist_tolerance=3., |
|
use_heatmap_refinement=False, heatmap_refine_cfg=None, |
|
use_junction_refinement=False, junction_refine_cfg=None): |
|
""" |
|
Parameters: |
|
detect_thresh: The probability threshold for mean activation (0. ~ 1.) |
|
num_samples: Number of sampling locations along the line segments. |
|
sampling_method: Sampling method on locations ("bilinear" or "local_max"). |
|
inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold. |
|
heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery. |
|
heatmap_high_thresh: The higher threshold for NMS in junction recovery. |
|
max_local_patch_radius: The max patch to be considered in local maximum search. |
|
lambda_radius: The lambda factor in linear local maximum search formulation |
|
use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments. |
|
nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line. |
|
use_heatmap_refinement: Use heatmap refinement method or not. |
|
heatmap_refine_cfg: The configs for heatmap refinement methods. |
|
use_junction_refinement: Use junction refinement method or not. |
|
junction_refine_cfg: The configs for junction refinement methods. |
|
""" |
|
|
|
self.detect_thresh = detect_thresh |
|
|
|
|
|
self.num_samples = num_samples |
|
self.sampling_method = sampling_method |
|
self.inlier_thresh = inlier_thresh |
|
self.local_patch_radius = max_local_patch_radius |
|
self.lambda_radius = lambda_radius |
|
|
|
|
|
self.low_thresh = heatmap_low_thresh |
|
self.high_thresh = heatmap_high_thresh |
|
|
|
|
|
self.sampler = np.linspace(0, 1, self.num_samples) |
|
self.torch_sampler = torch.linspace(0, 1, self.num_samples) |
|
|
|
|
|
self.use_candidate_suppression = use_candidate_suppression |
|
self.nms_dist_tolerance = nms_dist_tolerance |
|
|
|
|
|
self.use_heatmap_refinement = use_heatmap_refinement |
|
self.heatmap_refine_cfg = heatmap_refine_cfg |
|
if self.use_heatmap_refinement and self.heatmap_refine_cfg is None: |
|
raise ValueError("[Error] Missing heatmap refinement config.") |
|
|
|
|
|
self.use_junction_refinement = use_junction_refinement |
|
self.junction_refine_cfg = junction_refine_cfg |
|
if self.use_junction_refinement and self.junction_refine_cfg is None: |
|
raise ValueError("[Error] Missing junction refinement config.") |
|
|
|
def convert_inputs(self, inputs, device): |
|
""" Convert inputs to desired torch tensor. """ |
|
if isinstance(inputs, np.ndarray): |
|
outputs = torch.tensor(inputs, dtype=torch.float32, device=device) |
|
elif isinstance(inputs, torch.Tensor): |
|
outputs = inputs.to(torch.float32).to(device) |
|
else: |
|
raise ValueError( |
|
"[Error] Inputs must either be torch tensor or numpy ndarray.") |
|
|
|
return outputs |
|
|
|
def detect(self, junctions, heatmap, device=torch.device("cpu")): |
|
""" Main function performing line segment detection. """ |
|
|
|
junctions = self.convert_inputs(junctions, device=device) |
|
heatmap = self.convert_inputs(heatmap, device=device) |
|
|
|
|
|
if self.use_heatmap_refinement: |
|
if self.heatmap_refine_cfg["mode"] == "global": |
|
heatmap = self.refine_heatmap( |
|
heatmap, |
|
self.heatmap_refine_cfg["ratio"], |
|
self.heatmap_refine_cfg["valid_thresh"] |
|
) |
|
elif self.heatmap_refine_cfg["mode"] == "local": |
|
heatmap = self.refine_heatmap_local( |
|
heatmap, |
|
self.heatmap_refine_cfg["num_blocks"], |
|
self.heatmap_refine_cfg["overlap_ratio"], |
|
self.heatmap_refine_cfg["ratio"], |
|
self.heatmap_refine_cfg["valid_thresh"] |
|
) |
|
|
|
|
|
num_junctions = junctions.shape[0] |
|
line_map_pred = torch.zeros([num_junctions, num_junctions], |
|
device=device, dtype=torch.int32) |
|
|
|
|
|
if num_junctions < 2: |
|
return line_map_pred, junctions, heatmap |
|
|
|
|
|
candidate_map = torch.triu(torch.ones( |
|
[num_junctions, num_junctions], device=device, dtype=torch.int32), |
|
diagonal=1) |
|
|
|
|
|
if len(heatmap.shape) > 2: |
|
H, W, _ = heatmap.shape |
|
else: |
|
H, W = heatmap.shape |
|
|
|
|
|
if self.use_candidate_suppression: |
|
candidate_map = self.candidate_suppression(junctions, |
|
candidate_map) |
|
|
|
|
|
candidate_index_map = torch.where(candidate_map) |
|
candidate_index_map = torch.cat([candidate_index_map[0][..., None], |
|
candidate_index_map[1][..., None]], |
|
dim=-1) |
|
|
|
|
|
candidate_junc_start = junctions[candidate_index_map[:, 0], :] |
|
candidate_junc_end = junctions[candidate_index_map[:, 1], :] |
|
|
|
|
|
sampler = self.torch_sampler.to(device)[None, ...] |
|
cand_samples_h = candidate_junc_start[:, 0:1] * sampler + \ |
|
candidate_junc_end[:, 0:1] * (1 - sampler) |
|
cand_samples_w = candidate_junc_start[:, 1:2] * sampler + \ |
|
candidate_junc_end[:, 1:2] * (1 - sampler) |
|
|
|
|
|
cand_h = torch.clamp(cand_samples_h, min=0, max=H-1) |
|
cand_w = torch.clamp(cand_samples_w, min=0, max=W-1) |
|
|
|
|
|
if self.sampling_method == "local_max": |
|
|
|
segments_length = torch.sqrt(torch.sum( |
|
(candidate_junc_start.to(torch.float32) - |
|
candidate_junc_end.to(torch.float32)) ** 2, dim=-1)) |
|
normalized_seg_length = (segments_length |
|
/ (((H ** 2) + (W ** 2)) ** 0.5)) |
|
|
|
|
|
num_cand = cand_h.shape[0] |
|
group_size = 10000 |
|
if num_cand > group_size: |
|
num_iter = math.ceil(num_cand / group_size) |
|
sampled_feat_lst = [] |
|
for iter_idx in range(num_iter): |
|
if not iter_idx == num_iter-1: |
|
cand_h_ = cand_h[iter_idx * group_size: |
|
(iter_idx+1) * group_size, :] |
|
cand_w_ = cand_w[iter_idx * group_size: |
|
(iter_idx+1) * group_size, :] |
|
normalized_seg_length_ = normalized_seg_length[ |
|
iter_idx * group_size: (iter_idx+1) * group_size] |
|
else: |
|
cand_h_ = cand_h[iter_idx * group_size:, :] |
|
cand_w_ = cand_w[iter_idx * group_size:, :] |
|
normalized_seg_length_ = normalized_seg_length[ |
|
iter_idx * group_size:] |
|
sampled_feat_ = self.detect_local_max( |
|
heatmap, cand_h_, cand_w_, H, W, |
|
normalized_seg_length_, device) |
|
sampled_feat_lst.append(sampled_feat_) |
|
sampled_feat = torch.cat(sampled_feat_lst, dim=0) |
|
else: |
|
sampled_feat = self.detect_local_max( |
|
heatmap, cand_h, cand_w, H, W, |
|
normalized_seg_length, device) |
|
|
|
elif self.sampling_method == "bilinear": |
|
|
|
sampled_feat = self.detect_bilinear( |
|
heatmap, cand_h, cand_w, H, W, device) |
|
else: |
|
raise ValueError("[Error] Unknown sampling method.") |
|
|
|
|
|
|
|
detection_results = (torch.mean(sampled_feat, dim=-1) |
|
> self.detect_thresh) |
|
|
|
|
|
if self.inlier_thresh > 0.: |
|
inlier_ratio = torch.sum( |
|
sampled_feat > self.detect_thresh, |
|
dim=-1).to(torch.float32) / self.num_samples |
|
detection_results_inlier = inlier_ratio >= self.inlier_thresh |
|
detection_results = detection_results * detection_results_inlier |
|
|
|
|
|
detected_junc_indexes = candidate_index_map[detection_results, :] |
|
line_map_pred[detected_junc_indexes[:, 0], |
|
detected_junc_indexes[:, 1]] = 1 |
|
line_map_pred[detected_junc_indexes[:, 1], |
|
detected_junc_indexes[:, 0]] = 1 |
|
|
|
|
|
if self.use_junction_refinement and len(detected_junc_indexes) > 0: |
|
junctions, line_map_pred = self.refine_junction_perturb( |
|
junctions, line_map_pred, heatmap, H, W, device) |
|
|
|
return line_map_pred, junctions, heatmap |
|
|
|
def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2): |
|
""" Global heatmap refinement method. """ |
|
|
|
heatmap_values = heatmap[heatmap > valid_thresh] |
|
sorted_values = torch.sort(heatmap_values, descending=True)[0] |
|
top10_len = math.ceil(sorted_values.shape[0] * ratio) |
|
max20 = torch.mean(sorted_values[:top10_len]) |
|
heatmap = torch.clamp(heatmap / max20, min=0., max=1.) |
|
return heatmap |
|
|
|
def refine_heatmap_local(self, heatmap, num_blocks=5, overlap_ratio=0.5, |
|
ratio=0.2, valid_thresh=2e-3): |
|
""" Local heatmap refinement method. """ |
|
|
|
H, W = heatmap.shape |
|
increase_ratio = 1 - overlap_ratio |
|
h_block = round(H / (1 + (num_blocks - 1) * increase_ratio)) |
|
w_block = round(W / (1 + (num_blocks - 1) * increase_ratio)) |
|
|
|
count_map = torch.zeros(heatmap.shape, dtype=torch.int, |
|
device=heatmap.device) |
|
heatmap_output = torch.zeros(heatmap.shape, dtype=torch.float, |
|
device=heatmap.device) |
|
|
|
for h_idx in range(num_blocks): |
|
for w_idx in range(num_blocks): |
|
|
|
h_start = round(h_idx * h_block * increase_ratio) |
|
w_start = round(w_idx * w_block * increase_ratio) |
|
h_end = h_start + h_block if h_idx < num_blocks - 1 else H |
|
w_end = w_start + w_block if w_idx < num_blocks - 1 else W |
|
|
|
subheatmap = heatmap[h_start:h_end, w_start:w_end] |
|
if subheatmap.max() > valid_thresh: |
|
subheatmap = self.refine_heatmap( |
|
subheatmap, ratio, valid_thresh=valid_thresh) |
|
|
|
|
|
heatmap_output[h_start:h_end, w_start:w_end] += subheatmap |
|
count_map[h_start:h_end, w_start:w_end] += 1 |
|
heatmap_output = torch.clamp(heatmap_output / count_map, |
|
max=1., min=0.) |
|
|
|
return heatmap_output |
|
|
|
def candidate_suppression(self, junctions, candidate_map): |
|
""" Suppress overlapping long lines in the candidate segments. """ |
|
|
|
dist_tolerance = self.nms_dist_tolerance |
|
|
|
|
|
|
|
line_dist_map = torch.sum((torch.unsqueeze(junctions, dim=1) |
|
- junctions[None, ...]) ** 2, dim=-1) ** 0.5 |
|
|
|
|
|
seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1)) |
|
start_point_idxs = seg_indexes[0] |
|
end_point_idxs = seg_indexes[1] |
|
start_points = junctions[start_point_idxs, :] |
|
end_points = junctions[end_point_idxs, :] |
|
|
|
|
|
line_dists = line_dist_map[start_point_idxs, end_point_idxs] |
|
|
|
|
|
dir_vecs = ((end_points - start_points) |
|
/ torch.norm(end_points - start_points, |
|
dim=-1)[..., None]) |
|
|
|
cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1) |
|
cand_vecs_norm = torch.norm(cand_vecs, dim=-1) |
|
|
|
proj = (torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None]) |
|
/ line_dists[..., None, None]) |
|
|
|
proj_mask = (proj >=0) * (proj <= 1) |
|
cand_angles = torch.acos( |
|
torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None]) |
|
/ cand_vecs_norm[..., None]) |
|
cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles) |
|
junc_dist_mask = cand_dists <= dist_tolerance |
|
junc_mask = junc_dist_mask * proj_mask |
|
|
|
|
|
num_segs = start_point_idxs.shape[0] |
|
junc_counts = torch.sum(junc_mask, dim=[1, 2]) |
|
junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), |
|
start_point_idxs].to(torch.int) |
|
junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), |
|
end_point_idxs].to(torch.int) |
|
|
|
|
|
final_mask = junc_counts > 0 |
|
candidate_map[start_point_idxs[final_mask], |
|
end_point_idxs[final_mask]] = 0 |
|
|
|
return candidate_map |
|
|
|
def refine_junction_perturb(self, junctions, line_map_pred, |
|
heatmap, H, W, device): |
|
""" Refine the line endpoints in a similar way as in LSD. """ |
|
|
|
junction_refine_cfg = self.junction_refine_cfg |
|
|
|
|
|
num_perturbs = junction_refine_cfg["num_perturbs"] |
|
perturb_interval = junction_refine_cfg["perturb_interval"] |
|
side_perturbs = (num_perturbs - 1) // 2 |
|
|
|
perturb_vec = torch.arange( |
|
start=-perturb_interval*side_perturbs, |
|
end=perturb_interval*(side_perturbs+1), |
|
step=perturb_interval, device=device) |
|
w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid( |
|
perturb_vec, perturb_vec, perturb_vec, perturb_vec) |
|
perturb_tensor = torch.cat([ |
|
w1_grid[..., None], h1_grid[..., None], |
|
w2_grid[..., None], h2_grid[..., None]], dim=-1) |
|
perturb_tensor_flat = perturb_tensor.view(-1, 2, 2) |
|
|
|
|
|
junctions = junctions.clone() |
|
line_map = line_map_pred |
|
|
|
|
|
detected_seg_indexes = torch.where(torch.triu(line_map, diagonal=1)) |
|
start_point_idxs = detected_seg_indexes[0] |
|
end_point_idxs = detected_seg_indexes[1] |
|
start_points = junctions[start_point_idxs, :] |
|
end_points = junctions[end_point_idxs, :] |
|
|
|
line_segments = torch.cat([start_points.unsqueeze(dim=1), |
|
end_points.unsqueeze(dim=1)], dim=1) |
|
|
|
line_segment_candidates = (line_segments.unsqueeze(dim=1) |
|
+ perturb_tensor_flat[None, ...]) |
|
|
|
line_segment_candidates[..., 0] = torch.clamp( |
|
line_segment_candidates[..., 0], min=0, max=H - 1) |
|
line_segment_candidates[..., 1] = torch.clamp( |
|
line_segment_candidates[..., 1], min=0, max=W - 1) |
|
|
|
|
|
refined_segment_lst = [] |
|
num_segments = line_segments.shape[0] |
|
for idx in range(num_segments): |
|
segment = line_segment_candidates[idx, ...] |
|
|
|
candidate_junc_start = segment[:, 0, :] |
|
candidate_junc_end = segment[:, 1, :] |
|
|
|
|
|
sampler = self.torch_sampler.to(device)[None, ...] |
|
cand_samples_h = (candidate_junc_start[:, 0:1] * sampler + |
|
candidate_junc_end[:, 0:1] * (1 - sampler)) |
|
cand_samples_w = (candidate_junc_start[:, 1:2] * sampler + |
|
candidate_junc_end[:, 1:2] * (1 - sampler)) |
|
|
|
|
|
cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1) |
|
cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1) |
|
|
|
|
|
segment_feat = self.detect_bilinear( |
|
heatmap, cand_h, cand_w, H, W, device) |
|
segment_results = torch.mean(segment_feat, dim=-1) |
|
max_idx = torch.argmax(segment_results) |
|
refined_segment_lst.append(segment[max_idx, ...][None, ...]) |
|
|
|
|
|
refined_segments = torch.cat(refined_segment_lst, dim=0) |
|
|
|
|
|
junctions_new = torch.cat( |
|
[refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0) |
|
junctions_new = torch.unique(junctions_new, dim=0) |
|
line_map_new = self.segments_to_line_map(junctions_new, |
|
refined_segments) |
|
|
|
return junctions_new, line_map_new |
|
|
|
def segments_to_line_map(self, junctions, segments): |
|
""" Convert the list of segments to line map. """ |
|
|
|
device = junctions.device |
|
num_junctions = junctions.shape[0] |
|
line_map = torch.zeros([num_junctions, num_junctions], device=device) |
|
|
|
|
|
for idx in range(segments.shape[0]): |
|
|
|
seg = segments[idx, ...] |
|
junction1 = seg[0, :] |
|
junction2 = seg[1, :] |
|
|
|
|
|
idx_junction1 = torch.where( |
|
(junctions == junction1).sum(axis=1) == 2)[0] |
|
idx_junction2 = torch.where( |
|
(junctions == junction2).sum(axis=1) == 2)[0] |
|
|
|
|
|
line_map[idx_junction1, idx_junction2] = 1 |
|
line_map[idx_junction2, idx_junction1] = 1 |
|
|
|
return line_map |
|
|
|
def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device): |
|
""" Detection by bilinear sampling. """ |
|
|
|
cand_h_floor = torch.floor(cand_h).to(torch.long) |
|
cand_h_ceil = torch.ceil(cand_h).to(torch.long) |
|
cand_w_floor = torch.floor(cand_w).to(torch.long) |
|
cand_w_ceil = torch.ceil(cand_w).to(torch.long) |
|
|
|
|
|
cand_samples_feat = ( |
|
heatmap[cand_h_floor, cand_w_floor] * (cand_h_ceil - cand_h) |
|
* (cand_w_ceil - cand_w) + heatmap[cand_h_floor, cand_w_ceil] |
|
* (cand_h_ceil - cand_h) * (cand_w - cand_w_floor) + |
|
heatmap[cand_h_ceil, cand_w_floor] * (cand_h - cand_h_floor) |
|
* (cand_w_ceil - cand_w) + heatmap[cand_h_ceil, cand_w_ceil] |
|
* (cand_h - cand_h_floor) * (cand_w - cand_w_floor)) |
|
|
|
return cand_samples_feat |
|
|
|
def detect_local_max(self, heatmap, cand_h, cand_w, H, W, |
|
normalized_seg_length, device): |
|
""" Detection by local maximum search. """ |
|
|
|
dist_thresh = (0.5 * (2 ** 0.5) |
|
+ self.lambda_radius * normalized_seg_length) |
|
|
|
dist_thresh = torch.repeat_interleave(dist_thresh[..., None], |
|
self.num_samples, dim=-1) |
|
|
|
|
|
cand_points = torch.cat([cand_h[..., None], cand_w[..., None]], |
|
dim=-1) |
|
cand_points_round = torch.round(cand_points) |
|
|
|
|
|
patch_mask = torch.zeros([int(2 * self.local_patch_radius + 1), |
|
int(2 * self.local_patch_radius + 1)], |
|
device=device) |
|
patch_center = torch.tensor( |
|
[[self.local_patch_radius, self.local_patch_radius]], |
|
device=device, dtype=torch.float32) |
|
H_patch_points, W_patch_points = torch.where(patch_mask >= 0) |
|
patch_points = torch.cat([H_patch_points[..., None], |
|
W_patch_points[..., None]], dim=-1) |
|
|
|
patch_center_dist = torch.sqrt(torch.sum( |
|
(patch_points - patch_center) ** 2, dim=-1)) |
|
patch_points = (patch_points[patch_center_dist |
|
<= self.local_patch_radius, :]) |
|
|
|
patch_points = patch_points - self.local_patch_radius |
|
|
|
|
|
patch_points_shifted = (torch.unsqueeze(cand_points_round, dim=2) |
|
+ patch_points[None, None, ...]) |
|
patch_dist = torch.sqrt(torch.sum((torch.unsqueeze(cand_points, dim=2) |
|
- patch_points_shifted) ** 2, |
|
dim=-1)) |
|
patch_dist_mask = patch_dist < dist_thresh[..., None] |
|
|
|
|
|
points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0, |
|
max=H - 1).to(torch.long) |
|
points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0, |
|
max=W - 1).to(torch.long) |
|
points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1) |
|
|
|
|
|
sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]] |
|
|
|
sampled_feat = sampled_feat * patch_dist_mask.to(torch.float32) |
|
if len(sampled_feat) == 0: |
|
sampled_feat_lmax = torch.empty(0, 64) |
|
else: |
|
sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1) |
|
|
|
return sampled_feat_lmax |
|
|