|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from lib.exceptions import EmptyTensorError |
|
from lib.utils import interpolate_dense_features, upscale_positions |
|
|
|
|
|
def process_multiscale(image, model, scales=[.5, 1, 2]): |
|
b, _, h_init, w_init = image.size() |
|
device = image.device |
|
assert(b == 1) |
|
|
|
all_keypoints = torch.zeros([3, 0]) |
|
all_descriptors = torch.zeros([ |
|
model.dense_feature_extraction.num_channels, 0 |
|
]) |
|
all_scores = torch.zeros(0) |
|
|
|
previous_dense_features = None |
|
banned = None |
|
for idx, scale in enumerate(scales): |
|
current_image = F.interpolate( |
|
image, scale_factor=scale, |
|
mode='bilinear', align_corners=True |
|
) |
|
_, _, h_level, w_level = current_image.size() |
|
|
|
dense_features = model.dense_feature_extraction(current_image) |
|
del current_image |
|
|
|
_, _, h, w = dense_features.size() |
|
|
|
|
|
if previous_dense_features is not None: |
|
dense_features += F.interpolate( |
|
previous_dense_features, size=[h, w], |
|
mode='bilinear', align_corners=True |
|
) |
|
del previous_dense_features |
|
|
|
|
|
detections = model.detection(dense_features) |
|
if banned is not None: |
|
banned = F.interpolate(banned.float(), size=[h, w]).bool() |
|
detections = torch.min(detections, ~banned) |
|
banned = torch.max( |
|
torch.max(detections, dim=1)[0].unsqueeze(1), banned |
|
) |
|
else: |
|
banned = torch.max(detections, dim=1)[0].unsqueeze(1) |
|
fmap_pos = torch.nonzero(detections[0].cpu()).t() |
|
del detections |
|
|
|
|
|
displacements = model.localization(dense_features)[0].cpu() |
|
displacements_i = displacements[ |
|
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :] |
|
] |
|
displacements_j = displacements[ |
|
1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :] |
|
] |
|
del displacements |
|
|
|
mask = torch.min( |
|
torch.abs(displacements_i) < 0.5, |
|
torch.abs(displacements_j) < 0.5 |
|
) |
|
fmap_pos = fmap_pos[:, mask] |
|
valid_displacements = torch.stack([ |
|
displacements_i[mask], |
|
displacements_j[mask] |
|
], dim=0) |
|
del mask, displacements_i, displacements_j |
|
|
|
fmap_keypoints = fmap_pos[1 :, :].float() + valid_displacements |
|
del valid_displacements |
|
|
|
try: |
|
raw_descriptors, _, ids = interpolate_dense_features( |
|
fmap_keypoints.to(device), |
|
dense_features[0] |
|
) |
|
except EmptyTensorError: |
|
continue |
|
fmap_pos = fmap_pos.to(device) |
|
fmap_keypoints = fmap_keypoints.to(device) |
|
fmap_pos = fmap_pos[:, ids] |
|
fmap_keypoints = fmap_keypoints[:, ids] |
|
del ids |
|
|
|
keypoints = upscale_positions(fmap_keypoints, scaling_steps=2) |
|
del fmap_keypoints |
|
|
|
descriptors = F.normalize(raw_descriptors, dim=0).cpu() |
|
del raw_descriptors |
|
|
|
keypoints[0, :] *= h_init / h_level |
|
keypoints[1, :] *= w_init / w_level |
|
|
|
fmap_pos = fmap_pos.cpu() |
|
keypoints = keypoints.cpu() |
|
|
|
keypoints = torch.cat([ |
|
keypoints, |
|
torch.ones([1, keypoints.size(1)]) * 1 / scale, |
|
], dim=0) |
|
|
|
scores = dense_features[ |
|
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :] |
|
].cpu() / (idx + 1) |
|
del fmap_pos |
|
|
|
all_keypoints = torch.cat([all_keypoints, keypoints], dim=1) |
|
all_descriptors = torch.cat([all_descriptors, descriptors], dim=1) |
|
all_scores = torch.cat([all_scores, scores], dim=0) |
|
del keypoints, descriptors |
|
|
|
previous_dense_features = dense_features |
|
del dense_features |
|
del previous_dense_features, banned |
|
|
|
keypoints = all_keypoints.t().detach().numpy() |
|
del all_keypoints |
|
scores = all_scores.detach().numpy() |
|
del all_scores |
|
descriptors = all_descriptors.t().detach().numpy() |
|
del all_descriptors |
|
return keypoints, scores, descriptors |
|
|