|
from einops.einops import rearrange |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from roma.utils.utils import get_gt_warp |
|
import wandb |
|
import roma |
|
import math |
|
|
|
class RobustLosses(nn.Module): |
|
def __init__( |
|
self, |
|
robust=False, |
|
center_coords=False, |
|
scale_normalize=False, |
|
ce_weight=0.01, |
|
local_loss=True, |
|
local_dist=4.0, |
|
local_largest_scale=8, |
|
smooth_mask = False, |
|
depth_interpolation_mode = "bilinear", |
|
mask_depth_loss = False, |
|
relative_depth_error_threshold = 0.05, |
|
alpha = 1., |
|
c = 1e-3, |
|
): |
|
super().__init__() |
|
self.robust = robust |
|
self.center_coords = center_coords |
|
self.scale_normalize = scale_normalize |
|
self.ce_weight = ce_weight |
|
self.local_loss = local_loss |
|
self.local_dist = local_dist |
|
self.local_largest_scale = local_largest_scale |
|
self.smooth_mask = smooth_mask |
|
self.depth_interpolation_mode = depth_interpolation_mode |
|
self.mask_depth_loss = mask_depth_loss |
|
self.relative_depth_error_threshold = relative_depth_error_threshold |
|
self.avg_overlap = dict() |
|
self.alpha = alpha |
|
self.c = c |
|
|
|
def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale): |
|
with torch.no_grad(): |
|
B, C, H, W = scale_gm_cls.shape |
|
device = x2.device |
|
cls_res = round(math.sqrt(C)) |
|
G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)]) |
|
G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) |
|
GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices |
|
cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99] |
|
if not torch.any(cls_loss): |
|
cls_loss = (certainty_loss * 0.0) |
|
|
|
certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob) |
|
losses = { |
|
f"gm_certainty_loss_{scale}": certainty_loss.mean(), |
|
f"gm_cls_loss_{scale}": cls_loss.mean(), |
|
} |
|
wandb.log(losses, step = roma.GLOBAL_STEP) |
|
return losses |
|
|
|
def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale): |
|
with torch.no_grad(): |
|
B, C, H, W = delta_cls.shape |
|
device = x2.device |
|
cls_res = round(math.sqrt(C)) |
|
G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)]) |
|
G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale |
|
GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices |
|
cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99] |
|
if not torch.any(cls_loss): |
|
cls_loss = (certainty_loss * 0.0) |
|
certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob) |
|
losses = { |
|
f"delta_certainty_loss_{scale}": certainty_loss.mean(), |
|
f"delta_cls_loss_{scale}": cls_loss.mean(), |
|
} |
|
wandb.log(losses, step = roma.GLOBAL_STEP) |
|
return losses |
|
|
|
def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"): |
|
epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1) |
|
if scale == 1: |
|
pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean() |
|
wandb.log({"train_pck_05": pck_05}, step = roma.GLOBAL_STEP) |
|
|
|
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob) |
|
a = self.alpha |
|
cs = self.c * scale |
|
x = epe[prob > 0.99] |
|
reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2) |
|
if not torch.any(reg_loss): |
|
reg_loss = (ce_loss * 0.0) |
|
losses = { |
|
f"{mode}_certainty_loss_{scale}": ce_loss.mean(), |
|
f"{mode}_regression_loss_{scale}": reg_loss.mean(), |
|
} |
|
wandb.log(losses, step = roma.GLOBAL_STEP) |
|
return losses |
|
|
|
def forward(self, corresps, batch): |
|
scales = list(corresps.keys()) |
|
tot_loss = 0.0 |
|
|
|
scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1} |
|
for scale in scales: |
|
scale_corresps = corresps[scale] |
|
scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = ( |
|
scale_corresps["certainty"], |
|
scale_corresps["flow_pre_delta"], |
|
scale_corresps.get("delta_cls"), |
|
scale_corresps.get("offset_scale"), |
|
scale_corresps.get("gm_cls"), |
|
scale_corresps.get("gm_certainty"), |
|
scale_corresps["flow"], |
|
scale_corresps.get("gm_flow"), |
|
|
|
) |
|
flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d") |
|
b, h, w, d = flow_pre_delta.shape |
|
gt_warp, gt_prob = get_gt_warp( |
|
batch["im_A_depth"], |
|
batch["im_B_depth"], |
|
batch["T_1to2"], |
|
batch["K1"], |
|
batch["K2"], |
|
H=h, |
|
W=w, |
|
) |
|
x2 = gt_warp.float() |
|
prob = gt_prob |
|
|
|
if self.local_largest_scale >= scale: |
|
prob = prob * ( |
|
F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0] |
|
< (2 / 512) * (self.local_dist[scale] * scale)) |
|
|
|
if scale_gm_cls is not None: |
|
gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale) |
|
gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"] |
|
tot_loss = tot_loss + scale_weights[scale] * gm_loss |
|
elif scale_gm_flow is not None: |
|
gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm") |
|
gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"] |
|
tot_loss = tot_loss + scale_weights[scale] * gm_loss |
|
|
|
if delta_cls is not None: |
|
delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale) |
|
delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"] |
|
tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss |
|
else: |
|
delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale) |
|
reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"] |
|
tot_loss = tot_loss + scale_weights[scale] * reg_loss |
|
prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach() |
|
return tot_loss |
|
|