vidimatch / third_party /r2d2 /nets /repeatability_loss.py
Vincentqyw
update: features and matchers
404d2af
raw
history blame
2.01 kB
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.sampler import FullSampler
class CosimLoss (nn.Module):
""" Try to make the repeatability repeatable from one image to the other.
"""
def __init__(self, N=16):
nn.Module.__init__(self)
self.name = f'cosim{N}'
self.patches = nn.Unfold(N, padding=0, stride=N//2)
def extract_patches(self, sal):
patches = self.patches(sal).transpose(1,2) # flatten
patches = F.normalize(patches, p=2, dim=2) # norm
return patches
def forward(self, repeatability, aflow, **kw):
B,two,H,W = aflow.shape
assert two == 2
# normalize
sali1, sali2 = repeatability
grid = FullSampler._aflow_to_grid(aflow)
sali2 = F.grid_sample(sali2, grid, mode='bilinear', padding_mode='border')
patches1 = self.extract_patches(sali1)
patches2 = self.extract_patches(sali2)
cosim = (patches1 * patches2).sum(dim=2)
return 1 - cosim.mean()
class PeakyLoss (nn.Module):
""" Try to make the repeatability locally peaky.
Mechanism: we maximize, for each pixel, the difference between the local mean
and the local max.
"""
def __init__(self, N=16):
nn.Module.__init__(self)
self.name = f'peaky{N}'
assert N % 2 == 0, 'N must be pair'
self.preproc = nn.AvgPool2d(3, stride=1, padding=1)
self.maxpool = nn.MaxPool2d(N+1, stride=1, padding=N//2)
self.avgpool = nn.AvgPool2d(N+1, stride=1, padding=N//2)
def forward_one(self, sali):
sali = self.preproc(sali) # remove super high frequency
return 1 - (self.maxpool(sali) - self.avgpool(sali)).mean()
def forward(self, repeatability, **kw):
sali1, sali2 = repeatability
return (self.forward_one(sali1) + self.forward_one(sali2)) /2