File size: 4,560 Bytes
437b5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from .score import peakiness_score


class BaseNet(nn.Module):
    """ Helper class to construct a fully-convolutional network that
        extract a l2-normalized patch descriptor.
    """
    def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
        super(BaseNet, self).__init__()
        self.inchan = inchan
        self.curchan = inchan
        self.dilated = dilated
        self.dilation = dilation
        self.bn = bn
        self.bn_affine = bn_affine

    def _make_bn(self, outd):
        return nn.BatchNorm2d(outd, affine=self.bn_affine)

    def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool = 1, pool_type='max', bias=False):
        # as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer
        d = self.dilation * dilation
        # if self.dilated: 
        #     conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=1)
        #     self.dilation *= stride
        # else:
        #     conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride)
        conv_params = dict(padding=((k-1)*d)//2, dilation=d, stride=stride, bias=bias)

        ops = nn.ModuleList([])

        ops.append( nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params) )
        if bn and self.bn: ops.append( self._make_bn(outd) )
        if relu: ops.append( nn.ReLU(inplace=True) )
        self.curchan = outd
        
        if k_pool > 1:
            if pool_type == 'avg':
                ops.append(torch.nn.AvgPool2d(kernel_size=k_pool))
            elif pool_type == 'max':
                ops.append(torch.nn.MaxPool2d(kernel_size=k_pool))
            else:
                print(f"Error, unknown pooling type {pool_type}...")

        return nn.Sequential(*ops)


class Quad_L2Net(BaseNet):
    """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.
    """
    def __init__(self, dim=128, mchan=4, relu22=False, **kw):
        BaseNet.__init__(self, **kw)
        self.conv0 = self._add_conv(  8*mchan)
        self.conv1 = self._add_conv(  8*mchan, bn=False)
        self.bn1 = self._make_bn(8*mchan)
        self.conv2 = self._add_conv( 16*mchan, stride=2)
        self.conv3 = self._add_conv( 16*mchan, bn=False)
        self.bn3 = self._make_bn(16*mchan)
        self.conv4 = self._add_conv( 32*mchan, stride=2)
        self.conv5 = self._add_conv( 32*mchan)
        # replace last 8x8 convolution with 3 3x3 convolutions
        self.conv6_0 = self._add_conv( 32*mchan)
        self.conv6_1 = self._add_conv( 32*mchan)
        self.conv6_2 = self._add_conv(dim, bn=False, relu=False)
        self.out_dim = dim

        self.moving_avg_params = nn.ParameterList([
            Parameter(torch.tensor(1.), requires_grad=False),
            Parameter(torch.tensor(1.), requires_grad=False),
            Parameter(torch.tensor(1.), requires_grad=False)
        ])

    def forward(self, x):
        # x: [N, C, H, W]
        x0 = self.conv0(x)
        x1 = self.conv1(x0)
        x1_bn = self.bn1(x1)
        x2 = self.conv2(x1_bn)
        x3 = self.conv3(x2)
        x3_bn = self.bn3(x3)
        x4 = self.conv4(x3_bn)
        x5 = self.conv5(x4)
        x6_0 = self.conv6_0(x5)
        x6_1 = self.conv6_1(x6_0)
        x6_2 = self.conv6_2(x6_1)

        # calculate score map
        comb_weights = torch.tensor([1., 2., 3.], device=x.device)
        comb_weights /= torch.sum(comb_weights)
        ksize = [3, 2, 1]
        det_score_maps = []

        for idx, xx in enumerate([x1, x3, x6_2]):
            if self.training:
                instance_max = torch.max(xx)
                self.moving_avg_params[idx].data = self.moving_avg_params[idx] * 0.99 + instance_max.detach() * 0.01
            else:
                pass

            alpha, beta = peakiness_score(xx, self.moving_avg_params[idx].detach(), ksize=3, dilation=ksize[idx])

            score_vol = alpha * beta
            det_score_map = torch.max(score_vol, dim=1, keepdim=True)[0]
            det_score_map = F.interpolate(det_score_map, size=x.shape[2:], mode='bilinear', align_corners=True)
            det_score_map = comb_weights[idx] * det_score_map
            det_score_maps.append(det_score_map)

        det_score_map = torch.sum(torch.stack(det_score_maps, dim=0), dim=0)
        # print([param.data for param in self.moving_avg_params])

        return x6_2, det_score_map, x1, x3