Realcat
fix: extractors
631fac6
raw
history blame
4.18 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .exceptions import EmptyTensorError
from .utils import interpolate_dense_features, upscale_positions
def process_multiscale(image, model, scales=[.5, 1, 2]):
b, _, h_init, w_init = image.size()
device = image.device
assert(b == 1)
all_keypoints = torch.zeros([3, 0])
all_descriptors = torch.zeros([
model.dense_feature_extraction.num_channels, 0
])
all_scores = torch.zeros(0)
previous_dense_features = None
banned = None
for idx, scale in enumerate(scales):
current_image = F.interpolate(
image, scale_factor=scale,
mode='bilinear', align_corners=True
)
_, _, h_level, w_level = current_image.size()
dense_features = model.dense_feature_extraction(current_image)
del current_image
_, _, h, w = dense_features.size()
# Sum the feature maps.
if previous_dense_features is not None:
dense_features += F.interpolate(
previous_dense_features, size=[h, w],
mode='bilinear', align_corners=True
)
del previous_dense_features
# Recover detections.
detections = model.detection(dense_features)
if banned is not None:
banned = F.interpolate(banned.float(), size=[h, w]).bool()
detections = torch.min(detections, ~banned)
banned = torch.max(
torch.max(detections, dim=1)[0].unsqueeze(1), banned
)
else:
banned = torch.max(detections, dim=1)[0].unsqueeze(1)
fmap_pos = torch.nonzero(detections[0].cpu()).t()
del detections
# Recover displacements.
displacements = model.localization(dense_features)[0].cpu()
displacements_i = displacements[
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
]
displacements_j = displacements[
1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
]
del displacements
mask = torch.min(
torch.abs(displacements_i) < 0.5,
torch.abs(displacements_j) < 0.5
)
fmap_pos = fmap_pos[:, mask]
valid_displacements = torch.stack([
displacements_i[mask],
displacements_j[mask]
], dim=0)
del mask, displacements_i, displacements_j
fmap_keypoints = fmap_pos[1 :, :].float() + valid_displacements
del valid_displacements
try:
raw_descriptors, _, ids = interpolate_dense_features(
fmap_keypoints.to(device),
dense_features[0]
)
except EmptyTensorError:
continue
fmap_pos = fmap_pos.to(device)
fmap_keypoints = fmap_keypoints.to(device)
fmap_pos = fmap_pos[:, ids]
fmap_keypoints = fmap_keypoints[:, ids]
del ids
keypoints = upscale_positions(fmap_keypoints, scaling_steps=2)
del fmap_keypoints
descriptors = F.normalize(raw_descriptors, dim=0).cpu()
del raw_descriptors
keypoints[0, :] *= h_init / h_level
keypoints[1, :] *= w_init / w_level
fmap_pos = fmap_pos.cpu()
keypoints = keypoints.cpu()
keypoints = torch.cat([
keypoints,
torch.ones([1, keypoints.size(1)]) * 1 / scale,
], dim=0)
scores = dense_features[
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
].cpu() / (idx + 1)
del fmap_pos
all_keypoints = torch.cat([all_keypoints, keypoints], dim=1)
all_descriptors = torch.cat([all_descriptors, descriptors], dim=1)
all_scores = torch.cat([all_scores, scores], dim=0)
del keypoints, descriptors
previous_dense_features = dense_features
del dense_features
del previous_dense_features, banned
keypoints = all_keypoints.t().detach().numpy()
del all_keypoints
scores = all_scores.detach().numpy()
del all_scores
descriptors = all_descriptors.t().detach().numpy()
del all_descriptors
return keypoints, scores, descriptors