|
""" |
|
Implementation of the line matching methods. |
|
""" |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from ..misc.geometry_utils import keypoints_to_grid |
|
|
|
|
|
class WunschLineMatcher(object): |
|
"""Class matching two sets of line segments |
|
with the Needleman-Wunsch algorithm.""" |
|
|
|
def __init__( |
|
self, |
|
cross_check=True, |
|
num_samples=10, |
|
min_dist_pts=8, |
|
top_k_candidates=10, |
|
grid_size=8, |
|
sampling="regular", |
|
line_score=False, |
|
): |
|
self.cross_check = cross_check |
|
self.num_samples = num_samples |
|
self.min_dist_pts = min_dist_pts |
|
self.top_k_candidates = top_k_candidates |
|
self.grid_size = grid_size |
|
self.line_score = line_score |
|
self.sampling_mode = sampling |
|
if sampling not in ["regular", "d2_net", "asl_feat"]: |
|
raise ValueError("Wrong sampling mode: " + sampling) |
|
|
|
def forward(self, line_seg1, line_seg2, desc1, desc2): |
|
""" |
|
Find the best matches between two sets of line segments |
|
and their corresponding descriptors. |
|
""" |
|
img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size) |
|
img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size) |
|
device = desc1.device |
|
|
|
|
|
if len(line_seg1) == 0: |
|
return np.empty((0), dtype=int) |
|
if len(line_seg2) == 0: |
|
return -np.ones(len(line_seg1), dtype=int) |
|
|
|
|
|
if self.sampling_mode == "regular": |
|
line_points1, valid_points1 = self.sample_line_points(line_seg1) |
|
line_points2, valid_points2 = self.sample_line_points(line_seg2) |
|
else: |
|
line_points1, valid_points1 = self.sample_salient_points( |
|
line_seg1, desc1, img_size1, self.sampling_mode |
|
) |
|
line_points2, valid_points2 = self.sample_salient_points( |
|
line_seg2, desc2, img_size2, self.sampling_mode |
|
) |
|
line_points1 = torch.tensor( |
|
line_points1.reshape(-1, 2), dtype=torch.float, device=device |
|
) |
|
line_points2 = torch.tensor( |
|
line_points2.reshape(-1, 2), dtype=torch.float, device=device |
|
) |
|
|
|
|
|
grid1 = keypoints_to_grid(line_points1, img_size1) |
|
grid2 = keypoints_to_grid(line_points2, img_size2) |
|
desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0) |
|
desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0) |
|
|
|
|
|
|
|
scores = desc1.t() @ desc2 |
|
scores[~valid_points1.flatten()] = -1 |
|
scores[:, ~valid_points2.flatten()] = -1 |
|
scores = scores.reshape( |
|
len(line_seg1), self.num_samples, len(line_seg2), self.num_samples |
|
) |
|
scores = scores.permute(0, 2, 1, 3) |
|
|
|
|
|
|
|
matches = self.filter_and_match_lines(scores) |
|
|
|
|
|
if self.cross_check: |
|
matches2 = self.filter_and_match_lines(scores.permute(1, 0, 3, 2)) |
|
mutual = matches2[matches] == np.arange(len(line_seg1)) |
|
matches[~mutual] = -1 |
|
|
|
return matches |
|
|
|
def d2_net_saliency_score(self, desc): |
|
"""Compute the D2-Net saliency score |
|
on a 3D or 4D descriptor.""" |
|
is_3d = len(desc.shape) == 3 |
|
b_size = len(desc) |
|
feat = F.relu(desc) |
|
|
|
|
|
exp = torch.exp(feat) |
|
if is_3d: |
|
sum_exp = 3 * F.avg_pool1d(exp, kernel_size=3, stride=1, padding=1) |
|
else: |
|
sum_exp = 9 * F.avg_pool2d(exp, kernel_size=3, stride=1, padding=1) |
|
soft_local_max = exp / sum_exp |
|
|
|
|
|
depth_wise_max = torch.max(feat, dim=1)[0] |
|
depth_wise_max = feat / depth_wise_max.unsqueeze(1) |
|
|
|
|
|
score = torch.max(soft_local_max * depth_wise_max, dim=1)[0] |
|
normalization = torch.sum(score.reshape(b_size, -1), dim=1) |
|
if is_3d: |
|
normalization = normalization.reshape(b_size, 1) |
|
else: |
|
normalization = normalization.reshape(b_size, 1, 1) |
|
score = score / normalization |
|
return score |
|
|
|
def asl_feat_saliency_score(self, desc): |
|
"""Compute the ASLFeat saliency score on a 3D or 4D descriptor.""" |
|
is_3d = len(desc.shape) == 3 |
|
b_size = len(desc) |
|
|
|
|
|
if is_3d: |
|
local_avg = F.avg_pool1d(desc, kernel_size=3, stride=1, padding=1) |
|
else: |
|
local_avg = F.avg_pool2d(desc, kernel_size=3, stride=1, padding=1) |
|
soft_local_score = F.softplus(desc - local_avg) |
|
|
|
|
|
depth_wise_mean = torch.mean(desc, dim=1).unsqueeze(1) |
|
depth_wise_score = F.softplus(desc - depth_wise_mean) |
|
|
|
|
|
score = torch.max(soft_local_score * depth_wise_score, dim=1)[0] |
|
normalization = torch.sum(score.reshape(b_size, -1), dim=1) |
|
if is_3d: |
|
normalization = normalization.reshape(b_size, 1) |
|
else: |
|
normalization = normalization.reshape(b_size, 1, 1) |
|
score = score / normalization |
|
return score |
|
|
|
def sample_salient_points(self, line_seg, desc, img_size, saliency_type="d2_net"): |
|
""" |
|
Sample the most salient points along each line segments, with a |
|
minimal distance between each point. Pad the remaining points. |
|
Inputs: |
|
line_seg: an Nx2x2 torch.Tensor. |
|
desc: a NxDxHxW torch.Tensor. |
|
image_size: the original image size. |
|
saliency_type: 'd2_net' or 'asl_feat'. |
|
Outputs: |
|
line_points: an Nxnum_samplesx2 np.array. |
|
valid_points: a boolean Nxnum_samples np.array. |
|
""" |
|
device = desc.device |
|
if not self.line_score: |
|
|
|
if saliency_type == "d2_net": |
|
score = self.d2_net_saliency_score(desc) |
|
else: |
|
score = self.asl_feat_saliency_score(desc) |
|
|
|
num_lines = len(line_seg) |
|
line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1) |
|
|
|
|
|
num_samples_lst = np.clip( |
|
line_lengths // self.min_dist_pts, 2, self.num_samples |
|
) |
|
line_points = np.empty((num_lines, self.num_samples, 2), dtype=float) |
|
valid_points = np.empty((num_lines, self.num_samples), dtype=bool) |
|
|
|
|
|
n_samples_per_region = 4 |
|
for n in np.arange(2, self.num_samples + 1): |
|
sample_rate = n * n_samples_per_region |
|
|
|
cur_mask = num_samples_lst == n |
|
cur_line_seg = line_seg[cur_mask] |
|
cur_num_lines = len(cur_line_seg) |
|
if cur_num_lines == 0: |
|
continue |
|
line_points_x = np.linspace( |
|
cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], sample_rate, axis=-1 |
|
) |
|
line_points_y = np.linspace( |
|
cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], sample_rate, axis=-1 |
|
) |
|
cur_line_points = np.stack([line_points_x, line_points_y], axis=-1).reshape( |
|
-1, 2 |
|
) |
|
|
|
cur_line_points = torch.tensor( |
|
cur_line_points, dtype=torch.float, device=device |
|
) |
|
grid_points = keypoints_to_grid(cur_line_points, img_size) |
|
|
|
if self.line_score: |
|
|
|
|
|
line_desc = F.grid_sample(desc, grid_points).squeeze() |
|
line_desc = line_desc.reshape(-1, cur_num_lines, sample_rate) |
|
line_desc = line_desc.permute(1, 0, 2) |
|
if saliency_type == "d2_net": |
|
scores = self.d2_net_saliency_score(line_desc) |
|
else: |
|
scores = self.asl_feat_saliency_score(line_desc) |
|
else: |
|
scores = F.grid_sample(score.unsqueeze(1), grid_points).squeeze() |
|
|
|
|
|
scores = scores.reshape(-1, n, n_samples_per_region) |
|
best = torch.max(scores, dim=2, keepdim=True)[1].cpu().numpy() |
|
cur_line_points = cur_line_points.reshape(-1, n, n_samples_per_region, 2) |
|
cur_line_points = np.take_along_axis( |
|
cur_line_points, best[..., None], axis=2 |
|
)[:, :, 0] |
|
|
|
|
|
cur_valid_points = np.ones((cur_num_lines, self.num_samples), dtype=bool) |
|
cur_valid_points[:, n:] = False |
|
cur_line_points = np.concatenate( |
|
[ |
|
cur_line_points, |
|
np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float), |
|
], |
|
axis=1, |
|
) |
|
|
|
line_points[cur_mask] = cur_line_points |
|
valid_points[cur_mask] = cur_valid_points |
|
|
|
return line_points, valid_points |
|
|
|
def sample_line_points(self, line_seg): |
|
""" |
|
Regularly sample points along each line segments, with a minimal |
|
distance between each point. Pad the remaining points. |
|
Inputs: |
|
line_seg: an Nx2x2 torch.Tensor. |
|
Outputs: |
|
line_points: an Nxnum_samplesx2 np.array. |
|
valid_points: a boolean Nxnum_samples np.array. |
|
""" |
|
num_lines = len(line_seg) |
|
line_lengths = np.linalg.norm(line_seg[:, 0] - line_seg[:, 1], axis=1) |
|
|
|
|
|
|
|
num_samples_lst = np.clip( |
|
line_lengths // self.min_dist_pts, 2, self.num_samples |
|
) |
|
line_points = np.empty((num_lines, self.num_samples, 2), dtype=float) |
|
valid_points = np.empty((num_lines, self.num_samples), dtype=bool) |
|
for n in np.arange(2, self.num_samples + 1): |
|
|
|
cur_mask = num_samples_lst == n |
|
cur_line_seg = line_seg[cur_mask] |
|
line_points_x = np.linspace( |
|
cur_line_seg[:, 0, 0], cur_line_seg[:, 1, 0], n, axis=-1 |
|
) |
|
line_points_y = np.linspace( |
|
cur_line_seg[:, 0, 1], cur_line_seg[:, 1, 1], n, axis=-1 |
|
) |
|
cur_line_points = np.stack([line_points_x, line_points_y], axis=-1) |
|
|
|
|
|
cur_num_lines = len(cur_line_seg) |
|
cur_valid_points = np.ones((cur_num_lines, self.num_samples), dtype=bool) |
|
cur_valid_points[:, n:] = False |
|
cur_line_points = np.concatenate( |
|
[ |
|
cur_line_points, |
|
np.zeros((cur_num_lines, self.num_samples - n, 2), dtype=float), |
|
], |
|
axis=1, |
|
) |
|
|
|
line_points[cur_mask] = cur_line_points |
|
valid_points[cur_mask] = cur_valid_points |
|
|
|
return line_points, valid_points |
|
|
|
def filter_and_match_lines(self, scores): |
|
""" |
|
Use the scores to keep the top k best lines, compute the Needleman- |
|
Wunsch algorithm on each candidate pairs, and keep the highest score. |
|
Inputs: |
|
scores: a (N, M, n, n) torch.Tensor containing the pairwise scores |
|
of the elements to match. |
|
Outputs: |
|
matches: a (N) np.array containing the indices of the best match |
|
""" |
|
|
|
line_scores1 = scores.max(3)[0] |
|
valid_scores1 = line_scores1 != -1 |
|
line_scores1 = (line_scores1 * valid_scores1).sum(2) / valid_scores1.sum(2) |
|
line_scores2 = scores.max(2)[0] |
|
valid_scores2 = line_scores2 != -1 |
|
line_scores2 = (line_scores2 * valid_scores2).sum(2) / valid_scores2.sum(2) |
|
line_scores = (line_scores1 + line_scores2) / 2 |
|
topk_lines = torch.argsort(line_scores, dim=1)[:, -self.top_k_candidates :] |
|
scores, topk_lines = scores.cpu().numpy(), topk_lines.cpu().numpy() |
|
|
|
top_scores = np.take_along_axis(scores, topk_lines[:, :, None, None], axis=1) |
|
|
|
|
|
top_scores = np.concatenate([top_scores, top_scores[..., ::-1]], axis=1) |
|
|
|
|
|
|
|
n_lines1, top2k, n, m = top_scores.shape |
|
top_scores = top_scores.reshape(n_lines1 * top2k, n, m) |
|
nw_scores = self.needleman_wunsch(top_scores) |
|
nw_scores = nw_scores.reshape(n_lines1, top2k) |
|
matches = np.mod(np.argmax(nw_scores, axis=1), top2k // 2) |
|
matches = topk_lines[np.arange(n_lines1), matches] |
|
return matches |
|
|
|
def needleman_wunsch(self, scores): |
|
""" |
|
Batched implementation of the Needleman-Wunsch algorithm. |
|
The cost of the InDel operation is set to 0 by subtracting the gap |
|
penalty to the scores. |
|
Inputs: |
|
scores: a (B, N, M) np.array containing the pairwise scores |
|
of the elements to match. |
|
""" |
|
b, n, m = scores.shape |
|
|
|
|
|
gap = 0.1 |
|
nw_scores = scores - gap |
|
|
|
|
|
nw_grid = np.zeros((b, n + 1, m + 1), dtype=float) |
|
for i in range(n): |
|
for j in range(m): |
|
nw_grid[:, i + 1, j + 1] = np.maximum( |
|
np.maximum(nw_grid[:, i + 1, j], nw_grid[:, i, j + 1]), |
|
nw_grid[:, i, j] + nw_scores[:, i, j], |
|
) |
|
|
|
return nw_grid[:, -1, -1] |
|
|
|
def get_pairwise_distance(self, line_seg1, line_seg2, desc1, desc2): |
|
""" |
|
Compute the OPPOSITE of the NW score for pairs of line segments |
|
and their corresponding descriptors. |
|
""" |
|
num_lines = len(line_seg1) |
|
assert num_lines == len( |
|
line_seg2 |
|
), "The same number of lines is required in pairwise score." |
|
img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size) |
|
img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size) |
|
device = desc1.device |
|
|
|
|
|
line_points1, valid_points1 = self.sample_line_points(line_seg1) |
|
line_points2, valid_points2 = self.sample_line_points(line_seg2) |
|
line_points1 = torch.tensor( |
|
line_points1.reshape(-1, 2), dtype=torch.float, device=device |
|
) |
|
line_points2 = torch.tensor( |
|
line_points2.reshape(-1, 2), dtype=torch.float, device=device |
|
) |
|
|
|
|
|
grid1 = keypoints_to_grid(line_points1, img_size1) |
|
grid2 = keypoints_to_grid(line_points2, img_size2) |
|
desc1 = F.normalize(F.grid_sample(desc1, grid1)[0, :, :, 0], dim=0) |
|
desc1 = desc1.reshape(-1, num_lines, self.num_samples) |
|
desc2 = F.normalize(F.grid_sample(desc2, grid2)[0, :, :, 0], dim=0) |
|
desc2 = desc2.reshape(-1, num_lines, self.num_samples) |
|
|
|
|
|
|
|
scores = torch.einsum("dns,dnt->nst", desc1, desc2).cpu().numpy() |
|
scores = scores.reshape(num_lines * self.num_samples, self.num_samples) |
|
scores[~valid_points1.flatten()] = -1 |
|
scores = scores.reshape(num_lines, self.num_samples, self.num_samples) |
|
scores = scores.transpose(1, 0, 2).reshape(self.num_samples, -1) |
|
scores[:, ~valid_points2.flatten()] = -1 |
|
scores = scores.reshape(self.num_samples, num_lines, self.num_samples) |
|
scores = scores.transpose(1, 0, 2) |
|
|
|
|
|
|
|
pairwise_scores = np.array([self.needleman_wunsch(s) for s in scores]) |
|
return -pairwise_scores |
|
|