|
import torch |
|
from torch import nn |
|
from torch.nn.parameter import Parameter |
|
import torchvision.transforms as tvf |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
|
|
def gather_nd(params, indices): |
|
orig_shape = list(indices.shape) |
|
num_samples = np.prod(orig_shape[:-1]) |
|
m = orig_shape[-1] |
|
n = len(params.shape) |
|
|
|
if m <= n: |
|
out_shape = orig_shape[:-1] + list(params.shape)[m:] |
|
else: |
|
raise ValueError( |
|
f"the last dimension of indices must less or equal to the rank of params. Got indices:{indices.shape}, params:{params.shape}. {m} > {n}" |
|
) |
|
|
|
indices = indices.reshape((num_samples, m)).transpose(0, 1).tolist() |
|
output = params[indices] |
|
return output.reshape(out_shape).contiguous() |
|
|
|
|
|
|
|
|
|
def interpolate(pos, inputs, nd=True): |
|
h = inputs.shape[0] |
|
w = inputs.shape[1] |
|
|
|
i = pos[:, 0] |
|
j = pos[:, 1] |
|
|
|
i_top_left = torch.clamp(torch.floor(i).int(), 0, h - 1) |
|
j_top_left = torch.clamp(torch.floor(j).int(), 0, w - 1) |
|
|
|
i_top_right = torch.clamp(torch.floor(i).int(), 0, h - 1) |
|
j_top_right = torch.clamp(torch.ceil(j).int(), 0, w - 1) |
|
|
|
i_bottom_left = torch.clamp(torch.ceil(i).int(), 0, h - 1) |
|
j_bottom_left = torch.clamp(torch.floor(j).int(), 0, w - 1) |
|
|
|
i_bottom_right = torch.clamp(torch.ceil(i).int(), 0, h - 1) |
|
j_bottom_right = torch.clamp(torch.ceil(j).int(), 0, w - 1) |
|
|
|
dist_i_top_left = i - i_top_left.float() |
|
dist_j_top_left = j - j_top_left.float() |
|
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left) |
|
w_top_right = (1 - dist_i_top_left) * dist_j_top_left |
|
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left) |
|
w_bottom_right = dist_i_top_left * dist_j_top_left |
|
|
|
if nd: |
|
w_top_left = w_top_left[..., None] |
|
w_top_right = w_top_right[..., None] |
|
w_bottom_left = w_bottom_left[..., None] |
|
w_bottom_right = w_bottom_right[..., None] |
|
|
|
interpolated_val = ( |
|
w_top_left * gather_nd(inputs, torch.stack([i_top_left, j_top_left], axis=-1)) |
|
+ w_top_right |
|
* gather_nd(inputs, torch.stack([i_top_right, j_top_right], axis=-1)) |
|
+ w_bottom_left |
|
* gather_nd(inputs, torch.stack([i_bottom_left, j_bottom_left], axis=-1)) |
|
+ w_bottom_right |
|
* gather_nd(inputs, torch.stack([i_bottom_right, j_bottom_right], axis=-1)) |
|
) |
|
|
|
return interpolated_val |
|
|
|
|
|
def edge_mask(inputs, n_channel, dilation=1, edge_thld=5): |
|
b, c, h, w = inputs.size() |
|
device = inputs.device |
|
|
|
dii_filter = torch.tensor([[0, 1.0, 0], [0, -2.0, 0], [0, 1.0, 0]]).view(1, 1, 3, 3) |
|
dij_filter = 0.25 * torch.tensor( |
|
[[1.0, 0, -1.0], [0, 0.0, 0], [-1.0, 0, 1.0]] |
|
).view(1, 1, 3, 3) |
|
djj_filter = torch.tensor([[0, 0, 0], [1.0, -2.0, 1.0], [0, 0, 0]]).view(1, 1, 3, 3) |
|
|
|
dii = F.conv2d( |
|
inputs.view(-1, 1, h, w), |
|
dii_filter.to(device), |
|
padding=dilation, |
|
dilation=dilation, |
|
).view(b, c, h, w) |
|
dij = F.conv2d( |
|
inputs.view(-1, 1, h, w), |
|
dij_filter.to(device), |
|
padding=dilation, |
|
dilation=dilation, |
|
).view(b, c, h, w) |
|
djj = F.conv2d( |
|
inputs.view(-1, 1, h, w), |
|
djj_filter.to(device), |
|
padding=dilation, |
|
dilation=dilation, |
|
).view(b, c, h, w) |
|
|
|
det = dii * djj - dij * dij |
|
tr = dii + djj |
|
del dii, dij, djj |
|
|
|
threshold = (edge_thld + 1) ** 2 / edge_thld |
|
is_not_edge = torch.min(tr * tr / det <= threshold, det > 0) |
|
|
|
return is_not_edge |
|
|
|
|
|
|
|
|
|
def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_size=5): |
|
h = score_map.shape[2] |
|
w = score_map.shape[3] |
|
|
|
mask = score_map > score_thld |
|
if nms_size > 0: |
|
nms_mask = F.max_pool2d( |
|
score_map, kernel_size=nms_size, stride=1, padding=nms_size // 2 |
|
) |
|
nms_mask = torch.eq(score_map, nms_mask) |
|
mask = torch.logical_and(nms_mask, mask) |
|
if eof_size > 0: |
|
eof_mask = torch.ones( |
|
(1, 1, h - 2 * eof_size, w - 2 * eof_size), |
|
dtype=torch.float32, |
|
device=score_map.device, |
|
) |
|
eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0) |
|
eof_mask = eof_mask.bool() |
|
mask = torch.logical_and(eof_mask, mask) |
|
if edge_thld > 0: |
|
non_edge_mask = edge_mask(score_map, 1, dilation=3, edge_thld=edge_thld) |
|
mask = torch.logical_and(non_edge_mask, mask) |
|
|
|
bs = score_map.shape[0] |
|
if bs is None: |
|
indices = torch.nonzero(mask)[0] |
|
scores = gather_nd(score_map, indices)[0] |
|
sample = torch.sort(scores, descending=True)[1][0:k] |
|
indices = indices[sample].unsqueeze(0) |
|
scores = scores[sample].unsqueeze(0) |
|
else: |
|
indices = [] |
|
scores = [] |
|
for i in range(bs): |
|
tmp_mask = mask[i][0] |
|
tmp_score_map = score_map[i][0] |
|
tmp_indices = torch.nonzero(tmp_mask) |
|
tmp_scores = gather_nd(tmp_score_map, tmp_indices) |
|
tmp_sample = torch.sort(tmp_scores, descending=True)[1][0:k] |
|
tmp_indices = tmp_indices[tmp_sample] |
|
tmp_scores = tmp_scores[tmp_sample] |
|
indices.append(tmp_indices) |
|
scores.append(tmp_scores) |
|
try: |
|
indices = torch.stack(indices, dim=0) |
|
scores = torch.stack(scores, dim=0) |
|
except: |
|
min_num = np.min([len(i) for i in indices]) |
|
indices = torch.stack([i[:min_num] for i in indices], dim=0) |
|
scores = torch.stack([i[:min_num] for i in scores], dim=0) |
|
return indices, scores |
|
|
|
|
|
|
|
|
|
def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1): |
|
inputs = inputs / moving_instance_max |
|
|
|
batch_size, C, H, W = inputs.shape |
|
|
|
pad_size = ksize // 2 + (dilation - 1) |
|
kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize) |
|
|
|
pad_inputs = F.pad(inputs, [pad_size] * 4, mode="reflect") |
|
|
|
avg_spatial_inputs = F.conv2d( |
|
pad_inputs, kernel, stride=1, dilation=dilation, padding=0, groups=C |
|
) |
|
avg_channel_inputs = torch.mean( |
|
inputs, axis=1, keepdim=True |
|
) |
|
|
|
|
|
alpha = F.softplus(inputs - avg_spatial_inputs) |
|
beta = F.softplus(inputs - avg_channel_inputs) |
|
|
|
return alpha, beta |
|
|
|
|
|
class DarkFeat(nn.Module): |
|
default_config = { |
|
"model_path": "", |
|
"input_type": "raw-demosaic", |
|
"kpt_n": 5000, |
|
"kpt_refinement": True, |
|
"score_thld": 0.5, |
|
"edge_thld": 10, |
|
"multi_scale": False, |
|
"multi_level": True, |
|
"nms_size": 3, |
|
"eof_size": 5, |
|
"need_norm": True, |
|
"use_peakiness": True, |
|
} |
|
|
|
def __init__( |
|
self, |
|
model_path="", |
|
inchan=3, |
|
dilated=True, |
|
dilation=1, |
|
bn=True, |
|
bn_affine=False, |
|
): |
|
super(DarkFeat, self).__init__() |
|
inchan = ( |
|
3 |
|
if self.default_config["input_type"] == "rgb" |
|
or self.default_config["input_type"] == "raw-demosaic" |
|
else 1 |
|
) |
|
self.config = {**self.default_config} |
|
|
|
self.inchan = inchan |
|
self.curchan = inchan |
|
self.dilated = dilated |
|
self.dilation = dilation |
|
self.bn = bn |
|
self.bn_affine = bn_affine |
|
self.config["model_path"] = model_path |
|
|
|
dim = 128 |
|
mchan = 4 |
|
|
|
self.conv0 = self._add_conv(8 * mchan) |
|
self.conv1 = self._add_conv(8 * mchan, bn=False) |
|
self.bn1 = self._make_bn(8 * mchan) |
|
self.conv2 = self._add_conv(16 * mchan, stride=2) |
|
self.conv3 = self._add_conv(16 * mchan, bn=False) |
|
self.bn3 = self._make_bn(16 * mchan) |
|
self.conv4 = self._add_conv(32 * mchan, stride=2) |
|
self.conv5 = self._add_conv(32 * mchan) |
|
|
|
self.conv6_0 = self._add_conv(32 * mchan) |
|
self.conv6_1 = self._add_conv(32 * mchan) |
|
self.conv6_2 = self._add_conv(dim, bn=False, relu=False) |
|
self.out_dim = dim |
|
|
|
self.moving_avg_params = nn.ParameterList( |
|
[ |
|
Parameter(torch.tensor(1.0), requires_grad=False), |
|
Parameter(torch.tensor(1.0), requires_grad=False), |
|
Parameter(torch.tensor(1.0), requires_grad=False), |
|
] |
|
) |
|
self.clf = nn.Conv2d(128, 2, kernel_size=1) |
|
|
|
state_dict = torch.load(self.config["model_path"]) |
|
new_state_dict = {} |
|
|
|
for key in state_dict: |
|
if ( |
|
"running_mean" not in key |
|
and "running_var" not in key |
|
and "num_batches_tracked" not in key |
|
): |
|
new_state_dict[key] = state_dict[key] |
|
|
|
self.load_state_dict(new_state_dict) |
|
print("Loaded DarkFeat model") |
|
|
|
def _make_bn(self, outd): |
|
return nn.BatchNorm2d(outd, affine=self.bn_affine, track_running_stats=False) |
|
|
|
def _add_conv( |
|
self, |
|
outd, |
|
k=3, |
|
stride=1, |
|
dilation=1, |
|
bn=True, |
|
relu=True, |
|
k_pool=1, |
|
pool_type="max", |
|
bias=False, |
|
): |
|
d = self.dilation * dilation |
|
conv_params = dict( |
|
padding=((k - 1) * d) // 2, dilation=d, stride=stride, bias=bias |
|
) |
|
|
|
ops = nn.ModuleList([]) |
|
|
|
ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params)) |
|
if bn and self.bn: |
|
ops.append(self._make_bn(outd)) |
|
if relu: |
|
ops.append(nn.ReLU(inplace=True)) |
|
self.curchan = outd |
|
|
|
if k_pool > 1: |
|
if pool_type == "avg": |
|
ops.append(torch.nn.AvgPool2d(kernel_size=k_pool)) |
|
elif pool_type == "max": |
|
ops.append(torch.nn.MaxPool2d(kernel_size=k_pool)) |
|
else: |
|
print(f"Error, unknown pooling type {pool_type}...") |
|
|
|
return nn.Sequential(*ops) |
|
|
|
def forward(self, input): |
|
"""Compute keypoints, scores, descriptors for image""" |
|
data = input["image"] |
|
H, W = data.shape[2:] |
|
|
|
if self.config["input_type"] == "rgb": |
|
|
|
RGB_mean = [0.485, 0.456, 0.406] |
|
RGB_std = [0.229, 0.224, 0.225] |
|
norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std) |
|
data = norm_RGB(data) |
|
|
|
elif self.config["input_type"] == "gray": |
|
|
|
data = torch.mean(data, dim=1, keepdim=True) |
|
norm_gray0 = tvf.Normalize(mean=data.mean(), std=data.std()) |
|
data = norm_gray0(data) |
|
|
|
elif self.config["input_type"] == "raw": |
|
|
|
pass |
|
elif self.config["input_type"] == "raw-demosaic": |
|
|
|
pass |
|
else: |
|
raise NotImplementedError() |
|
|
|
|
|
x0 = self.conv0(data) |
|
x1 = self.conv1(x0) |
|
x1_bn = self.bn1(x1) |
|
x2 = self.conv2(x1_bn) |
|
x3 = self.conv3(x2) |
|
x3_bn = self.bn3(x3) |
|
x4 = self.conv4(x3_bn) |
|
x5 = self.conv5(x4) |
|
x6_0 = self.conv6_0(x5) |
|
x6_1 = self.conv6_1(x6_0) |
|
x6_2 = self.conv6_2(x6_1) |
|
|
|
comb_weights = torch.tensor([1.0, 2.0, 3.0], device=data.device) |
|
comb_weights /= torch.sum(comb_weights) |
|
ksize = [3, 2, 1] |
|
det_score_maps = [] |
|
|
|
for idx, xx in enumerate([x1, x3, x6_2]): |
|
alpha, beta = peakiness_score( |
|
xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx] |
|
) |
|
score_vol = alpha * beta |
|
det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0] |
|
det_score_map = F.interpolate( |
|
det_score_map, size=data.shape[2:], mode="bilinear", align_corners=True |
|
) |
|
det_score_map = comb_weights[idx] * det_score_map |
|
det_score_maps.append(det_score_map) |
|
|
|
det_score_map = torch.sum(torch.stack(det_score_maps, dim=0), dim=0) |
|
|
|
desc = x6_2 |
|
score_map = det_score_map |
|
conf = F.softmax(self.clf((desc) ** 2), dim=1)[:, 1:2] |
|
score_map = score_map * F.interpolate( |
|
conf, size=score_map.shape[2:], mode="bilinear", align_corners=True |
|
) |
|
|
|
kpt_inds, kpt_score = extract_kpts( |
|
score_map, |
|
k=self.config["kpt_n"], |
|
score_thld=self.config["score_thld"], |
|
nms_size=self.config["nms_size"], |
|
eof_size=self.config["eof_size"], |
|
edge_thld=self.config["edge_thld"], |
|
) |
|
|
|
descs = ( |
|
F.normalize( |
|
interpolate(kpt_inds.squeeze(0) / 4, desc.squeeze(0).permute(1, 2, 0)), |
|
p=2, |
|
dim=-1, |
|
) |
|
.detach() |
|
.cpu() |
|
.numpy(), |
|
) |
|
kpts = np.squeeze( |
|
torch.stack([kpt_inds[:, :, 1], kpt_inds[:, :, 0]], dim=-1).cpu(), axis=0 |
|
) * np.array([W / data.shape[3], H / data.shape[2]], dtype=np.float32) |
|
scores = np.squeeze(kpt_score.detach().cpu().numpy(), axis=0) |
|
|
|
idxs = np.negative(scores).argsort()[0 : self.config["kpt_n"]] |
|
descs = descs[0][idxs] |
|
kpts = kpts[idxs] |
|
scores = scores[idxs] |
|
|
|
return { |
|
"keypoints": kpts, |
|
"scores": torch.from_numpy(scores), |
|
"descriptors": torch.from_numpy(descs.T), |
|
} |
|
|