Spaces:
Running
Running
""" | |
Inference model of SuperPoint, a feature detector and descriptor. | |
Described in: | |
SuperPoint: Self-Supervised Interest Point Detection and Description, | |
Daniel DeTone, Tomasz Malisiewicz, Andrew Rabinovich, CVPRW 2018. | |
Original code: github.com/MagicLeapResearch/SuperPointPretrainedNetwork | |
""" | |
import torch | |
from torch import nn | |
from .. import GLUESTICK_ROOT | |
from ..models.base_model import BaseModel | |
def simple_nms(scores, radius): | |
"""Perform non maximum suppression on the heatmap using max-pooling. | |
This method does not suppress contiguous points that have the same score. | |
Args: | |
scores: the score heatmap of size `(B, H, W)`. | |
size: an interger scalar, the radius of the NMS window. | |
""" | |
def max_pool(x): | |
return torch.nn.functional.max_pool2d( | |
x, kernel_size=radius * 2 + 1, stride=1, padding=radius) | |
zeros = torch.zeros_like(scores) | |
max_mask = scores == max_pool(scores) | |
for _ in range(2): | |
supp_mask = max_pool(max_mask.float()) > 0 | |
supp_scores = torch.where(supp_mask, zeros, scores) | |
new_max_mask = supp_scores == max_pool(supp_scores) | |
max_mask = max_mask | (new_max_mask & (~supp_mask)) | |
return torch.where(max_mask, scores, zeros) | |
def remove_borders(keypoints, scores, b, h, w): | |
mask_h = (keypoints[:, 0] >= b) & (keypoints[:, 0] < (h - b)) | |
mask_w = (keypoints[:, 1] >= b) & (keypoints[:, 1] < (w - b)) | |
mask = mask_h & mask_w | |
return keypoints[mask], scores[mask] | |
def top_k_keypoints(keypoints, scores, k): | |
if k >= len(keypoints): | |
return keypoints, scores | |
scores, indices = torch.topk(scores, k, dim=0, sorted=True) | |
return keypoints[indices], scores | |
def sample_descriptors(keypoints, descriptors, s): | |
b, c, h, w = descriptors.shape | |
keypoints = keypoints - s / 2 + 0.5 | |
keypoints /= torch.tensor([(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], | |
).to(keypoints)[None] | |
keypoints = keypoints * 2 - 1 # normalize to (-1, 1) | |
args = {'align_corners': True} if torch.__version__ >= '1.3' else {} | |
descriptors = torch.nn.functional.grid_sample( | |
descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) | |
descriptors = torch.nn.functional.normalize( | |
descriptors.reshape(b, c, -1), p=2, dim=1) | |
return descriptors | |
class SuperPoint(BaseModel): | |
default_conf = { | |
'has_detector': True, | |
'has_descriptor': True, | |
'descriptor_dim': 256, | |
# Inference | |
'return_all': False, | |
'sparse_outputs': True, | |
'nms_radius': 4, | |
'detection_threshold': 0.005, | |
'max_num_keypoints': -1, | |
'force_num_keypoints': False, | |
'remove_borders': 4, | |
} | |
required_data_keys = ['image'] | |
def _init(self, conf): | |
self.relu = nn.ReLU(inplace=True) | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 | |
self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) | |
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) | |
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) | |
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) | |
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) | |
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) | |
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) | |
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) | |
if conf.has_detector: | |
self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) | |
self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) | |
if conf.has_descriptor: | |
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) | |
self.convDb = nn.Conv2d( | |
c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0) | |
path = GLUESTICK_ROOT / 'resources' / 'weights' / 'superpoint_v1.pth' | |
self.load_state_dict(torch.load(str(path)), strict=False) | |
def _forward(self, data): | |
image = data['image'] | |
if image.shape[1] == 3: # RGB | |
scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) | |
image = (image * scale).sum(1, keepdim=True) | |
# Shared Encoder | |
x = self.relu(self.conv1a(image)) | |
x = self.relu(self.conv1b(x)) | |
x = self.pool(x) | |
x = self.relu(self.conv2a(x)) | |
x = self.relu(self.conv2b(x)) | |
x = self.pool(x) | |
x = self.relu(self.conv3a(x)) | |
x = self.relu(self.conv3b(x)) | |
x = self.pool(x) | |
x = self.relu(self.conv4a(x)) | |
x = self.relu(self.conv4b(x)) | |
pred = {} | |
if self.conf.has_detector and self.conf.max_num_keypoints != 0: | |
# Compute the dense keypoint scores | |
cPa = self.relu(self.convPa(x)) | |
scores = self.convPb(cPa) | |
scores = torch.nn.functional.softmax(scores, 1)[:, :-1] | |
b, c, h, w = scores.shape | |
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) | |
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) | |
pred['keypoint_scores'] = dense_scores = scores | |
if self.conf.has_descriptor: | |
# Compute the dense descriptors | |
cDa = self.relu(self.convDa(x)) | |
all_desc = self.convDb(cDa) | |
all_desc = torch.nn.functional.normalize(all_desc, p=2, dim=1) | |
pred['descriptors'] = all_desc | |
if self.conf.max_num_keypoints == 0: # Predict dense descriptors only | |
b_size = len(image) | |
device = image.device | |
return { | |
'keypoints': torch.empty(b_size, 0, 2, device=device), | |
'keypoint_scores': torch.empty(b_size, 0, device=device), | |
'descriptors': torch.empty(b_size, self.conf.descriptor_dim, 0, device=device), | |
'all_descriptors': all_desc | |
} | |
if self.conf.sparse_outputs: | |
assert self.conf.has_detector and self.conf.has_descriptor | |
scores = simple_nms(scores, self.conf.nms_radius) | |
# Extract keypoints | |
keypoints = [ | |
torch.nonzero(s > self.conf.detection_threshold) | |
for s in scores] | |
scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] | |
# Discard keypoints near the image borders | |
keypoints, scores = list(zip(*[ | |
remove_borders(k, s, self.conf.remove_borders, h * 8, w * 8) | |
for k, s in zip(keypoints, scores)])) | |
# Keep the k keypoints with highest score | |
if self.conf.max_num_keypoints > 0: | |
keypoints, scores = list(zip(*[ | |
top_k_keypoints(k, s, self.conf.max_num_keypoints) | |
for k, s in zip(keypoints, scores)])) | |
# Convert (h, w) to (x, y) | |
keypoints = [torch.flip(k, [1]).float() for k in keypoints] | |
if self.conf.force_num_keypoints: | |
_, _, h, w = data['image'].shape | |
assert self.conf.max_num_keypoints > 0 | |
scores = list(scores) | |
for i in range(len(keypoints)): | |
k, s = keypoints[i], scores[i] | |
missing = self.conf.max_num_keypoints - len(k) | |
if missing > 0: | |
new_k = torch.rand(missing, 2).to(k) | |
new_k = new_k * k.new_tensor([[w - 1, h - 1]]) | |
new_s = torch.zeros(missing).to(s) | |
keypoints[i] = torch.cat([k, new_k], 0) | |
scores[i] = torch.cat([s, new_s], 0) | |
# Extract descriptors | |
desc = [sample_descriptors(k[None], d[None], 8)[0] | |
for k, d in zip(keypoints, all_desc)] | |
if (len(keypoints) == 1) or self.conf.force_num_keypoints: | |
keypoints = torch.stack(keypoints, 0) | |
scores = torch.stack(scores, 0) | |
desc = torch.stack(desc, 0) | |
pred = { | |
'keypoints': keypoints, | |
'keypoint_scores': scores, | |
'descriptors': desc, | |
} | |
if self.conf.return_all: | |
pred['all_descriptors'] = all_desc | |
pred['dense_score'] = dense_scores | |
else: | |
del all_desc | |
torch.cuda.empty_cache() | |
return pred | |
def loss(self, pred, data): | |
raise NotImplementedError | |
def metrics(self, pred, data): | |
raise NotImplementedError | |