Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
from ...models.builder import LOSSES | |
class SigLoss(nn.Module): | |
"""SigLoss. | |
This follows `AdaBins <https://arxiv.org/abs/2011.14141>`_. | |
Args: | |
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. | |
loss_weight (float): Weight of the loss. Default: 1.0. | |
max_depth (int): When filtering invalid gt, set a max threshold. Default: None. | |
warm_up (bool): A simple warm up stage to help convergence. Default: False. | |
warm_iter (int): The number of warm up stage. Default: 100. | |
""" | |
def __init__( | |
self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss" | |
): | |
super(SigLoss, self).__init__() | |
self.valid_mask = valid_mask | |
self.loss_weight = loss_weight | |
self.max_depth = max_depth | |
self.loss_name = loss_name | |
self.eps = 0.001 # avoid grad explode | |
# HACK: a hack implementation for warmup sigloss | |
self.warm_up = warm_up | |
self.warm_iter = warm_iter | |
self.warm_up_counter = 0 | |
def sigloss(self, input, target): | |
if self.valid_mask: | |
valid_mask = target > 0 | |
if self.max_depth is not None: | |
valid_mask = torch.logical_and(target > 0, target <= self.max_depth) | |
input = input[valid_mask] | |
target = target[valid_mask] | |
if self.warm_up: | |
if self.warm_up_counter < self.warm_iter: | |
g = torch.log(input + self.eps) - torch.log(target + self.eps) | |
g = 0.15 * torch.pow(torch.mean(g), 2) | |
self.warm_up_counter += 1 | |
return torch.sqrt(g) | |
g = torch.log(input + self.eps) - torch.log(target + self.eps) | |
Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2) | |
return torch.sqrt(Dg) | |
def forward(self, depth_pred, depth_gt): | |
"""Forward function.""" | |
loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt) | |
return loss_depth | |