Vincentqyw
update: features and matchers
404d2af
raw
history blame
6.32 kB
"""
A two-view sparse feature matching pipeline.
This model contains sub-models for each step:
feature extraction, feature matching, outlier filtering, pose estimation.
Each step is optional, and the features or matches can be provided as input.
Default: SuperPoint with nearest neighbor matching.
Convention for the matches: m0[i] is the index of the keypoint in image 1
that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched.
"""
import numpy as np
import torch
from .. import get_model
from .base_model import BaseModel
def keep_quadrant_kp_subset(keypoints, scores, descs, h, w):
"""Keep only keypoints in one of the four quadrant of the image."""
h2, w2 = h // 2, w // 2
w_x = np.random.choice([0, w2])
w_y = np.random.choice([0, h2])
valid_mask = ((keypoints[..., 0] >= w_x)
& (keypoints[..., 0] < w_x + w2)
& (keypoints[..., 1] >= w_y)
& (keypoints[..., 1] < w_y + h2))
keypoints = keypoints[valid_mask][None]
scores = scores[valid_mask][None]
descs = descs.permute(0, 2, 1)[valid_mask].t()[None]
return keypoints, scores, descs
def keep_random_kp_subset(keypoints, scores, descs, num_selected):
"""Keep a random subset of keypoints."""
num_kp = keypoints.shape[1]
selected_kp = torch.randperm(num_kp)[:num_selected]
keypoints = keypoints[:, selected_kp]
scores = scores[:, selected_kp]
descs = descs[:, :, selected_kp]
return keypoints, scores, descs
def keep_best_kp_subset(keypoints, scores, descs, num_selected):
"""Keep the top num_selected best keypoints."""
sorted_indices = torch.sort(scores, dim=1)[1]
selected_kp = sorted_indices[:, -num_selected:]
keypoints = torch.gather(keypoints, 1,
selected_kp[:, :, None].repeat(1, 1, 2))
scores = torch.gather(scores, 1, selected_kp)
descs = torch.gather(descs, 2,
selected_kp[:, None].repeat(1, descs.shape[1], 1))
return keypoints, scores, descs
class TwoViewPipeline(BaseModel):
default_conf = {
'extractor': {
'name': 'superpoint',
'trainable': False,
},
'use_lines': False,
'use_points': True,
'randomize_num_kp': False,
'detector': {'name': None},
'descriptor': {'name': None},
'matcher': {'name': 'nearest_neighbor_matcher'},
'filter': {'name': None},
'solver': {'name': None},
'ground_truth': {
'from_pose_depth': False,
'from_homography': False,
'th_positive': 3,
'th_negative': 5,
'reward_positive': 1,
'reward_negative': -0.25,
'is_likelihood_soft': True,
'p_random_occluders': 0,
'n_line_sampled_pts': 50,
'line_perp_dist_th': 5,
'overlap_th': 0.2,
'min_visibility_th': 0.5
},
}
required_data_keys = ['image0', 'image1']
strict_conf = False # need to pass new confs to children models
components = [
'extractor', 'detector', 'descriptor', 'matcher', 'filter', 'solver']
def _init(self, conf):
if conf.extractor.name:
self.extractor = get_model(conf.extractor.name)(conf.extractor)
else:
if self.conf.detector.name:
self.detector = get_model(conf.detector.name)(conf.detector)
else:
self.required_data_keys += ['keypoints0', 'keypoints1']
if self.conf.descriptor.name:
self.descriptor = get_model(conf.descriptor.name)(
conf.descriptor)
else:
self.required_data_keys += ['descriptors0', 'descriptors1']
if conf.matcher.name:
self.matcher = get_model(conf.matcher.name)(conf.matcher)
else:
self.required_data_keys += ['matches0']
if conf.filter.name:
self.filter = get_model(conf.filter.name)(conf.filter)
if conf.solver.name:
self.solver = get_model(conf.solver.name)(conf.solver)
def _forward(self, data):
def process_siamese(data, i):
data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i}
if self.conf.extractor.name:
pred_i = self.extractor(data_i)
else:
pred_i = {}
if self.conf.detector.name:
pred_i = self.detector(data_i)
else:
for k in ['keypoints', 'keypoint_scores', 'descriptors',
'lines', 'line_scores', 'line_descriptors',
'valid_lines']:
if k in data_i:
pred_i[k] = data_i[k]
if self.conf.descriptor.name:
pred_i = {
**pred_i, **self.descriptor({**data_i, **pred_i})}
return pred_i
pred0 = process_siamese(data, '0')
pred1 = process_siamese(data, '1')
pred = {**{k + '0': v for k, v in pred0.items()},
**{k + '1': v for k, v in pred1.items()}}
if self.conf.matcher.name:
pred = {**pred, **self.matcher({**data, **pred})}
if self.conf.filter.name:
pred = {**pred, **self.filter({**data, **pred})}
if self.conf.solver.name:
pred = {**pred, **self.solver({**data, **pred})}
return pred
def loss(self, pred, data):
losses = {}
total = 0
for k in self.components:
if self.conf[k].name:
try:
losses_ = getattr(self, k).loss(pred, {**pred, **data})
except NotImplementedError:
continue
losses = {**losses, **losses_}
total = losses_['total'] + total
return {**losses, 'total': total}
def metrics(self, pred, data):
metrics = {}
for k in self.components:
if self.conf[k].name:
try:
metrics_ = getattr(self, k).metrics(pred, {**pred, **data})
except NotImplementedError:
continue
metrics = {**metrics, **metrics_}
return metrics