File size: 2,005 Bytes
437b5f6
 
 
 
 
 
 
 
 
 
 
 
4c12b36
 
 
 
437b5f6
 
4c12b36
 
437b5f6
 
4c12b36
 
437b5f6
4c12b36
437b5f6
4c12b36
437b5f6
 
 
 
 
4c12b36
437b5f6
 
 
 
 
 
 
4c12b36
 
437b5f6
 
 
 
4c12b36
437b5f6
 
4c12b36
 
437b5f6
4c12b36
 
437b5f6
 
4c12b36
437b5f6
 
 
 
4c12b36
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use

import pdb

import torch
import torch.nn as nn
import torch.nn.functional as F

from nets.sampler import FullSampler


class CosimLoss(nn.Module):
    """Try to make the repeatability repeatable from one image to the other."""

    def __init__(self, N=16):
        nn.Module.__init__(self)
        self.name = f"cosim{N}"
        self.patches = nn.Unfold(N, padding=0, stride=N // 2)

    def extract_patches(self, sal):
        patches = self.patches(sal).transpose(1, 2)  # flatten
        patches = F.normalize(patches, p=2, dim=2)  # norm
        return patches

    def forward(self, repeatability, aflow, **kw):
        B, two, H, W = aflow.shape
        assert two == 2

        # normalize
        sali1, sali2 = repeatability
        grid = FullSampler._aflow_to_grid(aflow)
        sali2 = F.grid_sample(sali2, grid, mode="bilinear", padding_mode="border")

        patches1 = self.extract_patches(sali1)
        patches2 = self.extract_patches(sali2)
        cosim = (patches1 * patches2).sum(dim=2)
        return 1 - cosim.mean()


class PeakyLoss(nn.Module):
    """Try to make the repeatability locally peaky.

    Mechanism: we maximize, for each pixel, the difference between the local mean
               and the local max.
    """

    def __init__(self, N=16):
        nn.Module.__init__(self)
        self.name = f"peaky{N}"
        assert N % 2 == 0, "N must be pair"
        self.preproc = nn.AvgPool2d(3, stride=1, padding=1)
        self.maxpool = nn.MaxPool2d(N + 1, stride=1, padding=N // 2)
        self.avgpool = nn.AvgPool2d(N + 1, stride=1, padding=N // 2)

    def forward_one(self, sali):
        sali = self.preproc(sali)  # remove super high frequency
        return 1 - (self.maxpool(sali) - self.avgpool(sali)).mean()

    def forward(self, repeatability, **kw):
        sali1, sali2 = repeatability
        return (self.forward_one(sali1) + self.forward_one(sali2)) / 2