Spaces:
Running
Running
""" | |
Implementation of the line segment detection module. | |
""" | |
import math | |
import numpy as np | |
import torch | |
class LineSegmentDetectionModule(object): | |
""" Module extracting line segments from junctions and line heatmaps. """ | |
def __init__( | |
self, detect_thresh, num_samples=64, sampling_method="local_max", | |
inlier_thresh=0., heatmap_low_thresh=0.15, heatmap_high_thresh=0.2, | |
max_local_patch_radius=3, lambda_radius=2., | |
use_candidate_suppression=False, nms_dist_tolerance=3., | |
use_heatmap_refinement=False, heatmap_refine_cfg=None, | |
use_junction_refinement=False, junction_refine_cfg=None): | |
""" | |
Parameters: | |
detect_thresh: The probability threshold for mean activation (0. ~ 1.) | |
num_samples: Number of sampling locations along the line segments. | |
sampling_method: Sampling method on locations ("bilinear" or "local_max"). | |
inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold. | |
heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery. | |
heatmap_high_thresh: The higher threshold for NMS in junction recovery. | |
max_local_patch_radius: The max patch to be considered in local maximum search. | |
lambda_radius: The lambda factor in linear local maximum search formulation | |
use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments. | |
nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line. | |
use_heatmap_refinement: Use heatmap refinement method or not. | |
heatmap_refine_cfg: The configs for heatmap refinement methods. | |
use_junction_refinement: Use junction refinement method or not. | |
junction_refine_cfg: The configs for junction refinement methods. | |
""" | |
# Line detection parameters | |
self.detect_thresh = detect_thresh | |
# Line sampling parameters | |
self.num_samples = num_samples | |
self.sampling_method = sampling_method | |
self.inlier_thresh = inlier_thresh | |
self.local_patch_radius = max_local_patch_radius | |
self.lambda_radius = lambda_radius | |
# Detecting junctions on the boundary parameters | |
self.low_thresh = heatmap_low_thresh | |
self.high_thresh = heatmap_high_thresh | |
# Pre-compute the linspace sampler | |
self.sampler = np.linspace(0, 1, self.num_samples) | |
self.torch_sampler = torch.linspace(0, 1, self.num_samples) | |
# Long line segment suppression configuration | |
self.use_candidate_suppression = use_candidate_suppression | |
self.nms_dist_tolerance = nms_dist_tolerance | |
# Heatmap refinement configuration | |
self.use_heatmap_refinement = use_heatmap_refinement | |
self.heatmap_refine_cfg = heatmap_refine_cfg | |
if self.use_heatmap_refinement and self.heatmap_refine_cfg is None: | |
raise ValueError("[Error] Missing heatmap refinement config.") | |
# Junction refinement configuration | |
self.use_junction_refinement = use_junction_refinement | |
self.junction_refine_cfg = junction_refine_cfg | |
if self.use_junction_refinement and self.junction_refine_cfg is None: | |
raise ValueError("[Error] Missing junction refinement config.") | |
def convert_inputs(self, inputs, device): | |
""" Convert inputs to desired torch tensor. """ | |
if isinstance(inputs, np.ndarray): | |
outputs = torch.tensor(inputs, dtype=torch.float32, device=device) | |
elif isinstance(inputs, torch.Tensor): | |
outputs = inputs.to(torch.float32).to(device) | |
else: | |
raise ValueError( | |
"[Error] Inputs must either be torch tensor or numpy ndarray.") | |
return outputs | |
def detect(self, junctions, heatmap, device=torch.device("cpu")): | |
""" Main function performing line segment detection. """ | |
# Convert inputs to torch tensor | |
junctions = self.convert_inputs(junctions, device=device) | |
heatmap = self.convert_inputs(heatmap, device=device) | |
# Perform the heatmap refinement | |
if self.use_heatmap_refinement: | |
if self.heatmap_refine_cfg["mode"] == "global": | |
heatmap = self.refine_heatmap( | |
heatmap, | |
self.heatmap_refine_cfg["ratio"], | |
self.heatmap_refine_cfg["valid_thresh"] | |
) | |
elif self.heatmap_refine_cfg["mode"] == "local": | |
heatmap = self.refine_heatmap_local( | |
heatmap, | |
self.heatmap_refine_cfg["num_blocks"], | |
self.heatmap_refine_cfg["overlap_ratio"], | |
self.heatmap_refine_cfg["ratio"], | |
self.heatmap_refine_cfg["valid_thresh"] | |
) | |
# Initialize empty line map | |
num_junctions = junctions.shape[0] | |
line_map_pred = torch.zeros([num_junctions, num_junctions], | |
device=device, dtype=torch.int32) | |
# Stop if there are not enough junctions | |
if num_junctions < 2: | |
return line_map_pred, junctions, heatmap | |
# Generate the candidate map | |
candidate_map = torch.triu(torch.ones( | |
[num_junctions, num_junctions], device=device, dtype=torch.int32), | |
diagonal=1) | |
# Fetch the image boundary | |
if len(heatmap.shape) > 2: | |
H, W, _ = heatmap.shape | |
else: | |
H, W = heatmap.shape | |
# Optionally perform candidate filtering | |
if self.use_candidate_suppression: | |
candidate_map = self.candidate_suppression(junctions, | |
candidate_map) | |
# Fetch the candidates | |
candidate_index_map = torch.where(candidate_map) | |
candidate_index_map = torch.cat([candidate_index_map[0][..., None], | |
candidate_index_map[1][..., None]], | |
dim=-1) | |
# Get the corresponding start and end junctions | |
candidate_junc_start = junctions[candidate_index_map[:, 0], :] | |
candidate_junc_end = junctions[candidate_index_map[:, 1], :] | |
# Get the sampling locations (N x 64) | |
sampler = self.torch_sampler.to(device)[None, ...] | |
cand_samples_h = candidate_junc_start[:, 0:1] * sampler + \ | |
candidate_junc_end[:, 0:1] * (1 - sampler) | |
cand_samples_w = candidate_junc_start[:, 1:2] * sampler + \ | |
candidate_junc_end[:, 1:2] * (1 - sampler) | |
# Clip to image boundary | |
cand_h = torch.clamp(cand_samples_h, min=0, max=H-1) | |
cand_w = torch.clamp(cand_samples_w, min=0, max=W-1) | |
# Local maximum search | |
if self.sampling_method == "local_max": | |
# Compute normalized segment lengths | |
segments_length = torch.sqrt(torch.sum( | |
(candidate_junc_start.to(torch.float32) - | |
candidate_junc_end.to(torch.float32)) ** 2, dim=-1)) | |
normalized_seg_length = (segments_length | |
/ (((H ** 2) + (W ** 2)) ** 0.5)) | |
# Perform local max search | |
num_cand = cand_h.shape[0] | |
group_size = 10000 | |
if num_cand > group_size: | |
num_iter = math.ceil(num_cand / group_size) | |
sampled_feat_lst = [] | |
for iter_idx in range(num_iter): | |
if not iter_idx == num_iter-1: | |
cand_h_ = cand_h[iter_idx * group_size: | |
(iter_idx+1) * group_size, :] | |
cand_w_ = cand_w[iter_idx * group_size: | |
(iter_idx+1) * group_size, :] | |
normalized_seg_length_ = normalized_seg_length[ | |
iter_idx * group_size: (iter_idx+1) * group_size] | |
else: | |
cand_h_ = cand_h[iter_idx * group_size:, :] | |
cand_w_ = cand_w[iter_idx * group_size:, :] | |
normalized_seg_length_ = normalized_seg_length[ | |
iter_idx * group_size:] | |
sampled_feat_ = self.detect_local_max( | |
heatmap, cand_h_, cand_w_, H, W, | |
normalized_seg_length_, device) | |
sampled_feat_lst.append(sampled_feat_) | |
sampled_feat = torch.cat(sampled_feat_lst, dim=0) | |
else: | |
sampled_feat = self.detect_local_max( | |
heatmap, cand_h, cand_w, H, W, | |
normalized_seg_length, device) | |
# Bilinear sampling | |
elif self.sampling_method == "bilinear": | |
# Perform bilinear sampling | |
sampled_feat = self.detect_bilinear( | |
heatmap, cand_h, cand_w, H, W, device) | |
else: | |
raise ValueError("[Error] Unknown sampling method.") | |
# [Simple threshold detection] | |
# detection_results is a mask over all candidates | |
detection_results = (torch.mean(sampled_feat, dim=-1) | |
> self.detect_thresh) | |
# [Inlier threshold detection] | |
if self.inlier_thresh > 0.: | |
inlier_ratio = torch.sum( | |
sampled_feat > self.detect_thresh, | |
dim=-1).to(torch.float32) / self.num_samples | |
detection_results_inlier = inlier_ratio >= self.inlier_thresh | |
detection_results = detection_results * detection_results_inlier | |
# Convert detection results back to line_map_pred | |
detected_junc_indexes = candidate_index_map[detection_results, :] | |
line_map_pred[detected_junc_indexes[:, 0], | |
detected_junc_indexes[:, 1]] = 1 | |
line_map_pred[detected_junc_indexes[:, 1], | |
detected_junc_indexes[:, 0]] = 1 | |
# Perform junction refinement | |
if self.use_junction_refinement and len(detected_junc_indexes) > 0: | |
junctions, line_map_pred = self.refine_junction_perturb( | |
junctions, line_map_pred, heatmap, H, W, device) | |
return line_map_pred, junctions, heatmap | |
def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2): | |
""" Global heatmap refinement method. """ | |
# Grab the top 10% values | |
heatmap_values = heatmap[heatmap > valid_thresh] | |
sorted_values = torch.sort(heatmap_values, descending=True)[0] | |
top10_len = math.ceil(sorted_values.shape[0] * ratio) | |
max20 = torch.mean(sorted_values[:top10_len]) | |
heatmap = torch.clamp(heatmap / max20, min=0., max=1.) | |
return heatmap | |
def refine_heatmap_local(self, heatmap, num_blocks=5, overlap_ratio=0.5, | |
ratio=0.2, valid_thresh=2e-3): | |
""" Local heatmap refinement method. """ | |
# Get the shape of the heatmap | |
H, W = heatmap.shape | |
increase_ratio = 1 - overlap_ratio | |
h_block = round(H / (1 + (num_blocks - 1) * increase_ratio)) | |
w_block = round(W / (1 + (num_blocks - 1) * increase_ratio)) | |
count_map = torch.zeros(heatmap.shape, dtype=torch.int, | |
device=heatmap.device) | |
heatmap_output = torch.zeros(heatmap.shape, dtype=torch.float, | |
device=heatmap.device) | |
# Iterate through each block | |
for h_idx in range(num_blocks): | |
for w_idx in range(num_blocks): | |
# Fetch the heatmap | |
h_start = round(h_idx * h_block * increase_ratio) | |
w_start = round(w_idx * w_block * increase_ratio) | |
h_end = h_start + h_block if h_idx < num_blocks - 1 else H | |
w_end = w_start + w_block if w_idx < num_blocks - 1 else W | |
subheatmap = heatmap[h_start:h_end, w_start:w_end] | |
if subheatmap.max() > valid_thresh: | |
subheatmap = self.refine_heatmap( | |
subheatmap, ratio, valid_thresh=valid_thresh) | |
# Aggregate it to the final heatmap | |
heatmap_output[h_start:h_end, w_start:w_end] += subheatmap | |
count_map[h_start:h_end, w_start:w_end] += 1 | |
heatmap_output = torch.clamp(heatmap_output / count_map, | |
max=1., min=0.) | |
return heatmap_output | |
def candidate_suppression(self, junctions, candidate_map): | |
""" Suppress overlapping long lines in the candidate segments. """ | |
# Define the distance tolerance | |
dist_tolerance = self.nms_dist_tolerance | |
# Compute distance between junction pairs | |
# (num_junc x 1 x 2) - (1 x num_junc x 2) => num_junc x num_junc map | |
line_dist_map = torch.sum((torch.unsqueeze(junctions, dim=1) | |
- junctions[None, ...]) ** 2, dim=-1) ** 0.5 | |
# Fetch all the "detected lines" | |
seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1)) | |
start_point_idxs = seg_indexes[0] | |
end_point_idxs = seg_indexes[1] | |
start_points = junctions[start_point_idxs, :] | |
end_points = junctions[end_point_idxs, :] | |
# Fetch corresponding entries | |
line_dists = line_dist_map[start_point_idxs, end_point_idxs] | |
# Check whether they are on the line | |
dir_vecs = ((end_points - start_points) | |
/ torch.norm(end_points - start_points, | |
dim=-1)[..., None]) | |
# Get the orthogonal distance | |
cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1) | |
cand_vecs_norm = torch.norm(cand_vecs, dim=-1) | |
# Check whether they are projected directly onto the segment | |
proj = (torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None]) | |
/ line_dists[..., None, None]) | |
# proj is num_segs x num_junction x 1 | |
proj_mask = (proj >=0) * (proj <= 1) | |
cand_angles = torch.acos( | |
torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None]) | |
/ cand_vecs_norm[..., None]) | |
cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles) | |
junc_dist_mask = cand_dists <= dist_tolerance | |
junc_mask = junc_dist_mask * proj_mask | |
# Minus starting points | |
num_segs = start_point_idxs.shape[0] | |
junc_counts = torch.sum(junc_mask, dim=[1, 2]) | |
junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), | |
start_point_idxs].to(torch.int) | |
junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), | |
end_point_idxs].to(torch.int) | |
# Get the invalid candidate mask | |
final_mask = junc_counts > 0 | |
candidate_map[start_point_idxs[final_mask], | |
end_point_idxs[final_mask]] = 0 | |
return candidate_map | |
def refine_junction_perturb(self, junctions, line_map_pred, | |
heatmap, H, W, device): | |
""" Refine the line endpoints in a similar way as in LSD. """ | |
# Get the config | |
junction_refine_cfg = self.junction_refine_cfg | |
# Fetch refinement parameters | |
num_perturbs = junction_refine_cfg["num_perturbs"] | |
perturb_interval = junction_refine_cfg["perturb_interval"] | |
side_perturbs = (num_perturbs - 1) // 2 | |
# Fetch the 2D perturb mat | |
perturb_vec = torch.arange( | |
start=-perturb_interval*side_perturbs, | |
end=perturb_interval*(side_perturbs+1), | |
step=perturb_interval, device=device) | |
w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid( | |
perturb_vec, perturb_vec, perturb_vec, perturb_vec) | |
perturb_tensor = torch.cat([ | |
w1_grid[..., None], h1_grid[..., None], | |
w2_grid[..., None], h2_grid[..., None]], dim=-1) | |
perturb_tensor_flat = perturb_tensor.view(-1, 2, 2) | |
# Fetch the junctions and line_map | |
junctions = junctions.clone() | |
line_map = line_map_pred | |
# Fetch all the detected lines | |
detected_seg_indexes = torch.where(torch.triu(line_map, diagonal=1)) | |
start_point_idxs = detected_seg_indexes[0] | |
end_point_idxs = detected_seg_indexes[1] | |
start_points = junctions[start_point_idxs, :] | |
end_points = junctions[end_point_idxs, :] | |
line_segments = torch.cat([start_points.unsqueeze(dim=1), | |
end_points.unsqueeze(dim=1)], dim=1) | |
line_segment_candidates = (line_segments.unsqueeze(dim=1) | |
+ perturb_tensor_flat[None, ...]) | |
# Clip the boundaries | |
line_segment_candidates[..., 0] = torch.clamp( | |
line_segment_candidates[..., 0], min=0, max=H - 1) | |
line_segment_candidates[..., 1] = torch.clamp( | |
line_segment_candidates[..., 1], min=0, max=W - 1) | |
# Iterate through all the segments | |
refined_segment_lst = [] | |
num_segments = line_segments.shape[0] | |
for idx in range(num_segments): | |
segment = line_segment_candidates[idx, ...] | |
# Get the corresponding start and end junctions | |
candidate_junc_start = segment[:, 0, :] | |
candidate_junc_end = segment[:, 1, :] | |
# Get the sampling locations (N x 64) | |
sampler = self.torch_sampler.to(device)[None, ...] | |
cand_samples_h = (candidate_junc_start[:, 0:1] * sampler + | |
candidate_junc_end[:, 0:1] * (1 - sampler)) | |
cand_samples_w = (candidate_junc_start[:, 1:2] * sampler + | |
candidate_junc_end[:, 1:2] * (1 - sampler)) | |
# Clip to image boundary | |
cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1) | |
cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1) | |
# Perform bilinear sampling | |
segment_feat = self.detect_bilinear( | |
heatmap, cand_h, cand_w, H, W, device) | |
segment_results = torch.mean(segment_feat, dim=-1) | |
max_idx = torch.argmax(segment_results) | |
refined_segment_lst.append(segment[max_idx, ...][None, ...]) | |
# Concatenate back to segments | |
refined_segments = torch.cat(refined_segment_lst, dim=0) | |
# Convert back to junctions and line_map | |
junctions_new = torch.cat( | |
[refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0) | |
junctions_new = torch.unique(junctions_new, dim=0) | |
line_map_new = self.segments_to_line_map(junctions_new, | |
refined_segments) | |
return junctions_new, line_map_new | |
def segments_to_line_map(self, junctions, segments): | |
""" Convert the list of segments to line map. """ | |
# Create empty line map | |
device = junctions.device | |
num_junctions = junctions.shape[0] | |
line_map = torch.zeros([num_junctions, num_junctions], device=device) | |
# Iterate through every segment | |
for idx in range(segments.shape[0]): | |
# Get the junctions from a single segement | |
seg = segments[idx, ...] | |
junction1 = seg[0, :] | |
junction2 = seg[1, :] | |
# Get index | |
idx_junction1 = torch.where( | |
(junctions == junction1).sum(axis=1) == 2)[0] | |
idx_junction2 = torch.where( | |
(junctions == junction2).sum(axis=1) == 2)[0] | |
# label the corresponding entries | |
line_map[idx_junction1, idx_junction2] = 1 | |
line_map[idx_junction2, idx_junction1] = 1 | |
return line_map | |
def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device): | |
""" Detection by bilinear sampling. """ | |
# Get the floor and ceiling locations | |
cand_h_floor = torch.floor(cand_h).to(torch.long) | |
cand_h_ceil = torch.ceil(cand_h).to(torch.long) | |
cand_w_floor = torch.floor(cand_w).to(torch.long) | |
cand_w_ceil = torch.ceil(cand_w).to(torch.long) | |
# Perform the bilinear sampling | |
cand_samples_feat = ( | |
heatmap[cand_h_floor, cand_w_floor] * (cand_h_ceil - cand_h) | |
* (cand_w_ceil - cand_w) + heatmap[cand_h_floor, cand_w_ceil] | |
* (cand_h_ceil - cand_h) * (cand_w - cand_w_floor) + | |
heatmap[cand_h_ceil, cand_w_floor] * (cand_h - cand_h_floor) | |
* (cand_w_ceil - cand_w) + heatmap[cand_h_ceil, cand_w_ceil] | |
* (cand_h - cand_h_floor) * (cand_w - cand_w_floor)) | |
return cand_samples_feat | |
def detect_local_max(self, heatmap, cand_h, cand_w, H, W, | |
normalized_seg_length, device): | |
""" Detection by local maximum search. """ | |
# Compute the distance threshold | |
dist_thresh = (0.5 * (2 ** 0.5) | |
+ self.lambda_radius * normalized_seg_length) | |
# Make it N x 64 | |
dist_thresh = torch.repeat_interleave(dist_thresh[..., None], | |
self.num_samples, dim=-1) | |
# Compute the candidate points | |
cand_points = torch.cat([cand_h[..., None], cand_w[..., None]], | |
dim=-1) | |
cand_points_round = torch.round(cand_points) # N x 64 x 2 | |
# Construct local patches 9x9 = 81 | |
patch_mask = torch.zeros([int(2 * self.local_patch_radius + 1), | |
int(2 * self.local_patch_radius + 1)], | |
device=device) | |
patch_center = torch.tensor( | |
[[self.local_patch_radius, self.local_patch_radius]], | |
device=device, dtype=torch.float32) | |
H_patch_points, W_patch_points = torch.where(patch_mask >= 0) | |
patch_points = torch.cat([H_patch_points[..., None], | |
W_patch_points[..., None]], dim=-1) | |
# Fetch the circle region | |
patch_center_dist = torch.sqrt(torch.sum( | |
(patch_points - patch_center) ** 2, dim=-1)) | |
patch_points = (patch_points[patch_center_dist | |
<= self.local_patch_radius, :]) | |
# Shift [0, 0] to the center | |
patch_points = patch_points - self.local_patch_radius | |
# Construct local patch mask | |
patch_points_shifted = (torch.unsqueeze(cand_points_round, dim=2) | |
+ patch_points[None, None, ...]) | |
patch_dist = torch.sqrt(torch.sum((torch.unsqueeze(cand_points, dim=2) | |
- patch_points_shifted) ** 2, | |
dim=-1)) | |
patch_dist_mask = patch_dist < dist_thresh[..., None] | |
# Get all points => num_points_center x num_patch_points x 2 | |
points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0, | |
max=H - 1).to(torch.long) | |
points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0, | |
max=W - 1).to(torch.long) | |
points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1) | |
# Sample the feature (N x 64 x 81) | |
sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]] | |
# Filtering using the valid mask | |
sampled_feat = sampled_feat * patch_dist_mask.to(torch.float32) | |
if len(sampled_feat) == 0: | |
sampled_feat_lmax = torch.empty(0, 64) | |
else: | |
sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1) | |
return sampled_feat_lmax | |