File size: 7,158 Bytes
404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
class BaseNet (nn.Module):
""" Takes a list of images as input, and returns for each image:
- a pixelwise descriptor
- a pixelwise confidence
"""
def softmax(self, ux):
if ux.shape[1] == 1:
x = F.softplus(ux)
return x / (1 + x) # for sure in [0,1], much less plateaus than softmax
elif ux.shape[1] == 2:
return F.softmax(ux, dim=1)[:,1:2]
def normalize(self, x, ureliability, urepeatability):
return dict(descriptors = F.normalize(x, p=2, dim=1),
repeatability = self.softmax( urepeatability ),
reliability = self.softmax( ureliability ))
def forward_one(self, x):
raise NotImplementedError()
def forward(self, imgs, **kw):
res = [self.forward_one(img) for img in imgs]
# merge all dictionaries into one
res = {k:[r[k] for r in res if k in r] for k in {k for r in res for k in r}}
return dict(res, imgs=imgs, **kw)
class PatchNet (BaseNet):
""" Helper class to construct a fully-convolutional network that
extract a l2-normalized patch descriptor.
"""
def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
BaseNet.__init__(self)
self.inchan = inchan
self.curchan = inchan
self.dilated = dilated
self.dilation = dilation
self.bn = bn
self.bn_affine = bn_affine
self.ops = nn.ModuleList([])
def _make_bn(self, outd):
return nn.BatchNorm2d(outd, affine=self.bn_affine)
def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max'):
# as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer
d = self.dilation * dilation
if self.dilated:
conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=1)
self.dilation *= stride
else:
conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride)
self.ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) )
if bn and self.bn: self.ops.append( self._make_bn(outd) )
if relu: self.ops.append( nn.ReLU(inplace=True) )
self.curchan = outd
if k_pool > 1:
if pool_type == 'avg':
self.ops.append(torch.nn.AvgPool2d(kernel_size=k_pool))
elif pool_type == 'max':
self.ops.append(torch.nn.MaxPool2d(kernel_size=k_pool))
else:
print(f"Error, unknown pooling type {pool_type}...")
def forward_one(self, x):
assert self.ops, "You need to add convolutions first"
for n,op in enumerate(self.ops):
x = op(x)
return self.normalize(x)
class L2_Net (PatchNet):
""" Compute a 128D descriptor for all overlapping 32x32 patches.
From the L2Net paper (CVPR'17).
"""
def __init__(self, dim=128, **kw ):
PatchNet.__init__(self, **kw)
add_conv = lambda n,**kw: self._add_conv((n*dim)//128,**kw)
add_conv(32)
add_conv(32)
add_conv(64, stride=2)
add_conv(64)
add_conv(128, stride=2)
add_conv(128)
add_conv(128, k=7, stride=8, bn=False, relu=False)
self.out_dim = dim
class Quad_L2Net (PatchNet):
""" Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.
"""
def __init__(self, dim=128, mchan=4, relu22=False, **kw ):
PatchNet.__init__(self, **kw)
self._add_conv( 8*mchan)
self._add_conv( 8*mchan)
self._add_conv( 16*mchan, stride=2)
self._add_conv( 16*mchan)
self._add_conv( 32*mchan, stride=2)
self._add_conv( 32*mchan)
# replace last 8x8 convolution with 3 2x2 convolutions
self._add_conv( 32*mchan, k=2, stride=2, relu=relu22)
self._add_conv( 32*mchan, k=2, stride=2, relu=relu22)
self._add_conv(dim, k=2, stride=2, bn=False, relu=False)
self.out_dim = dim
class Quad_L2Net_ConfCFS (Quad_L2Net):
""" Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability.
"""
def __init__(self, **kw ):
Quad_L2Net.__init__(self, **kw)
# reliability classifier
self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1)
# repeatability classifier: for some reasons it's a softplus, not a softmax!
# Why? I guess it's a mistake that was left unnoticed in the code for a long time...
self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1)
def forward_one(self, x):
assert self.ops, "You need to add convolutions first"
for op in self.ops:
x = op(x)
# compute the confidence maps
ureliability = self.clf(x**2)
urepeatability = self.sal(x**2)
return self.normalize(x, ureliability, urepeatability)
class Fast_Quad_L2Net (PatchNet):
""" Faster version of Quad l2 net, replacing one dilated conv with one pooling to diminish image resolution thus increase inference time
Dilation factors and pooling:
1,1,1, pool2, 1,1, 2,2, 4, 8, upsample2
"""
def __init__(self, dim=128, mchan=4, relu22=False, downsample_factor=2, **kw ):
PatchNet.__init__(self, **kw)
self._add_conv( 8*mchan)
self._add_conv( 8*mchan)
self._add_conv( 16*mchan, k_pool = downsample_factor) # added avg pooling to decrease img resolution
self._add_conv( 16*mchan)
self._add_conv( 32*mchan, stride=2)
self._add_conv( 32*mchan)
# replace last 8x8 convolution with 3 2x2 convolutions
self._add_conv( 32*mchan, k=2, stride=2, relu=relu22)
self._add_conv( 32*mchan, k=2, stride=2, relu=relu22)
self._add_conv(dim, k=2, stride=2, bn=False, relu=False)
# Go back to initial image resolution with upsampling
self.ops.append(torch.nn.Upsample(scale_factor=downsample_factor, mode='bilinear', align_corners=False))
self.out_dim = dim
class Fast_Quad_L2Net_ConfCFS (Fast_Quad_L2Net):
""" Fast r2d2 architecture
"""
def __init__(self, **kw ):
Fast_Quad_L2Net.__init__(self, **kw)
# reliability classifier
self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1)
# repeatability classifier: for some reasons it's a softplus, not a softmax!
# Why? I guess it's a mistake that was left unnoticed in the code for a long time...
self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1)
def forward_one(self, x):
assert self.ops, "You need to add convolutions first"
for op in self.ops:
x = op(x)
# compute the confidence maps
ureliability = self.clf(x**2)
urepeatability = self.sal(x**2)
return self.normalize(x, ureliability, urepeatability) |