|
import matplotlib.pyplot as plt |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
from lib.exceptions import EmptyTensorError |
|
|
|
|
|
def preprocess_image(image, preprocessing=None): |
|
image = image.astype(np.float32) |
|
image = np.transpose(image, [2, 0, 1]) |
|
if preprocessing is None: |
|
pass |
|
elif preprocessing == 'caffe': |
|
|
|
image = image[:: -1, :, :] |
|
|
|
mean = np.array([103.939, 116.779, 123.68]) |
|
image = image - mean.reshape([3, 1, 1]) |
|
elif preprocessing == 'torch': |
|
image /= 255.0 |
|
mean = np.array([0.485, 0.456, 0.406]) |
|
std = np.array([0.229, 0.224, 0.225]) |
|
image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1]) |
|
else: |
|
raise ValueError('Unknown preprocessing parameter.') |
|
return image |
|
|
|
|
|
def imshow_image(image, preprocessing=None): |
|
if preprocessing is None: |
|
pass |
|
elif preprocessing == 'caffe': |
|
mean = np.array([103.939, 116.779, 123.68]) |
|
image = image + mean.reshape([3, 1, 1]) |
|
|
|
image = image[:: -1, :, :] |
|
elif preprocessing == 'torch': |
|
mean = np.array([0.485, 0.456, 0.406]) |
|
std = np.array([0.229, 0.224, 0.225]) |
|
image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1]) |
|
image *= 255.0 |
|
else: |
|
raise ValueError('Unknown preprocessing parameter.') |
|
image = np.transpose(image, [1, 2, 0]) |
|
image = np.round(image).astype(np.uint8) |
|
return image |
|
|
|
|
|
def grid_positions(h, w, device, matrix=False): |
|
lines = torch.arange( |
|
0, h, device=device |
|
).view(-1, 1).float().repeat(1, w) |
|
columns = torch.arange( |
|
0, w, device=device |
|
).view(1, -1).float().repeat(h, 1) |
|
if matrix: |
|
return torch.stack([lines, columns], dim=0) |
|
else: |
|
return torch.cat([lines.view(1, -1), columns.view(1, -1)], dim=0) |
|
|
|
|
|
def upscale_positions(pos, scaling_steps=0): |
|
for _ in range(scaling_steps): |
|
pos = pos * 2 + 0.5 |
|
return pos |
|
|
|
|
|
def downscale_positions(pos, scaling_steps=0): |
|
for _ in range(scaling_steps): |
|
pos = (pos - 0.5) / 2 |
|
return pos |
|
|
|
|
|
def interpolate_dense_features(pos, dense_features, return_corners=False): |
|
device = pos.device |
|
|
|
ids = torch.arange(0, pos.size(1), device=device) |
|
|
|
_, h, w = dense_features.size() |
|
|
|
i = pos[0, :] |
|
j = pos[1, :] |
|
|
|
|
|
i_top_left = torch.floor(i).long() |
|
j_top_left = torch.floor(j).long() |
|
valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0) |
|
|
|
i_top_right = torch.floor(i).long() |
|
j_top_right = torch.ceil(j).long() |
|
valid_top_right = torch.min(i_top_right >= 0, j_top_right < w) |
|
|
|
i_bottom_left = torch.ceil(i).long() |
|
j_bottom_left = torch.floor(j).long() |
|
valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0) |
|
|
|
i_bottom_right = torch.ceil(i).long() |
|
j_bottom_right = torch.ceil(j).long() |
|
valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w) |
|
|
|
valid_corners = torch.min( |
|
torch.min(valid_top_left, valid_top_right), |
|
torch.min(valid_bottom_left, valid_bottom_right) |
|
) |
|
|
|
i_top_left = i_top_left[valid_corners] |
|
j_top_left = j_top_left[valid_corners] |
|
|
|
i_top_right = i_top_right[valid_corners] |
|
j_top_right = j_top_right[valid_corners] |
|
|
|
i_bottom_left = i_bottom_left[valid_corners] |
|
j_bottom_left = j_bottom_left[valid_corners] |
|
|
|
i_bottom_right = i_bottom_right[valid_corners] |
|
j_bottom_right = j_bottom_right[valid_corners] |
|
|
|
ids = ids[valid_corners] |
|
if ids.size(0) == 0: |
|
raise EmptyTensorError |
|
|
|
|
|
i = i[ids] |
|
j = j[ids] |
|
dist_i_top_left = i - i_top_left.float() |
|
dist_j_top_left = j - j_top_left.float() |
|
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) |
|
w_top_right = (1 - dist_i_top_left) * dist_j_top_left |
|
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) |
|
w_bottom_right = dist_i_top_left * dist_j_top_left |
|
|
|
descriptors = ( |
|
w_top_left * dense_features[:, i_top_left, j_top_left] + |
|
w_top_right * dense_features[:, i_top_right, j_top_right] + |
|
w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] + |
|
w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right] |
|
) |
|
|
|
pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0) |
|
|
|
if not return_corners: |
|
return [descriptors, pos, ids] |
|
else: |
|
corners = torch.stack([ |
|
torch.stack([i_top_left, j_top_left], dim=0), |
|
torch.stack([i_top_right, j_top_right], dim=0), |
|
torch.stack([i_bottom_left, j_bottom_left], dim=0), |
|
torch.stack([i_bottom_right, j_bottom_right], dim=0) |
|
], dim=0) |
|
return [descriptors, pos, ids, corners] |
|
|
|
|
|
def savefig(filepath, fig=None, dpi=None): |
|
|
|
if not fig: |
|
fig = plt.gcf() |
|
|
|
plt.subplots_adjust(0, 0, 1, 1, 0, 0) |
|
for ax in fig.axes: |
|
ax.axis('off') |
|
ax.margins(0, 0) |
|
ax.xaxis.set_major_locator(plt.NullLocator()) |
|
ax.yaxis.set_major_locator(plt.NullLocator()) |
|
|
|
fig.savefig(filepath, pad_inches=0, bbox_inches='tight', dpi=dpi) |
|
|