CHSTR's picture
Upload src
265ae36 verified
raw
history blame
2.54 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from ...models.builder import LOSSES
@LOSSES.register_module()
class GradientLoss(nn.Module):
"""GradientLoss.
Adapted from https://www.cs.cornell.edu/projects/megadepth/
Args:
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
loss_weight (float): Weight of the loss. Default: 1.0.
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
"""
def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"):
super(GradientLoss, self).__init__()
self.valid_mask = valid_mask
self.loss_weight = loss_weight
self.max_depth = max_depth
self.loss_name = loss_name
self.eps = 0.001 # avoid grad explode
def gradientloss(self, input, target):
input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)]
target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)]
gradient_loss = 0
for input, target in zip(input_downscaled, target_downscaled):
if self.valid_mask:
mask = target > 0
if self.max_depth is not None:
mask = torch.logical_and(target > 0, target <= self.max_depth)
N = torch.sum(mask)
else:
mask = torch.ones_like(target)
N = input.numel()
input_log = torch.log(input + self.eps)
target_log = torch.log(target + self.eps)
log_d_diff = input_log - target_log
log_d_diff = torch.mul(log_d_diff, mask)
v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :])
v_mask = torch.mul(mask[0:-2, :], mask[2:, :])
v_gradient = torch.mul(v_gradient, v_mask)
h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:])
h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:])
h_gradient = torch.mul(h_gradient, h_mask)
gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N
return gradient_loss
def forward(self, depth_pred, depth_gt):
"""Forward function."""
gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt)
return gradient_loss