Vincentqyw
update: features and matchers
a80d6bb
raw
history blame
24 kB
"""
Implementation of the line segment detection module.
"""
import math
import numpy as np
import torch
class LineSegmentDetectionModule(object):
""" Module extracting line segments from junctions and line heatmaps. """
def __init__(
self, detect_thresh, num_samples=64, sampling_method="local_max",
inlier_thresh=0., heatmap_low_thresh=0.15, heatmap_high_thresh=0.2,
max_local_patch_radius=3, lambda_radius=2.,
use_candidate_suppression=False, nms_dist_tolerance=3.,
use_heatmap_refinement=False, heatmap_refine_cfg=None,
use_junction_refinement=False, junction_refine_cfg=None):
"""
Parameters:
detect_thresh: The probability threshold for mean activation (0. ~ 1.)
num_samples: Number of sampling locations along the line segments.
sampling_method: Sampling method on locations ("bilinear" or "local_max").
inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold.
heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery.
heatmap_high_thresh: The higher threshold for NMS in junction recovery.
max_local_patch_radius: The max patch to be considered in local maximum search.
lambda_radius: The lambda factor in linear local maximum search formulation
use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments.
nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line.
use_heatmap_refinement: Use heatmap refinement method or not.
heatmap_refine_cfg: The configs for heatmap refinement methods.
use_junction_refinement: Use junction refinement method or not.
junction_refine_cfg: The configs for junction refinement methods.
"""
# Line detection parameters
self.detect_thresh = detect_thresh
# Line sampling parameters
self.num_samples = num_samples
self.sampling_method = sampling_method
self.inlier_thresh = inlier_thresh
self.local_patch_radius = max_local_patch_radius
self.lambda_radius = lambda_radius
# Detecting junctions on the boundary parameters
self.low_thresh = heatmap_low_thresh
self.high_thresh = heatmap_high_thresh
# Pre-compute the linspace sampler
self.sampler = np.linspace(0, 1, self.num_samples)
self.torch_sampler = torch.linspace(0, 1, self.num_samples)
# Long line segment suppression configuration
self.use_candidate_suppression = use_candidate_suppression
self.nms_dist_tolerance = nms_dist_tolerance
# Heatmap refinement configuration
self.use_heatmap_refinement = use_heatmap_refinement
self.heatmap_refine_cfg = heatmap_refine_cfg
if self.use_heatmap_refinement and self.heatmap_refine_cfg is None:
raise ValueError("[Error] Missing heatmap refinement config.")
# Junction refinement configuration
self.use_junction_refinement = use_junction_refinement
self.junction_refine_cfg = junction_refine_cfg
if self.use_junction_refinement and self.junction_refine_cfg is None:
raise ValueError("[Error] Missing junction refinement config.")
def convert_inputs(self, inputs, device):
""" Convert inputs to desired torch tensor. """
if isinstance(inputs, np.ndarray):
outputs = torch.tensor(inputs, dtype=torch.float32, device=device)
elif isinstance(inputs, torch.Tensor):
outputs = inputs.to(torch.float32).to(device)
else:
raise ValueError(
"[Error] Inputs must either be torch tensor or numpy ndarray.")
return outputs
def detect(self, junctions, heatmap, device=torch.device("cpu")):
""" Main function performing line segment detection. """
# Convert inputs to torch tensor
junctions = self.convert_inputs(junctions, device=device)
heatmap = self.convert_inputs(heatmap, device=device)
# Perform the heatmap refinement
if self.use_heatmap_refinement:
if self.heatmap_refine_cfg["mode"] == "global":
heatmap = self.refine_heatmap(
heatmap,
self.heatmap_refine_cfg["ratio"],
self.heatmap_refine_cfg["valid_thresh"]
)
elif self.heatmap_refine_cfg["mode"] == "local":
heatmap = self.refine_heatmap_local(
heatmap,
self.heatmap_refine_cfg["num_blocks"],
self.heatmap_refine_cfg["overlap_ratio"],
self.heatmap_refine_cfg["ratio"],
self.heatmap_refine_cfg["valid_thresh"]
)
# Initialize empty line map
num_junctions = junctions.shape[0]
line_map_pred = torch.zeros([num_junctions, num_junctions],
device=device, dtype=torch.int32)
# Stop if there are not enough junctions
if num_junctions < 2:
return line_map_pred, junctions, heatmap
# Generate the candidate map
candidate_map = torch.triu(torch.ones(
[num_junctions, num_junctions], device=device, dtype=torch.int32),
diagonal=1)
# Fetch the image boundary
if len(heatmap.shape) > 2:
H, W, _ = heatmap.shape
else:
H, W = heatmap.shape
# Optionally perform candidate filtering
if self.use_candidate_suppression:
candidate_map = self.candidate_suppression(junctions,
candidate_map)
# Fetch the candidates
candidate_index_map = torch.where(candidate_map)
candidate_index_map = torch.cat([candidate_index_map[0][..., None],
candidate_index_map[1][..., None]],
dim=-1)
# Get the corresponding start and end junctions
candidate_junc_start = junctions[candidate_index_map[:, 0], :]
candidate_junc_end = junctions[candidate_index_map[:, 1], :]
# Get the sampling locations (N x 64)
sampler = self.torch_sampler.to(device)[None, ...]
cand_samples_h = candidate_junc_start[:, 0:1] * sampler + \
candidate_junc_end[:, 0:1] * (1 - sampler)
cand_samples_w = candidate_junc_start[:, 1:2] * sampler + \
candidate_junc_end[:, 1:2] * (1 - sampler)
# Clip to image boundary
cand_h = torch.clamp(cand_samples_h, min=0, max=H-1)
cand_w = torch.clamp(cand_samples_w, min=0, max=W-1)
# Local maximum search
if self.sampling_method == "local_max":
# Compute normalized segment lengths
segments_length = torch.sqrt(torch.sum(
(candidate_junc_start.to(torch.float32) -
candidate_junc_end.to(torch.float32)) ** 2, dim=-1))
normalized_seg_length = (segments_length
/ (((H ** 2) + (W ** 2)) ** 0.5))
# Perform local max search
num_cand = cand_h.shape[0]
group_size = 10000
if num_cand > group_size:
num_iter = math.ceil(num_cand / group_size)
sampled_feat_lst = []
for iter_idx in range(num_iter):
if not iter_idx == num_iter-1:
cand_h_ = cand_h[iter_idx * group_size:
(iter_idx+1) * group_size, :]
cand_w_ = cand_w[iter_idx * group_size:
(iter_idx+1) * group_size, :]
normalized_seg_length_ = normalized_seg_length[
iter_idx * group_size: (iter_idx+1) * group_size]
else:
cand_h_ = cand_h[iter_idx * group_size:, :]
cand_w_ = cand_w[iter_idx * group_size:, :]
normalized_seg_length_ = normalized_seg_length[
iter_idx * group_size:]
sampled_feat_ = self.detect_local_max(
heatmap, cand_h_, cand_w_, H, W,
normalized_seg_length_, device)
sampled_feat_lst.append(sampled_feat_)
sampled_feat = torch.cat(sampled_feat_lst, dim=0)
else:
sampled_feat = self.detect_local_max(
heatmap, cand_h, cand_w, H, W,
normalized_seg_length, device)
# Bilinear sampling
elif self.sampling_method == "bilinear":
# Perform bilinear sampling
sampled_feat = self.detect_bilinear(
heatmap, cand_h, cand_w, H, W, device)
else:
raise ValueError("[Error] Unknown sampling method.")
# [Simple threshold detection]
# detection_results is a mask over all candidates
detection_results = (torch.mean(sampled_feat, dim=-1)
> self.detect_thresh)
# [Inlier threshold detection]
if self.inlier_thresh > 0.:
inlier_ratio = torch.sum(
sampled_feat > self.detect_thresh,
dim=-1).to(torch.float32) / self.num_samples
detection_results_inlier = inlier_ratio >= self.inlier_thresh
detection_results = detection_results * detection_results_inlier
# Convert detection results back to line_map_pred
detected_junc_indexes = candidate_index_map[detection_results, :]
line_map_pred[detected_junc_indexes[:, 0],
detected_junc_indexes[:, 1]] = 1
line_map_pred[detected_junc_indexes[:, 1],
detected_junc_indexes[:, 0]] = 1
# Perform junction refinement
if self.use_junction_refinement and len(detected_junc_indexes) > 0:
junctions, line_map_pred = self.refine_junction_perturb(
junctions, line_map_pred, heatmap, H, W, device)
return line_map_pred, junctions, heatmap
def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2):
""" Global heatmap refinement method. """
# Grab the top 10% values
heatmap_values = heatmap[heatmap > valid_thresh]
sorted_values = torch.sort(heatmap_values, descending=True)[0]
top10_len = math.ceil(sorted_values.shape[0] * ratio)
max20 = torch.mean(sorted_values[:top10_len])
heatmap = torch.clamp(heatmap / max20, min=0., max=1.)
return heatmap
def refine_heatmap_local(self, heatmap, num_blocks=5, overlap_ratio=0.5,
ratio=0.2, valid_thresh=2e-3):
""" Local heatmap refinement method. """
# Get the shape of the heatmap
H, W = heatmap.shape
increase_ratio = 1 - overlap_ratio
h_block = round(H / (1 + (num_blocks - 1) * increase_ratio))
w_block = round(W / (1 + (num_blocks - 1) * increase_ratio))
count_map = torch.zeros(heatmap.shape, dtype=torch.int,
device=heatmap.device)
heatmap_output = torch.zeros(heatmap.shape, dtype=torch.float,
device=heatmap.device)
# Iterate through each block
for h_idx in range(num_blocks):
for w_idx in range(num_blocks):
# Fetch the heatmap
h_start = round(h_idx * h_block * increase_ratio)
w_start = round(w_idx * w_block * increase_ratio)
h_end = h_start + h_block if h_idx < num_blocks - 1 else H
w_end = w_start + w_block if w_idx < num_blocks - 1 else W
subheatmap = heatmap[h_start:h_end, w_start:w_end]
if subheatmap.max() > valid_thresh:
subheatmap = self.refine_heatmap(
subheatmap, ratio, valid_thresh=valid_thresh)
# Aggregate it to the final heatmap
heatmap_output[h_start:h_end, w_start:w_end] += subheatmap
count_map[h_start:h_end, w_start:w_end] += 1
heatmap_output = torch.clamp(heatmap_output / count_map,
max=1., min=0.)
return heatmap_output
def candidate_suppression(self, junctions, candidate_map):
""" Suppress overlapping long lines in the candidate segments. """
# Define the distance tolerance
dist_tolerance = self.nms_dist_tolerance
# Compute distance between junction pairs
# (num_junc x 1 x 2) - (1 x num_junc x 2) => num_junc x num_junc map
line_dist_map = torch.sum((torch.unsqueeze(junctions, dim=1)
- junctions[None, ...]) ** 2, dim=-1) ** 0.5
# Fetch all the "detected lines"
seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1))
start_point_idxs = seg_indexes[0]
end_point_idxs = seg_indexes[1]
start_points = junctions[start_point_idxs, :]
end_points = junctions[end_point_idxs, :]
# Fetch corresponding entries
line_dists = line_dist_map[start_point_idxs, end_point_idxs]
# Check whether they are on the line
dir_vecs = ((end_points - start_points)
/ torch.norm(end_points - start_points,
dim=-1)[..., None])
# Get the orthogonal distance
cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1)
cand_vecs_norm = torch.norm(cand_vecs, dim=-1)
# Check whether they are projected directly onto the segment
proj = (torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None])
/ line_dists[..., None, None])
# proj is num_segs x num_junction x 1
proj_mask = (proj >=0) * (proj <= 1)
cand_angles = torch.acos(
torch.einsum('bij,bjk->bik', cand_vecs, dir_vecs[..., None])
/ cand_vecs_norm[..., None])
cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles)
junc_dist_mask = cand_dists <= dist_tolerance
junc_mask = junc_dist_mask * proj_mask
# Minus starting points
num_segs = start_point_idxs.shape[0]
junc_counts = torch.sum(junc_mask, dim=[1, 2])
junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs),
start_point_idxs].to(torch.int)
junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs),
end_point_idxs].to(torch.int)
# Get the invalid candidate mask
final_mask = junc_counts > 0
candidate_map[start_point_idxs[final_mask],
end_point_idxs[final_mask]] = 0
return candidate_map
def refine_junction_perturb(self, junctions, line_map_pred,
heatmap, H, W, device):
""" Refine the line endpoints in a similar way as in LSD. """
# Get the config
junction_refine_cfg = self.junction_refine_cfg
# Fetch refinement parameters
num_perturbs = junction_refine_cfg["num_perturbs"]
perturb_interval = junction_refine_cfg["perturb_interval"]
side_perturbs = (num_perturbs - 1) // 2
# Fetch the 2D perturb mat
perturb_vec = torch.arange(
start=-perturb_interval*side_perturbs,
end=perturb_interval*(side_perturbs+1),
step=perturb_interval, device=device)
w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid(
perturb_vec, perturb_vec, perturb_vec, perturb_vec)
perturb_tensor = torch.cat([
w1_grid[..., None], h1_grid[..., None],
w2_grid[..., None], h2_grid[..., None]], dim=-1)
perturb_tensor_flat = perturb_tensor.view(-1, 2, 2)
# Fetch the junctions and line_map
junctions = junctions.clone()
line_map = line_map_pred
# Fetch all the detected lines
detected_seg_indexes = torch.where(torch.triu(line_map, diagonal=1))
start_point_idxs = detected_seg_indexes[0]
end_point_idxs = detected_seg_indexes[1]
start_points = junctions[start_point_idxs, :]
end_points = junctions[end_point_idxs, :]
line_segments = torch.cat([start_points.unsqueeze(dim=1),
end_points.unsqueeze(dim=1)], dim=1)
line_segment_candidates = (line_segments.unsqueeze(dim=1)
+ perturb_tensor_flat[None, ...])
# Clip the boundaries
line_segment_candidates[..., 0] = torch.clamp(
line_segment_candidates[..., 0], min=0, max=H - 1)
line_segment_candidates[..., 1] = torch.clamp(
line_segment_candidates[..., 1], min=0, max=W - 1)
# Iterate through all the segments
refined_segment_lst = []
num_segments = line_segments.shape[0]
for idx in range(num_segments):
segment = line_segment_candidates[idx, ...]
# Get the corresponding start and end junctions
candidate_junc_start = segment[:, 0, :]
candidate_junc_end = segment[:, 1, :]
# Get the sampling locations (N x 64)
sampler = self.torch_sampler.to(device)[None, ...]
cand_samples_h = (candidate_junc_start[:, 0:1] * sampler +
candidate_junc_end[:, 0:1] * (1 - sampler))
cand_samples_w = (candidate_junc_start[:, 1:2] * sampler +
candidate_junc_end[:, 1:2] * (1 - sampler))
# Clip to image boundary
cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1)
cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1)
# Perform bilinear sampling
segment_feat = self.detect_bilinear(
heatmap, cand_h, cand_w, H, W, device)
segment_results = torch.mean(segment_feat, dim=-1)
max_idx = torch.argmax(segment_results)
refined_segment_lst.append(segment[max_idx, ...][None, ...])
# Concatenate back to segments
refined_segments = torch.cat(refined_segment_lst, dim=0)
# Convert back to junctions and line_map
junctions_new = torch.cat(
[refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0)
junctions_new = torch.unique(junctions_new, dim=0)
line_map_new = self.segments_to_line_map(junctions_new,
refined_segments)
return junctions_new, line_map_new
def segments_to_line_map(self, junctions, segments):
""" Convert the list of segments to line map. """
# Create empty line map
device = junctions.device
num_junctions = junctions.shape[0]
line_map = torch.zeros([num_junctions, num_junctions], device=device)
# Iterate through every segment
for idx in range(segments.shape[0]):
# Get the junctions from a single segement
seg = segments[idx, ...]
junction1 = seg[0, :]
junction2 = seg[1, :]
# Get index
idx_junction1 = torch.where(
(junctions == junction1).sum(axis=1) == 2)[0]
idx_junction2 = torch.where(
(junctions == junction2).sum(axis=1) == 2)[0]
# label the corresponding entries
line_map[idx_junction1, idx_junction2] = 1
line_map[idx_junction2, idx_junction1] = 1
return line_map
def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device):
""" Detection by bilinear sampling. """
# Get the floor and ceiling locations
cand_h_floor = torch.floor(cand_h).to(torch.long)
cand_h_ceil = torch.ceil(cand_h).to(torch.long)
cand_w_floor = torch.floor(cand_w).to(torch.long)
cand_w_ceil = torch.ceil(cand_w).to(torch.long)
# Perform the bilinear sampling
cand_samples_feat = (
heatmap[cand_h_floor, cand_w_floor] * (cand_h_ceil - cand_h)
* (cand_w_ceil - cand_w) + heatmap[cand_h_floor, cand_w_ceil]
* (cand_h_ceil - cand_h) * (cand_w - cand_w_floor) +
heatmap[cand_h_ceil, cand_w_floor] * (cand_h - cand_h_floor)
* (cand_w_ceil - cand_w) + heatmap[cand_h_ceil, cand_w_ceil]
* (cand_h - cand_h_floor) * (cand_w - cand_w_floor))
return cand_samples_feat
def detect_local_max(self, heatmap, cand_h, cand_w, H, W,
normalized_seg_length, device):
""" Detection by local maximum search. """
# Compute the distance threshold
dist_thresh = (0.5 * (2 ** 0.5)
+ self.lambda_radius * normalized_seg_length)
# Make it N x 64
dist_thresh = torch.repeat_interleave(dist_thresh[..., None],
self.num_samples, dim=-1)
# Compute the candidate points
cand_points = torch.cat([cand_h[..., None], cand_w[..., None]],
dim=-1)
cand_points_round = torch.round(cand_points) # N x 64 x 2
# Construct local patches 9x9 = 81
patch_mask = torch.zeros([int(2 * self.local_patch_radius + 1),
int(2 * self.local_patch_radius + 1)],
device=device)
patch_center = torch.tensor(
[[self.local_patch_radius, self.local_patch_radius]],
device=device, dtype=torch.float32)
H_patch_points, W_patch_points = torch.where(patch_mask >= 0)
patch_points = torch.cat([H_patch_points[..., None],
W_patch_points[..., None]], dim=-1)
# Fetch the circle region
patch_center_dist = torch.sqrt(torch.sum(
(patch_points - patch_center) ** 2, dim=-1))
patch_points = (patch_points[patch_center_dist
<= self.local_patch_radius, :])
# Shift [0, 0] to the center
patch_points = patch_points - self.local_patch_radius
# Construct local patch mask
patch_points_shifted = (torch.unsqueeze(cand_points_round, dim=2)
+ patch_points[None, None, ...])
patch_dist = torch.sqrt(torch.sum((torch.unsqueeze(cand_points, dim=2)
- patch_points_shifted) ** 2,
dim=-1))
patch_dist_mask = patch_dist < dist_thresh[..., None]
# Get all points => num_points_center x num_patch_points x 2
points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0,
max=H - 1).to(torch.long)
points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0,
max=W - 1).to(torch.long)
points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1)
# Sample the feature (N x 64 x 81)
sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]]
# Filtering using the valid mask
sampled_feat = sampled_feat * patch_dist_mask.to(torch.float32)
if len(sampled_feat) == 0:
sampled_feat_lmax = torch.empty(0, 64)
else:
sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1)
return sampled_feat_lmax