Vincentqyw
fix: roma
358ab8f
raw
history blame
7.24 kB
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
class BaseNet(nn.Module):
"""Takes a list of images as input, and returns for each image:
- a pixelwise descriptor
- a pixelwise confidence
"""
def softmax(self, ux):
if ux.shape[1] == 1:
x = F.softplus(ux)
return x / (1 + x) # for sure in [0,1], much less plateaus than softmax
elif ux.shape[1] == 2:
return F.softmax(ux, dim=1)[:, 1:2]
def normalize(self, x, ureliability, urepeatability):
return dict(
descriptors=F.normalize(x, p=2, dim=1),
repeatability=self.softmax(urepeatability),
reliability=self.softmax(ureliability),
)
def forward_one(self, x):
raise NotImplementedError()
def forward(self, imgs, **kw):
res = [self.forward_one(img) for img in imgs]
# merge all dictionaries into one
res = {k: [r[k] for r in res if k in r] for k in {k for r in res for k in r}}
return dict(res, imgs=imgs, **kw)
class PatchNet(BaseNet):
"""Helper class to construct a fully-convolutional network that
extract a l2-normalized patch descriptor.
"""
def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
BaseNet.__init__(self)
self.inchan = inchan
self.curchan = inchan
self.dilated = dilated
self.dilation = dilation
self.bn = bn
self.bn_affine = bn_affine
self.ops = nn.ModuleList([])
def _make_bn(self, outd):
return nn.BatchNorm2d(outd, affine=self.bn_affine)
def _add_conv(
self,
outd,
k=3,
stride=1,
dilation=1,
bn=True,
relu=True,
k_pool=1,
pool_type="max",
):
# as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer
d = self.dilation * dilation
if self.dilated:
conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=1)
self.dilation *= stride
else:
conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=stride)
self.ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params))
if bn and self.bn:
self.ops.append(self._make_bn(outd))
if relu:
self.ops.append(nn.ReLU(inplace=True))
self.curchan = outd
if k_pool > 1:
if pool_type == "avg":
self.ops.append(torch.nn.AvgPool2d(kernel_size=k_pool))
elif pool_type == "max":
self.ops.append(torch.nn.MaxPool2d(kernel_size=k_pool))
else:
print(f"Error, unknown pooling type {pool_type}...")
def forward_one(self, x):
assert self.ops, "You need to add convolutions first"
for n, op in enumerate(self.ops):
x = op(x)
return self.normalize(x)
class L2_Net(PatchNet):
"""Compute a 128D descriptor for all overlapping 32x32 patches.
From the L2Net paper (CVPR'17).
"""
def __init__(self, dim=128, **kw):
PatchNet.__init__(self, **kw)
add_conv = lambda n, **kw: self._add_conv((n * dim) // 128, **kw)
add_conv(32)
add_conv(32)
add_conv(64, stride=2)
add_conv(64)
add_conv(128, stride=2)
add_conv(128)
add_conv(128, k=7, stride=8, bn=False, relu=False)
self.out_dim = dim
class Quad_L2Net(PatchNet):
"""Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs."""
def __init__(self, dim=128, mchan=4, relu22=False, **kw):
PatchNet.__init__(self, **kw)
self._add_conv(8 * mchan)
self._add_conv(8 * mchan)
self._add_conv(16 * mchan, stride=2)
self._add_conv(16 * mchan)
self._add_conv(32 * mchan, stride=2)
self._add_conv(32 * mchan)
# replace last 8x8 convolution with 3 2x2 convolutions
self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
self._add_conv(dim, k=2, stride=2, bn=False, relu=False)
self.out_dim = dim
class Quad_L2Net_ConfCFS(Quad_L2Net):
"""Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability."""
def __init__(self, **kw):
Quad_L2Net.__init__(self, **kw)
# reliability classifier
self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1)
# repeatability classifier: for some reasons it's a softplus, not a softmax!
# Why? I guess it's a mistake that was left unnoticed in the code for a long time...
self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1)
def forward_one(self, x):
assert self.ops, "You need to add convolutions first"
for op in self.ops:
x = op(x)
# compute the confidence maps
ureliability = self.clf(x**2)
urepeatability = self.sal(x**2)
return self.normalize(x, ureliability, urepeatability)
class Fast_Quad_L2Net(PatchNet):
"""Faster version of Quad l2 net, replacing one dilated conv with one pooling to diminish image resolution thus increase inference time
Dilation factors and pooling:
1,1,1, pool2, 1,1, 2,2, 4, 8, upsample2
"""
def __init__(self, dim=128, mchan=4, relu22=False, downsample_factor=2, **kw):
PatchNet.__init__(self, **kw)
self._add_conv(8 * mchan)
self._add_conv(8 * mchan)
self._add_conv(
16 * mchan, k_pool=downsample_factor
) # added avg pooling to decrease img resolution
self._add_conv(16 * mchan)
self._add_conv(32 * mchan, stride=2)
self._add_conv(32 * mchan)
# replace last 8x8 convolution with 3 2x2 convolutions
self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
self._add_conv(dim, k=2, stride=2, bn=False, relu=False)
# Go back to initial image resolution with upsampling
self.ops.append(
torch.nn.Upsample(
scale_factor=downsample_factor, mode="bilinear", align_corners=False
)
)
self.out_dim = dim
class Fast_Quad_L2Net_ConfCFS(Fast_Quad_L2Net):
"""Fast r2d2 architecture"""
def __init__(self, **kw):
Fast_Quad_L2Net.__init__(self, **kw)
# reliability classifier
self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1)
# repeatability classifier: for some reasons it's a softplus, not a softmax!
# Why? I guess it's a mistake that was left unnoticed in the code for a long time...
self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1)
def forward_one(self, x):
assert self.ops, "You need to add convolutions first"
for op in self.ops:
x = op(x)
# compute the confidence maps
ureliability = self.clf(x**2)
urepeatability = self.sal(x**2)
return self.normalize(x, ureliability, urepeatability)