# Copyright 2019-present NAVER Corp. # CC BY-NC-SA 3.0 # Available only for non-commercial use import pdb import torch import torch.nn as nn import torch.nn.functional as F from nets.sampler import FullSampler class CosimLoss (nn.Module): """ Try to make the repeatability repeatable from one image to the other. """ def __init__(self, N=16): nn.Module.__init__(self) self.name = f'cosim{N}' self.patches = nn.Unfold(N, padding=0, stride=N//2) def extract_patches(self, sal): patches = self.patches(sal).transpose(1,2) # flatten patches = F.normalize(patches, p=2, dim=2) # norm return patches def forward(self, repeatability, aflow, **kw): B,two,H,W = aflow.shape assert two == 2 # normalize sali1, sali2 = repeatability grid = FullSampler._aflow_to_grid(aflow) sali2 = F.grid_sample(sali2, grid, mode='bilinear', padding_mode='border') patches1 = self.extract_patches(sali1) patches2 = self.extract_patches(sali2) cosim = (patches1 * patches2).sum(dim=2) return 1 - cosim.mean() class PeakyLoss (nn.Module): """ Try to make the repeatability locally peaky. Mechanism: we maximize, for each pixel, the difference between the local mean and the local max. """ def __init__(self, N=16): nn.Module.__init__(self) self.name = f'peaky{N}' assert N % 2 == 0, 'N must be pair' self.preproc = nn.AvgPool2d(3, stride=1, padding=1) self.maxpool = nn.MaxPool2d(N+1, stride=1, padding=N//2) self.avgpool = nn.AvgPool2d(N+1, stride=1, padding=N//2) def forward_one(self, sali): sali = self.preproc(sali) # remove super high frequency return 1 - (self.maxpool(sali) - self.avgpool(sali)).mean() def forward(self, repeatability, **kw): sali1, sali2 = repeatability return (self.forward_one(sali1) + self.forward_one(sali2)) /2