|
|
|
|
|
|
|
|
|
import pdb |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from nets.ap_loss import APLoss |
|
|
|
|
|
class PixelAPLoss(nn.Module): |
|
"""Computes the pixel-wise AP loss: |
|
Given two images and ground-truth optical flow, computes the AP per pixel. |
|
|
|
feat1: (B, C, H, W) pixel-wise features extracted from img1 |
|
feat2: (B, C, H, W) pixel-wise features extracted from img2 |
|
aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 |
|
""" |
|
|
|
def __init__(self, sampler, nq=20): |
|
nn.Module.__init__(self) |
|
self.aploss = APLoss(nq, min=0, max=1, euc=False) |
|
self.name = "pixAP" |
|
self.sampler = sampler |
|
|
|
def loss_from_ap(self, ap, rel): |
|
return 1 - ap |
|
|
|
def forward(self, descriptors, aflow, **kw): |
|
|
|
scores, gt, msk, qconf = self.sampler(descriptors, kw.get("reliability"), aflow) |
|
|
|
|
|
n = qconf.numel() |
|
if n == 0: |
|
return 0 |
|
scores, gt = scores.view(n, -1), gt.view(n, -1) |
|
ap = self.aploss(scores, gt).view(msk.shape) |
|
|
|
pixel_loss = self.loss_from_ap(ap, qconf) |
|
|
|
loss = pixel_loss[msk].mean() |
|
return loss |
|
|
|
|
|
class ReliabilityLoss(PixelAPLoss): |
|
"""same than PixelAPLoss, but also train a pixel-wise confidence |
|
that this pixel is going to have a good AP. |
|
""" |
|
|
|
def __init__(self, sampler, base=0.5, **kw): |
|
PixelAPLoss.__init__(self, sampler, **kw) |
|
assert 0 <= base < 1 |
|
self.base = base |
|
self.name = "reliability" |
|
|
|
def loss_from_ap(self, ap, rel): |
|
return 1 - ap * rel - (1 - rel) * self.base |
|
|