Spaces:
Running
Running
File size: 7,144 Bytes
b075789 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from einops.einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from romatch.utils.utils import get_gt_warp
import wandb
import romatch
import math
# This is slightly different than regular romatch due to significantly worse corresps
# The confidence loss is quite tricky here //Johan
class RobustLosses(nn.Module):
def __init__(
self,
robust=False,
center_coords=False,
scale_normalize=False,
ce_weight=0.01,
local_loss=True,
local_dist=None,
smooth_mask = False,
depth_interpolation_mode = "bilinear",
mask_depth_loss = False,
relative_depth_error_threshold = 0.05,
alpha = 1.,
c = 1e-3,
epe_mask_prob_th = None,
cert_only_on_consistent_depth = False,
):
super().__init__()
if local_dist is None:
local_dist = {}
self.robust = robust # measured in pixels
self.center_coords = center_coords
self.scale_normalize = scale_normalize
self.ce_weight = ce_weight
self.local_loss = local_loss
self.local_dist = local_dist
self.smooth_mask = smooth_mask
self.depth_interpolation_mode = depth_interpolation_mode
self.mask_depth_loss = mask_depth_loss
self.relative_depth_error_threshold = relative_depth_error_threshold
self.avg_overlap = dict()
self.alpha = alpha
self.c = c
self.epe_mask_prob_th = epe_mask_prob_th
self.cert_only_on_consistent_depth = cert_only_on_consistent_depth
def corr_volume_loss(self, mnn:torch.Tensor, corr_volume:torch.Tensor, scale):
b, h,w, h,w = corr_volume.shape
inv_temp = 10
corr_volume = corr_volume.reshape(-1, h*w, h*w)
nll = -(inv_temp*corr_volume).log_softmax(dim = 1) - (inv_temp*corr_volume).log_softmax(dim = 2)
corr_volume_loss = nll[mnn[:,0], mnn[:,1], mnn[:,2]].mean()
losses = {
f"gm_corr_volume_loss_{scale}": corr_volume_loss.mean(),
}
wandb.log(losses, step = romatch.GLOBAL_STEP)
return losses
def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
if scale in self.local_dist:
prob = prob * (epe < (2 / 512) * (self.local_dist[scale] * scale)).float()
if scale == 1:
pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
wandb.log({"train_pck_05": pck_05}, step = romatch.GLOBAL_STEP)
if self.epe_mask_prob_th is not None:
# if too far away from gt, certainty should be 0
gt_cert = prob * (epe < scale * self.epe_mask_prob_th)
else:
gt_cert = prob
if self.cert_only_on_consistent_depth:
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0][prob > 0], gt_cert[prob > 0])
else:
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], gt_cert)
a = self.alpha[scale] if isinstance(self.alpha, dict) else self.alpha
cs = self.c * scale
x = epe[prob > 0.99]
reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
if not torch.any(reg_loss):
reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
losses = {
f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
f"{mode}_regression_loss_{scale}": reg_loss.mean(),
}
wandb.log(losses, step = romatch.GLOBAL_STEP)
return losses
def forward(self, corresps, batch):
scales = list(corresps.keys())
tot_loss = 0.0
# scale_weights due to differences in scale for regression gradients and classification gradients
for scale in scales:
scale_corresps = corresps[scale]
scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_corr_volume, scale_gm_certainty, flow, scale_gm_flow = (
scale_corresps["certainty"],
scale_corresps.get("flow_pre_delta"),
scale_corresps.get("delta_cls"),
scale_corresps.get("offset_scale"),
scale_corresps.get("corr_volume"),
scale_corresps.get("gm_certainty"),
scale_corresps["flow"],
scale_corresps.get("gm_flow"),
)
if flow_pre_delta is not None:
flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
b, h, w, d = flow_pre_delta.shape
else:
# _ = 1
b, _, h, w = scale_certainty.shape
gt_warp, gt_prob = get_gt_warp(
batch["im_A_depth"],
batch["im_B_depth"],
batch["T_1to2"],
batch["K1"],
batch["K2"],
H=h,
W=w,
)
x2 = gt_warp.float()
prob = gt_prob
if scale_gm_corr_volume is not None:
gt_warp_back, _ = get_gt_warp(
batch["im_B_depth"],
batch["im_A_depth"],
batch["T_1to2"].inverse(),
batch["K2"],
batch["K1"],
H=h,
W=w,
)
grid = torch.stack(torch.meshgrid(torch.linspace(-1+1/w, 1-1/w, w), torch.linspace(-1+1/h, 1-1/h, h), indexing='xy'), dim =-1).to(gt_warp.device)
#fwd_bck = F.grid_sample(gt_warp_back.permute(0,3,1,2), gt_warp, align_corners=False, mode = 'bilinear').permute(0,2,3,1)
#diff = (fwd_bck - grid).norm(dim = -1)
with torch.no_grad():
D_B = torch.cdist(gt_warp.float().reshape(-1,h*w,2), grid.reshape(-1,h*w,2))
D_A = torch.cdist(grid.reshape(-1,h*w,2), gt_warp_back.float().reshape(-1,h*w,2))
inds = torch.nonzero((D_B == D_B.min(dim=-1, keepdim = True).values)
* (D_A == D_A.min(dim=-2, keepdim = True).values)
* (D_B < 0.01)
* (D_A < 0.01))
gm_cls_losses = self.corr_volume_loss(inds, scale_gm_corr_volume, scale)
gm_loss = gm_cls_losses[f"gm_corr_volume_loss_{scale}"]
tot_loss = tot_loss + gm_loss
elif scale_gm_flow is not None:
gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
tot_loss = tot_loss + gm_loss
delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
tot_loss = tot_loss + reg_loss
return tot_loss
|