|
|
|
|
|
|
|
|
|
import pdb |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class BaseNet(nn.Module): |
|
"""Takes a list of images as input, and returns for each image: |
|
- a pixelwise descriptor |
|
- a pixelwise confidence |
|
""" |
|
|
|
def softmax(self, ux): |
|
if ux.shape[1] == 1: |
|
x = F.softplus(ux) |
|
return x / (1 + x) |
|
elif ux.shape[1] == 2: |
|
return F.softmax(ux, dim=1)[:, 1:2] |
|
|
|
def normalize(self, x, ureliability, urepeatability): |
|
return dict( |
|
descriptors=F.normalize(x, p=2, dim=1), |
|
repeatability=self.softmax(urepeatability), |
|
reliability=self.softmax(ureliability), |
|
) |
|
|
|
def forward_one(self, x): |
|
raise NotImplementedError() |
|
|
|
def forward(self, imgs, **kw): |
|
res = [self.forward_one(img) for img in imgs] |
|
|
|
res = {k: [r[k] for r in res if k in r] for k in {k for r in res for k in r}} |
|
return dict(res, imgs=imgs, **kw) |
|
|
|
|
|
class PatchNet(BaseNet): |
|
"""Helper class to construct a fully-convolutional network that |
|
extract a l2-normalized patch descriptor. |
|
""" |
|
|
|
def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False): |
|
BaseNet.__init__(self) |
|
self.inchan = inchan |
|
self.curchan = inchan |
|
self.dilated = dilated |
|
self.dilation = dilation |
|
self.bn = bn |
|
self.bn_affine = bn_affine |
|
self.ops = nn.ModuleList([]) |
|
|
|
def _make_bn(self, outd): |
|
return nn.BatchNorm2d(outd, affine=self.bn_affine) |
|
|
|
def _add_conv( |
|
self, |
|
outd, |
|
k=3, |
|
stride=1, |
|
dilation=1, |
|
bn=True, |
|
relu=True, |
|
k_pool=1, |
|
pool_type="max", |
|
): |
|
|
|
d = self.dilation * dilation |
|
if self.dilated: |
|
conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=1) |
|
self.dilation *= stride |
|
else: |
|
conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=stride) |
|
self.ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params)) |
|
if bn and self.bn: |
|
self.ops.append(self._make_bn(outd)) |
|
if relu: |
|
self.ops.append(nn.ReLU(inplace=True)) |
|
self.curchan = outd |
|
|
|
if k_pool > 1: |
|
if pool_type == "avg": |
|
self.ops.append(torch.nn.AvgPool2d(kernel_size=k_pool)) |
|
elif pool_type == "max": |
|
self.ops.append(torch.nn.MaxPool2d(kernel_size=k_pool)) |
|
else: |
|
print(f"Error, unknown pooling type {pool_type}...") |
|
|
|
def forward_one(self, x): |
|
assert self.ops, "You need to add convolutions first" |
|
for n, op in enumerate(self.ops): |
|
x = op(x) |
|
return self.normalize(x) |
|
|
|
|
|
class L2_Net(PatchNet): |
|
"""Compute a 128D descriptor for all overlapping 32x32 patches. |
|
From the L2Net paper (CVPR'17). |
|
""" |
|
|
|
def __init__(self, dim=128, **kw): |
|
PatchNet.__init__(self, **kw) |
|
add_conv = lambda n, **kw: self._add_conv((n * dim) // 128, **kw) |
|
add_conv(32) |
|
add_conv(32) |
|
add_conv(64, stride=2) |
|
add_conv(64) |
|
add_conv(128, stride=2) |
|
add_conv(128) |
|
add_conv(128, k=7, stride=8, bn=False, relu=False) |
|
self.out_dim = dim |
|
|
|
|
|
class Quad_L2Net(PatchNet): |
|
"""Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.""" |
|
|
|
def __init__(self, dim=128, mchan=4, relu22=False, **kw): |
|
PatchNet.__init__(self, **kw) |
|
self._add_conv(8 * mchan) |
|
self._add_conv(8 * mchan) |
|
self._add_conv(16 * mchan, stride=2) |
|
self._add_conv(16 * mchan) |
|
self._add_conv(32 * mchan, stride=2) |
|
self._add_conv(32 * mchan) |
|
|
|
self._add_conv(32 * mchan, k=2, stride=2, relu=relu22) |
|
self._add_conv(32 * mchan, k=2, stride=2, relu=relu22) |
|
self._add_conv(dim, k=2, stride=2, bn=False, relu=False) |
|
self.out_dim = dim |
|
|
|
|
|
class Quad_L2Net_ConfCFS(Quad_L2Net): |
|
"""Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability.""" |
|
|
|
def __init__(self, **kw): |
|
Quad_L2Net.__init__(self, **kw) |
|
|
|
self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1) |
|
|
|
|
|
self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) |
|
|
|
def forward_one(self, x): |
|
assert self.ops, "You need to add convolutions first" |
|
for op in self.ops: |
|
x = op(x) |
|
|
|
ureliability = self.clf(x**2) |
|
urepeatability = self.sal(x**2) |
|
return self.normalize(x, ureliability, urepeatability) |
|
|
|
|
|
class Fast_Quad_L2Net(PatchNet): |
|
"""Faster version of Quad l2 net, replacing one dilated conv with one pooling to diminish image resolution thus increase inference time |
|
Dilation factors and pooling: |
|
1,1,1, pool2, 1,1, 2,2, 4, 8, upsample2 |
|
""" |
|
|
|
def __init__(self, dim=128, mchan=4, relu22=False, downsample_factor=2, **kw): |
|
|
|
PatchNet.__init__(self, **kw) |
|
self._add_conv(8 * mchan) |
|
self._add_conv(8 * mchan) |
|
self._add_conv( |
|
16 * mchan, k_pool=downsample_factor |
|
) |
|
self._add_conv(16 * mchan) |
|
self._add_conv(32 * mchan, stride=2) |
|
self._add_conv(32 * mchan) |
|
|
|
|
|
self._add_conv(32 * mchan, k=2, stride=2, relu=relu22) |
|
self._add_conv(32 * mchan, k=2, stride=2, relu=relu22) |
|
self._add_conv(dim, k=2, stride=2, bn=False, relu=False) |
|
|
|
|
|
self.ops.append( |
|
torch.nn.Upsample( |
|
scale_factor=downsample_factor, mode="bilinear", align_corners=False |
|
) |
|
) |
|
|
|
self.out_dim = dim |
|
|
|
|
|
class Fast_Quad_L2Net_ConfCFS(Fast_Quad_L2Net): |
|
"""Fast r2d2 architecture""" |
|
|
|
def __init__(self, **kw): |
|
Fast_Quad_L2Net.__init__(self, **kw) |
|
|
|
self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1) |
|
|
|
|
|
|
|
self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1) |
|
|
|
def forward_one(self, x): |
|
assert self.ops, "You need to add convolutions first" |
|
for op in self.ops: |
|
x = op(x) |
|
|
|
ureliability = self.clf(x**2) |
|
urepeatability = self.sal(x**2) |
|
return self.normalize(x, ureliability, urepeatability) |
|
|