Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| from pytlsd import lsd | |
| from sklearn.cluster import DBSCAN | |
| from .base_model import BaseModel | |
| from .superpoint import SuperPoint, sample_descriptors | |
| from ..geometry import warp_lines_torch | |
| def lines_to_wireframe(lines, line_scores, all_descs, conf): | |
| """Given a set of lines, their score and dense descriptors, | |
| merge close-by endpoints and compute a wireframe defined by | |
| its junctions and connectivity. | |
| Returns: | |
| junctions: list of [num_junc, 2] tensors listing all wireframe junctions | |
| junc_scores: list of [num_junc] tensors with the junction score | |
| junc_descs: list of [dim, num_junc] tensors with the junction descriptors | |
| connectivity: list of [num_junc, num_junc] bool arrays with True when 2 junctions are connected | |
| new_lines: the new set of [b_size, num_lines, 2, 2] lines | |
| lines_junc_idx: a [b_size, num_lines, 2] tensor with the indices of the junctions of each endpoint | |
| num_true_junctions: a list of the number of valid junctions for each image in the batch, | |
| i.e. before filling with random ones | |
| """ | |
| b_size, _, _, _ = all_descs.shape | |
| device = lines.device | |
| endpoints = lines.reshape(b_size, -1, 2) | |
| ( | |
| junctions, | |
| junc_scores, | |
| junc_descs, | |
| connectivity, | |
| new_lines, | |
| lines_junc_idx, | |
| num_true_junctions, | |
| ) = ([], [], [], [], [], [], []) | |
| for bs in range(b_size): | |
| # Cluster the junctions that are close-by | |
| db = DBSCAN(eps=conf.nms_radius, min_samples=1).fit(endpoints[bs].cpu().numpy()) | |
| clusters = db.labels_ | |
| n_clusters = len(set(clusters)) | |
| num_true_junctions.append(n_clusters) | |
| # Compute the average junction and score for each cluster | |
| clusters = torch.tensor(clusters, dtype=torch.long, device=device) | |
| new_junc = torch.zeros(n_clusters, 2, dtype=torch.float, device=device) | |
| new_junc.scatter_reduce_( | |
| 0, | |
| clusters[:, None].repeat(1, 2), | |
| endpoints[bs], | |
| reduce="mean", | |
| include_self=False, | |
| ) | |
| junctions.append(new_junc) | |
| new_scores = torch.zeros(n_clusters, dtype=torch.float, device=device) | |
| new_scores.scatter_reduce_( | |
| 0, | |
| clusters, | |
| torch.repeat_interleave(line_scores[bs], 2), | |
| reduce="mean", | |
| include_self=False, | |
| ) | |
| junc_scores.append(new_scores) | |
| # Compute the new lines | |
| new_lines.append(junctions[-1][clusters].reshape(-1, 2, 2)) | |
| lines_junc_idx.append(clusters.reshape(-1, 2)) | |
| # Compute the junction connectivity | |
| junc_connect = torch.eye(n_clusters, dtype=torch.bool, device=device) | |
| pairs = clusters.reshape(-1, 2) # these pairs are connected by a line | |
| junc_connect[pairs[:, 0], pairs[:, 1]] = True | |
| junc_connect[pairs[:, 1], pairs[:, 0]] = True | |
| connectivity.append(junc_connect) | |
| # Interpolate the new junction descriptors | |
| junc_descs.append( | |
| sample_descriptors(junctions[-1][None], all_descs[bs : (bs + 1)], 8)[0] | |
| ) | |
| new_lines = torch.stack(new_lines, dim=0) | |
| lines_junc_idx = torch.stack(lines_junc_idx, dim=0) | |
| return ( | |
| junctions, | |
| junc_scores, | |
| junc_descs, | |
| connectivity, | |
| new_lines, | |
| lines_junc_idx, | |
| num_true_junctions, | |
| ) | |
| class SPWireframeDescriptor(BaseModel): | |
| default_conf = { | |
| "sp_params": { | |
| "has_detector": True, | |
| "has_descriptor": True, | |
| "descriptor_dim": 256, | |
| "trainable": False, | |
| # Inference | |
| "return_all": True, | |
| "sparse_outputs": True, | |
| "nms_radius": 4, | |
| "detection_threshold": 0.005, | |
| "max_num_keypoints": 1000, | |
| "force_num_keypoints": True, | |
| "remove_borders": 4, | |
| }, | |
| "wireframe_params": { | |
| "merge_points": True, | |
| "merge_line_endpoints": True, | |
| "nms_radius": 3, | |
| "max_n_junctions": 500, | |
| }, | |
| "max_n_lines": 250, | |
| "min_length": 15, | |
| } | |
| required_data_keys = ["image"] | |
| def _init(self, conf): | |
| self.conf = conf | |
| self.sp = SuperPoint(conf.sp_params) | |
| def detect_lsd_lines(self, x, max_n_lines=None): | |
| if max_n_lines is None: | |
| max_n_lines = self.conf.max_n_lines | |
| lines, scores, valid_lines = [], [], [] | |
| for b in range(len(x)): | |
| # For each image on batch | |
| img = (x[b].squeeze().cpu().numpy() * 255).astype(np.uint8) | |
| if max_n_lines is None: | |
| b_segs = lsd(img) | |
| else: | |
| for s in [0.3, 0.4, 0.5, 0.7, 0.8, 1.0]: | |
| b_segs = lsd(img, scale=s) | |
| if len(b_segs) >= max_n_lines: | |
| break | |
| segs_length = np.linalg.norm(b_segs[:, 2:4] - b_segs[:, 0:2], axis=1) | |
| # Remove short lines | |
| b_segs = b_segs[segs_length >= self.conf.min_length] | |
| segs_length = segs_length[segs_length >= self.conf.min_length] | |
| b_scores = b_segs[:, -1] * np.sqrt(segs_length) | |
| # Take the most relevant segments with | |
| indices = np.argsort(-b_scores) | |
| if max_n_lines is not None: | |
| indices = indices[:max_n_lines] | |
| lines.append(torch.from_numpy(b_segs[indices, :4].reshape(-1, 2, 2))) | |
| scores.append(torch.from_numpy(b_scores[indices])) | |
| valid_lines.append(torch.ones_like(scores[-1], dtype=torch.bool)) | |
| lines = torch.stack(lines).to(x) | |
| scores = torch.stack(scores).to(x) | |
| valid_lines = torch.stack(valid_lines).to(x.device) | |
| return lines, scores, valid_lines | |
| def _forward(self, data): | |
| b_size, _, h, w = data["image"].shape | |
| device = data["image"].device | |
| if not self.conf.sp_params.force_num_keypoints: | |
| assert b_size == 1, "Only batch size of 1 accepted for non padded inputs" | |
| # Line detection | |
| if "lines" not in data or "line_scores" not in data: | |
| if "original_img" in data: | |
| # Detect more lines, because when projecting them to the image most of them will be discarded | |
| lines, line_scores, valid_lines = self.detect_lsd_lines( | |
| data["original_img"], self.conf.max_n_lines * 3 | |
| ) | |
| # Apply the same transformation that is applied in homography_adaptation | |
| lines, valid_lines2 = warp_lines_torch( | |
| lines, data["H"], False, data["image"].shape[-2:] | |
| ) | |
| valid_lines = valid_lines & valid_lines2 | |
| lines[~valid_lines] = -1 | |
| line_scores[~valid_lines] = 0 | |
| # Re-sort the line segments to pick the ones that are inside the image and have bigger score | |
| sorted_scores, sorting_indices = torch.sort( | |
| line_scores, dim=-1, descending=True | |
| ) | |
| line_scores = sorted_scores[:, : self.conf.max_n_lines] | |
| sorting_indices = sorting_indices[:, : self.conf.max_n_lines] | |
| lines = torch.take_along_dim(lines, sorting_indices[..., None, None], 1) | |
| valid_lines = torch.take_along_dim(valid_lines, sorting_indices, 1) | |
| else: | |
| lines, line_scores, valid_lines = self.detect_lsd_lines(data["image"]) | |
| else: | |
| lines, line_scores, valid_lines = ( | |
| data["lines"], | |
| data["line_scores"], | |
| data["valid_lines"], | |
| ) | |
| if line_scores.shape[-1] != 0: | |
| line_scores /= ( | |
| line_scores.new_tensor(1e-8) + line_scores.max(dim=1).values[:, None] | |
| ) | |
| # SuperPoint prediction | |
| pred = self.sp(data) | |
| # Remove keypoints that are too close to line endpoints | |
| if self.conf.wireframe_params.merge_points: | |
| kp = pred["keypoints"] | |
| line_endpts = lines.reshape(b_size, -1, 2) | |
| dist_pt_lines = torch.norm(kp[:, :, None] - line_endpts[:, None], dim=-1) | |
| # For each keypoint, mark it as valid or to remove | |
| pts_to_remove = torch.any( | |
| dist_pt_lines < self.conf.sp_params.nms_radius, dim=2 | |
| ) | |
| # Simply remove them (we assume batch_size = 1 here) | |
| assert len(kp) == 1 | |
| pred["keypoints"] = pred["keypoints"][0][~pts_to_remove[0]][None] | |
| pred["keypoint_scores"] = pred["keypoint_scores"][0][~pts_to_remove[0]][ | |
| None | |
| ] | |
| pred["descriptors"] = pred["descriptors"][0].T[~pts_to_remove[0]].T[None] | |
| # Connect the lines together to form a wireframe | |
| orig_lines = lines.clone() | |
| if self.conf.wireframe_params.merge_line_endpoints and len(lines[0]) > 0: | |
| # Merge first close-by endpoints to connect lines | |
| ( | |
| line_points, | |
| line_pts_scores, | |
| line_descs, | |
| line_association, | |
| lines, | |
| lines_junc_idx, | |
| num_true_junctions, | |
| ) = lines_to_wireframe( | |
| lines, | |
| line_scores, | |
| pred["all_descriptors"], | |
| conf=self.conf.wireframe_params, | |
| ) | |
| # Add the keypoints to the junctions and fill the rest with random keypoints | |
| (all_points, all_scores, all_descs, pl_associativity) = [], [], [], [] | |
| for bs in range(b_size): | |
| all_points.append( | |
| torch.cat([line_points[bs], pred["keypoints"][bs]], dim=0) | |
| ) | |
| all_scores.append( | |
| torch.cat([line_pts_scores[bs], pred["keypoint_scores"][bs]], dim=0) | |
| ) | |
| all_descs.append( | |
| torch.cat([line_descs[bs], pred["descriptors"][bs]], dim=1) | |
| ) | |
| associativity = torch.eye( | |
| len(all_points[-1]), dtype=torch.bool, device=device | |
| ) | |
| associativity[ | |
| : num_true_junctions[bs], : num_true_junctions[bs] | |
| ] = line_association[bs][ | |
| : num_true_junctions[bs], : num_true_junctions[bs] | |
| ] | |
| pl_associativity.append(associativity) | |
| all_points = torch.stack(all_points, dim=0) | |
| all_scores = torch.stack(all_scores, dim=0) | |
| all_descs = torch.stack(all_descs, dim=0) | |
| pl_associativity = torch.stack(pl_associativity, dim=0) | |
| else: | |
| # Lines are independent | |
| all_points = torch.cat( | |
| [lines.reshape(b_size, -1, 2), pred["keypoints"]], dim=1 | |
| ) | |
| n_pts = all_points.shape[1] | |
| num_lines = lines.shape[1] | |
| num_true_junctions = [num_lines * 2] * b_size | |
| all_scores = torch.cat( | |
| [ | |
| torch.repeat_interleave(line_scores, 2, dim=1), | |
| pred["keypoint_scores"], | |
| ], | |
| dim=1, | |
| ) | |
| pred["line_descriptors"] = self.endpoints_pooling( | |
| lines, pred["all_descriptors"], (h, w) | |
| ) | |
| all_descs = torch.cat( | |
| [ | |
| pred["line_descriptors"].reshape( | |
| b_size, self.conf.sp_params.descriptor_dim, -1 | |
| ), | |
| pred["descriptors"], | |
| ], | |
| dim=2, | |
| ) | |
| pl_associativity = torch.eye(n_pts, dtype=torch.bool, device=device)[ | |
| None | |
| ].repeat(b_size, 1, 1) | |
| lines_junc_idx = ( | |
| torch.arange(num_lines * 2, device=device) | |
| .reshape(1, -1, 2) | |
| .repeat(b_size, 1, 1) | |
| ) | |
| del pred["all_descriptors"] # Remove dense descriptors to save memory | |
| torch.cuda.empty_cache() | |
| return { | |
| "keypoints": all_points, | |
| "keypoint_scores": all_scores, | |
| "descriptors": all_descs, | |
| "pl_associativity": pl_associativity, | |
| "num_junctions": torch.tensor(num_true_junctions), | |
| "lines": lines, | |
| "orig_lines": orig_lines, | |
| "lines_junc_idx": lines_junc_idx, | |
| "line_scores": line_scores, | |
| "valid_lines": valid_lines, | |
| } | |
| def endpoints_pooling(segs, all_descriptors, img_shape): | |
| assert segs.ndim == 4 and segs.shape[-2:] == (2, 2) | |
| filter_shape = all_descriptors.shape[-2:] | |
| scale_x = filter_shape[1] / img_shape[1] | |
| scale_y = filter_shape[0] / img_shape[0] | |
| scaled_segs = torch.round( | |
| segs * torch.tensor([scale_x, scale_y]).to(segs) | |
| ).long() | |
| scaled_segs[..., 0] = torch.clip(scaled_segs[..., 0], 0, filter_shape[1] - 1) | |
| scaled_segs[..., 1] = torch.clip(scaled_segs[..., 1], 0, filter_shape[0] - 1) | |
| line_descriptors = [ | |
| all_descriptors[ | |
| None, | |
| b, | |
| ..., | |
| torch.squeeze(b_segs[..., 1]), | |
| torch.squeeze(b_segs[..., 0]), | |
| ] | |
| for b, b_segs in enumerate(scaled_segs) | |
| ] | |
| line_descriptors = torch.cat(line_descriptors) | |
| return line_descriptors # Shape (1, 256, 308, 2) | |
| def loss(self, pred, data): | |
| raise NotImplementedError | |
| def metrics(self, pred, data): | |
| return {} | |