LN3Diff_I23D / nsr /losses /sdfstudio_losses.py
NIRVANALAN
init
11e6f7b
raw
history blame
29.2 kB
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Collection of Losses.
"""
import torch
import torch.nn.functional as F
from torch import nn
from torchtyping import TensorType
from torch.autograd import Variable
import numpy as np
from math import exp
# from nerfstudio.cameras.rays import RaySamples
# from nerfstudio.field_components.field_heads import FieldHeadNames
L1Loss = nn.L1Loss
MSELoss = nn.MSELoss
LOSSES = {"L1": L1Loss, "MSE": MSELoss}
EPS = 1.0e-7
def outer(
t0_starts: TensorType[..., "num_samples_0"],
t0_ends: TensorType[..., "num_samples_0"],
t1_starts: TensorType[..., "num_samples_1"],
t1_ends: TensorType[..., "num_samples_1"],
y1: TensorType[..., "num_samples_1"],
) -> TensorType[..., "num_samples_0"]:
"""Faster version of
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L117
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L64
Args:
t0_starts: start of the interval edges
t0_ends: end of the interval edges
t1_starts: start of the interval edges
t1_ends: end of the interval edges
y1: weights
"""
cy1 = torch.cat([torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1)
idx_lo = torch.searchsorted(t1_starts.contiguous(), t0_starts.contiguous(), side="right") - 1
idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1)
idx_hi = torch.searchsorted(t1_ends.contiguous(), t0_ends.contiguous(), side="right")
idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1)
cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1)
cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1)
y0_outer = cy1_hi - cy1_lo
return y0_outer
def lossfun_outer(
t: TensorType[..., "num_samples+1"],
w: TensorType[..., "num_samples"],
t_env: TensorType[..., "num_samples+1"],
w_env: TensorType[..., "num_samples"],
):
"""
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L136
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L80
Args:
t: interval edges
w: weights
t_env: interval edges of the upper bound enveloping historgram
w_env: weights that should upper bound the inner (t,w) histogram
"""
w_outer = outer(t[..., :-1], t[..., 1:], t_env[..., :-1], t_env[..., 1:], w_env)
return torch.clip(w - w_outer, min=0) ** 2 / (w + EPS)
def ray_samples_to_sdist(ray_samples):
"""Convert ray samples to s space"""
starts = ray_samples.spacing_starts
ends = ray_samples.spacing_ends
sdist = torch.cat([starts[..., 0], ends[..., -1:, 0]], dim=-1) # (num_rays, num_samples + 1)
return sdist
def interlevel_loss(weights_list, ray_samples_list):
"""Calculates the proposal loss in the MipNeRF-360 paper.
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/model.py#L515
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/train_utils.py#L133
"""
c = ray_samples_to_sdist(ray_samples_list[-1]).detach()
w = weights_list[-1][..., 0].detach()
loss_interlevel = 0.0
for ray_samples, weights in zip(ray_samples_list[:-1], weights_list[:-1]):
sdist = ray_samples_to_sdist(ray_samples)
cp = sdist # (num_rays, num_samples + 1)
wp = weights[..., 0] # (num_rays, num_samples)
loss_interlevel += torch.mean(lossfun_outer(c, w, cp, wp))
return loss_interlevel
## zip-NeRF losses
def blur_stepfun(x, y, r):
x_c = torch.cat([x - r, x + r], dim=-1)
x_r, x_idx = torch.sort(x_c, dim=-1)
zeros = torch.zeros_like(y[:, :1])
y_1 = (torch.cat([y, zeros], dim=-1) - torch.cat([zeros, y], dim=-1)) / (2 * r)
x_idx = x_idx[:, :-1]
y_2 = torch.cat([y_1, -y_1], dim=-1)[
torch.arange(x_idx.shape[0]).reshape(-1, 1).expand(x_idx.shape).to(x_idx.device), x_idx
]
y_r = torch.cumsum((x_r[:, 1:] - x_r[:, :-1]) * torch.cumsum(y_2, dim=-1), dim=-1)
y_r = torch.cat([zeros, y_r], dim=-1)
return x_r, y_r
def interlevel_loss_zip(weights_list, ray_samples_list):
"""Calculates the proposal loss in the Zip-NeRF paper."""
c = ray_samples_to_sdist(ray_samples_list[-1]).detach()
w = weights_list[-1][..., 0].detach()
# 1. normalize
w_normalize = w / (c[:, 1:] - c[:, :-1])
loss_interlevel = 0.0
for ray_samples, weights, r in zip(ray_samples_list[:-1], weights_list[:-1], [0.03, 0.003]):
# 2. step blur with different r
x_r, y_r = blur_stepfun(c, w_normalize, r)
y_r = torch.clip(y_r, min=0)
assert (y_r >= 0.0).all()
# 3. accumulate
y_cum = torch.cumsum((y_r[:, 1:] + y_r[:, :-1]) * 0.5 * (x_r[:, 1:] - x_r[:, :-1]), dim=-1)
y_cum = torch.cat([torch.zeros_like(y_cum[:, :1]), y_cum], dim=-1)
# 4 loss
sdist = ray_samples_to_sdist(ray_samples)
cp = sdist # (num_rays, num_samples + 1)
wp = weights[..., 0] # (num_rays, num_samples)
# resample
inds = torch.searchsorted(x_r, cp, side="right")
below = torch.clamp(inds - 1, 0, x_r.shape[-1] - 1)
above = torch.clamp(inds, 0, x_r.shape[-1] - 1)
cdf_g0 = torch.gather(x_r, -1, below)
bins_g0 = torch.gather(y_cum, -1, below)
cdf_g1 = torch.gather(x_r, -1, above)
bins_g1 = torch.gather(y_cum, -1, above)
t = torch.clip(torch.nan_to_num((cp - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
bins = bins_g0 + t * (bins_g1 - bins_g0)
w_gt = bins[:, 1:] - bins[:, :-1]
# TODO here might be unstable when wp is very small
loss_interlevel += torch.mean(torch.clip(w_gt - wp, min=0) ** 2 / (wp + 1e-5))
return loss_interlevel
# Verified
def lossfun_distortion(t, w):
"""
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L142
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L266
"""
ut = (t[..., 1:] + t[..., :-1]) / 2
dut = torch.abs(ut[..., :, None] - ut[..., None, :])
loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1)
loss_intra = torch.sum(w**2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3
return loss_inter + loss_intra
def distortion_loss(weights_list, ray_samples_list):
"""From mipnerf360"""
c = ray_samples_to_sdist(ray_samples_list[-1])
w = weights_list[-1][..., 0]
loss = torch.mean(lossfun_distortion(c, w))
return loss
# def nerfstudio_distortion_loss(
# ray_samples: RaySamples,
# densities: TensorType["bs":..., "num_samples", 1] = None,
# weights: TensorType["bs":..., "num_samples", 1] = None,
# ) -> TensorType["bs":..., 1]:
# """Ray based distortion loss proposed in MipNeRF-360. Returns distortion Loss.
# .. math::
# \\mathcal{L}(\\mathbf{s}, \\mathbf{w}) =\\iint\\limits_{-\\infty}^{\\,\\,\\,\\infty}
# \\mathbf{w}_\\mathbf{s}(u)\\mathbf{w}_\\mathbf{s}(v)|u - v|\\,d_{u}\\,d_{v}
# where :math:`\\mathbf{w}_\\mathbf{s}(u)=\\sum_i w_i \\mathbb{1}_{[\\mathbf{s}_i, \\mathbf{s}_{i+1})}(u)`
# is the weight at location :math:`u` between bin locations :math:`s_i` and :math:`s_{i+1}`.
# Args:
# ray_samples: Ray samples to compute loss over
# densities: Predicted sample densities
# weights: Predicted weights from densities and sample locations
# """
# if torch.is_tensor(densities):
# assert not torch.is_tensor(weights), "Cannot use both densities and weights"
# # Compute the weight at each sample location
# weights = ray_samples.get_weights(densities)
# if torch.is_tensor(weights):
# assert not torch.is_tensor(densities), "Cannot use both densities and weights"
# starts = ray_samples.spacing_starts
# ends = ray_samples.spacing_ends
# assert starts is not None and ends is not None, "Ray samples must have spacing starts and ends"
# midpoints = (starts + ends) / 2.0 # (..., num_samples, 1)
# loss = (
# weights * weights[..., None, :, 0] * torch.abs(midpoints - midpoints[..., None, :, 0])
# ) # (..., num_samples, num_samples)
# loss = torch.sum(loss, dim=(-1, -2))[..., None] # (..., num_samples)
# loss = loss + 1 / 3.0 * torch.sum(weights**2 * (ends - starts), dim=-2)
# return loss
def orientation_loss(
weights: TensorType["bs":..., "num_samples", 1],
normals: TensorType["bs":..., "num_samples", 3],
viewdirs: TensorType["bs":..., 3],
):
"""Orientation loss proposed in Ref-NeRF.
Loss that encourages that all visible normals are facing towards the camera.
"""
w = weights
n = normals
v = viewdirs
n_dot_v = (n * v[..., None, :]).sum(axis=-1)
return (w[..., 0] * torch.fmin(torch.zeros_like(n_dot_v), n_dot_v) ** 2).sum(dim=-1)
def pred_normal_loss(
weights: TensorType["bs":..., "num_samples", 1],
normals: TensorType["bs":..., "num_samples", 3],
pred_normals: TensorType["bs":..., "num_samples", 3],
):
"""Loss between normals calculated from density and normals from prediction network."""
return (weights[..., 0] * (1.0 - torch.sum(normals * pred_normals, dim=-1))).sum(dim=-1)
def monosdf_normal_loss(normal_pred: torch.Tensor, normal_gt: torch.Tensor):
"""normal consistency loss as monosdf
Args:
normal_pred (torch.Tensor): volume rendered normal
normal_gt (torch.Tensor): monocular normal
"""
normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1)
normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1)
l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean()
cos = (1.0 - torch.sum(normal_pred * normal_gt, dim=-1)).mean()
return l1 + cos
# copy from MiDaS
def compute_scale_and_shift(prediction, target, mask):
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
a_00 = torch.sum(mask * prediction * prediction, (1, 2))
a_01 = torch.sum(mask * prediction, (1, 2))
a_11 = torch.sum(mask, (1, 2))
# right hand side: b = [b_0, b_1]
b_0 = torch.sum(mask * prediction * target, (1, 2))
b_1 = torch.sum(mask * target, (1, 2))
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
x_0 = torch.zeros_like(b_0)
x_1 = torch.zeros_like(b_1)
det = a_00 * a_11 - a_01 * a_01
valid = det.nonzero()
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
return x_0, x_1
def reduction_batch_based(image_loss, M):
# average of all valid pixels of the batch
# avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
divisor = torch.sum(M)
if divisor == 0:
return 0
else:
return torch.sum(image_loss) / divisor
def reduction_image_based(image_loss, M):
# mean of average of valid pixels of an image
# avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
valid = M.nonzero()
image_loss[valid] = image_loss[valid] / M[valid]
return torch.mean(image_loss)
def mse_loss(prediction, target, mask, reduction=reduction_batch_based):
M = torch.sum(mask, (1, 2))
res = prediction - target
image_loss = torch.sum(mask * res * res, (1, 2))
return reduction(image_loss, 2 * M)
def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
M = torch.sum(mask, (1, 2))
diff = prediction - target
diff = torch.mul(mask, diff)
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
grad_x = torch.mul(mask_x, grad_x)
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
grad_y = torch.mul(mask_y, grad_y)
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
return reduction(image_loss, M)
class MiDaSMSELoss(nn.Module):
def __init__(self, reduction="batch-based"):
super().__init__()
if reduction == "batch-based":
self.__reduction = reduction_batch_based
else:
self.__reduction = reduction_image_based
def forward(self, prediction, target, mask):
return mse_loss(prediction, target, mask, reduction=self.__reduction)
class GradientLoss(nn.Module):
def __init__(self, scales=4, reduction="batch-based"):
super().__init__()
if reduction == "batch-based":
self.__reduction = reduction_batch_based
else:
self.__reduction = reduction_image_based
self.__scales = scales
def forward(self, prediction, target, mask):
total = 0
for scale in range(self.__scales):
step = pow(2, scale)
total += gradient_loss(
prediction[:, ::step, ::step],
target[:, ::step, ::step],
mask[:, ::step, ::step],
reduction=self.__reduction,
)
return total
class ScaleAndShiftInvariantLoss(nn.Module):
def __init__(self, alpha=0.5, scales=4, reduction="batch-based"):
super().__init__()
self.__data_loss = MiDaSMSELoss(reduction=reduction)
self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction)
self.__alpha = alpha
self.__prediction_ssi = None
def forward(self, prediction, target, mask):
scale, shift = compute_scale_and_shift(prediction, target, mask)
self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)
total = self.__data_loss(self.__prediction_ssi, target, mask)
if self.__alpha > 0:
total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask)
return total
def __get_prediction_ssi(self):
return self.__prediction_ssi
prediction_ssi = property(__get_prediction_ssi)
# end copy
# copy from https://github.com/svip-lab/Indoor-SfMLearner/blob/0d682b7ce292484e5e3e2161fc9fc07e2f5ca8d1/layers.py#L218
class SSIM(nn.Module):
"""Layer to compute the SSIM loss between a pair of images"""
def __init__(self, patch_size):
super(SSIM, self).__init__()
self.mu_x_pool = nn.AvgPool2d(patch_size, 1)
self.mu_y_pool = nn.AvgPool2d(patch_size, 1)
self.sig_x_pool = nn.AvgPool2d(patch_size, 1)
self.sig_y_pool = nn.AvgPool2d(patch_size, 1)
self.sig_xy_pool = nn.AvgPool2d(patch_size, 1)
self.refl = nn.ReflectionPad2d(patch_size // 2)
self.C1 = 0.01**2
self.C2 = 0.03**2
def forward(self, x, y):
x = self.refl(x)
y = self.refl(y)
mu_x = self.mu_x_pool(x)
mu_y = self.mu_y_pool(y)
sigma_x = self.sig_x_pool(x**2) - mu_x**2
sigma_y = self.sig_y_pool(y**2) - mu_y**2
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
SSIM_d = (mu_x**2 + mu_y**2 + self.C1) * (sigma_x + sigma_y + self.C2)
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
# TODO test different losses
class NCC(nn.Module):
"""Layer to compute the normalization cross correlation (NCC) of patches"""
def __init__(self, patch_size: int = 11, min_patch_variance: float = 0.01):
super(NCC, self).__init__()
self.patch_size = patch_size
self.min_patch_variance = min_patch_variance
def forward(self, x, y):
# TODO if we use gray image we should do it right after loading the image to save computations
# to gray image
x = torch.mean(x, dim=1)
y = torch.mean(y, dim=1)
x_mean = torch.mean(x, dim=(1, 2), keepdim=True)
y_mean = torch.mean(y, dim=(1, 2), keepdim=True)
x_normalized = x - x_mean
y_normalized = y - y_mean
norm = torch.sum(x_normalized * y_normalized, dim=(1, 2))
var = torch.square(x_normalized).sum(dim=(1, 2)) * torch.square(y_normalized).sum(dim=(1, 2))
denom = torch.sqrt(var + 1e-6)
ncc = norm / (denom + 1e-6)
# ignore pathces with low variances
not_valid = (torch.square(x_normalized).sum(dim=(1, 2)) < self.min_patch_variance) | (
torch.square(y_normalized).sum(dim=(1, 2)) < self.min_patch_variance
)
ncc[not_valid] = 1.0
score = 1 - ncc.clip(-1.0, 1.0) # 0->2: smaller, better
return score[:, None, None, None]
class MultiViewLoss(nn.Module):
"""compute multi-view consistency loss"""
def __init__(self, patch_size: int = 11, topk: int = 4, min_patch_variance: float = 0.01):
super(MultiViewLoss, self).__init__()
self.patch_size = patch_size
self.topk = topk
self.min_patch_variance = min_patch_variance
# TODO make metric configurable
# self.ssim = SSIM(patch_size=patch_size)
# self.ncc = NCC(patch_size=patch_size)
self.ssim = NCC(patch_size=patch_size, min_patch_variance=min_patch_variance)
self.iter = 0
def forward(self, patches: torch.Tensor, valid: torch.Tensor):
"""take the mim
Args:
patches (torch.Tensor): _description_
valid (torch.Tensor): _description_
Returns:
_type_: _description_
"""
num_imgs, num_rays, _, num_channels = patches.shape
if num_rays <= 0:
return torch.tensor(0.0).to(patches.device)
ref_patches = (
patches[:1, ...]
.reshape(1, num_rays, self.patch_size, self.patch_size, num_channels)
.expand(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels)
.reshape(-1, self.patch_size, self.patch_size, num_channels)
.permute(0, 3, 1, 2)
) # [N_src*N_rays, 3, patch_size, patch_size]
src_patches = (
patches[1:, ...]
.reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels)
.reshape(-1, self.patch_size, self.patch_size, num_channels)
.permute(0, 3, 1, 2)
) # [N_src*N_rays, 3, patch_size, patch_size]
# apply same reshape to the valid mask
src_patches_valid = (
valid[1:, ...]
.reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, 1)
.reshape(-1, self.patch_size, self.patch_size, 1)
.permute(0, 3, 1, 2)
) # [N_src*N_rays, 1, patch_size, patch_size]
ssim = self.ssim(ref_patches.detach(), src_patches)
ssim = torch.mean(ssim, dim=(1, 2, 3))
ssim = ssim.reshape(num_imgs - 1, num_rays)
# ignore invalid patch by setting ssim error to very large value
ssim_valid = (
src_patches_valid.reshape(-1, self.patch_size * self.patch_size).all(dim=-1).reshape(num_imgs - 1, num_rays)
)
# we should mask the error after we select the topk value, otherwise we might select far way patches that happens to be inside the image
# ssim[torch.logical_not(ssim_valid)] = 1.1 # max ssim_error is 1
min_ssim, idx = torch.topk(ssim, k=self.topk, largest=False, dim=0, sorted=True)
min_ssim_valid = ssim_valid[idx, torch.arange(num_rays)[None].expand_as(idx)]
# TODO how to set this value for better visualization
min_ssim[torch.logical_not(min_ssim_valid)] = 0.0 # max ssim_error is 1
if False:
# visualization of topK error computations
import cv2
import numpy as np
vis_patch_num = num_rays
K = min(100, vis_patch_num)
image = (
patches[:, :vis_patch_num, :, :]
.reshape(-1, vis_patch_num, self.patch_size, self.patch_size, 3)
.permute(1, 2, 0, 3, 4)
.reshape(vis_patch_num * self.patch_size, -1, 3)
)
src_patches_reshaped = src_patches.reshape(
num_imgs - 1, num_rays, 3, self.patch_size, self.patch_size
).permute(1, 0, 3, 4, 2)
idx = idx.permute(1, 0)
selected_patch = (
src_patches_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx]
.permute(0, 2, 1, 3, 4)
.reshape(num_rays, self.patch_size, self.topk * self.patch_size, 3)[:vis_patch_num]
.reshape(-1, self.topk * self.patch_size, 3)
)
# apply same reshape to the valid mask
src_patches_valid_reshaped = src_patches_valid.reshape(
num_imgs - 1, num_rays, 1, self.patch_size, self.patch_size
).permute(1, 0, 3, 4, 2)
selected_patch_valid = (
src_patches_valid_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx]
.permute(0, 2, 1, 3, 4)
.reshape(num_rays, self.patch_size, self.topk * self.patch_size, 1)[:vis_patch_num]
.reshape(-1, self.topk * self.patch_size, 1)
)
# valid to image
selected_patch_valid = selected_patch_valid.expand_as(selected_patch).float()
# breakpoint()
image = torch.cat([selected_patch_valid, selected_patch, image], dim=1)
# select top rays with highest errors
image = image.reshape(num_rays, self.patch_size, -1, 3)
_, idx2 = torch.topk(
torch.sum(min_ssim, dim=0) / (min_ssim_valid.float().sum(dim=0) + 1e-6),
k=K,
largest=True,
dim=0,
sorted=True,
)
image = image[idx2].reshape(K * self.patch_size, -1, 3)
cv2.imwrite(f"vis/{self.iter}.png", (image.detach().cpu().numpy() * 255).astype(np.uint8)[..., ::-1])
self.iter += 1
if self.iter == 9:
breakpoint()
return torch.sum(min_ssim) / (min_ssim_valid.float().sum() + 1e-6)
# sensor depth loss, adapted from https://github.com/dazinovic/neural-rgbd-surface-reconstruction/blob/main/losses.py
# class SensorDepthLoss(nn.Module):
# """Sensor Depth loss"""
# def __init__(self, truncation: float):
# super(SensorDepthLoss, self).__init__()
# self.truncation = truncation # 0.05 * 0.3 5cm scaled
# def forward(self, batch, outputs):
# """take the mim
# Args:
# batch (Dict): inputs
# outputs (Dict): outputs data from surface model
# Returns:
# l1_loss: l1 loss
# freespace_loss: free space loss
# sdf_loss: sdf loss
# """
# depth_pred = outputs["depth"]
# depth_gt = batch["sensor_depth"].to(depth_pred.device)[..., None]
# valid_gt_mask = depth_gt > 0.0
# l1_loss = torch.sum(valid_gt_mask * torch.abs(depth_gt - depth_pred)) / (valid_gt_mask.sum() + 1e-6)
# # free space loss and sdf loss
# ray_samples = outputs["ray_samples"]
# filed_outputs = outputs["field_outputs"]
# pred_sdf = filed_outputs[FieldHeadNames.SDF][..., 0]
# directions_norm = outputs["directions_norm"]
# z_vals = ray_samples.frustums.starts[..., 0] / directions_norm
# truncation = self.truncation
# front_mask = valid_gt_mask & (z_vals < (depth_gt - truncation))
# back_mask = valid_gt_mask & (z_vals > (depth_gt + truncation))
# sdf_mask = valid_gt_mask & (~front_mask) & (~back_mask)
# num_fs_samples = front_mask.sum()
# num_sdf_samples = sdf_mask.sum()
# num_samples = num_fs_samples + num_sdf_samples + 1e-6
# fs_weight = 1.0 - num_fs_samples / num_samples
# sdf_weight = 1.0 - num_sdf_samples / num_samples
# free_space_loss = torch.mean((F.relu(truncation - pred_sdf) * front_mask) ** 2) * fs_weight
# sdf_loss = torch.mean(((z_vals + pred_sdf) - depth_gt) ** 2 * sdf_mask) * sdf_weight
# return l1_loss, free_space_loss, sdf_loss
r"""Implements Stochastic Structural SIMilarity(S3IM) algorithm.
It is proposed in the ICCV2023 paper
`S3IM: Stochastic Structural SIMilarity and Its Unreasonable Effectiveness for Neural Fields`.
Arguments:
s3im_kernel_size (int): kernel size in ssim's convolution(default: 4)
s3im_stride (int): stride in ssim's convolution(default: 4)
s3im_repeat_time (int): repeat time in re-shuffle virtual patch(default: 10)
s3im_patch_height (height): height of virtual patch(default: 64)
"""
class S3IM(torch.nn.Module):
def __init__(self, s3im_kernel_size = 4, s3im_stride=4, s3im_repeat_time=10, s3im_patch_height=64, size_average = True):
super(S3IM, self).__init__()
self.s3im_kernel_size = s3im_kernel_size
self.s3im_stride = s3im_stride
self.s3im_repeat_time = s3im_repeat_time
self.s3im_patch_height = s3im_patch_height
self.size_average = size_average
self.channel = 1
self.s3im_kernel = self.create_kernel(s3im_kernel_size, self.channel)
def gaussian(self, s3im_kernel_size, sigma):
gauss = torch.Tensor([exp(-(x - s3im_kernel_size//2)**2/float(2*sigma**2)) for x in range(s3im_kernel_size)])
return gauss/gauss.sum()
def create_kernel(self, s3im_kernel_size, channel):
_1D_window = self.gaussian(s3im_kernel_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
s3im_kernel = Variable(_2D_window.expand(channel, 1, s3im_kernel_size, s3im_kernel_size).contiguous())
return s3im_kernel
def _ssim(self, img1, img2, s3im_kernel, s3im_kernel_size, channel, size_average = True, s3im_stride=None):
mu1 = F.conv2d(img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride)
mu2 = F.conv2d(img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu2_sq
sigma12 = F.conv2d(img1*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
def ssim_loss(self, img1, img2):
"""
img1, img2: torch.Tensor([b,c,h,w])
"""
(_, channel, _, _) = img1.size()
if channel == self.channel and self.s3im_kernel.data.type() == img1.data.type():
s3im_kernel = self.s3im_kernel
else:
s3im_kernel = self.create_kernel(self.s3im_kernel_size, channel)
if img1.is_cuda:
s3im_kernel = s3im_kernel.cuda(img1.get_device())
s3im_kernel = s3im_kernel.type_as(img1)
self.s3im_kernel = s3im_kernel
self.channel = channel
return self._ssim(img1, img2, s3im_kernel, self.s3im_kernel_size, channel, self.size_average, s3im_stride=self.s3im_stride)
def forward(self, src_vec, tar_vec):
loss = 0.0
index_list = []
for i in range(self.s3im_repeat_time):
if i == 0:
tmp_index = torch.arange(len(tar_vec))
index_list.append(tmp_index)
else:
ran_idx = torch.randperm(len(tar_vec))
index_list.append(ran_idx)
res_index = torch.cat(index_list)
tar_all = tar_vec[res_index]
src_all = src_vec[res_index]
tar_patch = tar_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1)
src_patch = src_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1)
loss = (1 - self.ssim_loss(src_patch, tar_patch))
return loss