Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2022 The Nerfstudio Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Collection of Losses. | |
""" | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torchtyping import TensorType | |
from torch.autograd import Variable | |
import numpy as np | |
from math import exp | |
# from nerfstudio.cameras.rays import RaySamples | |
# from nerfstudio.field_components.field_heads import FieldHeadNames | |
L1Loss = nn.L1Loss | |
MSELoss = nn.MSELoss | |
LOSSES = {"L1": L1Loss, "MSE": MSELoss} | |
EPS = 1.0e-7 | |
def outer( | |
t0_starts: TensorType[..., "num_samples_0"], | |
t0_ends: TensorType[..., "num_samples_0"], | |
t1_starts: TensorType[..., "num_samples_1"], | |
t1_ends: TensorType[..., "num_samples_1"], | |
y1: TensorType[..., "num_samples_1"], | |
) -> TensorType[..., "num_samples_0"]: | |
"""Faster version of | |
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L117 | |
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L64 | |
Args: | |
t0_starts: start of the interval edges | |
t0_ends: end of the interval edges | |
t1_starts: start of the interval edges | |
t1_ends: end of the interval edges | |
y1: weights | |
""" | |
cy1 = torch.cat([torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1) | |
idx_lo = torch.searchsorted(t1_starts.contiguous(), t0_starts.contiguous(), side="right") - 1 | |
idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1) | |
idx_hi = torch.searchsorted(t1_ends.contiguous(), t0_ends.contiguous(), side="right") | |
idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1) | |
cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1) | |
cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1) | |
y0_outer = cy1_hi - cy1_lo | |
return y0_outer | |
def lossfun_outer( | |
t: TensorType[..., "num_samples+1"], | |
w: TensorType[..., "num_samples"], | |
t_env: TensorType[..., "num_samples+1"], | |
w_env: TensorType[..., "num_samples"], | |
): | |
""" | |
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L136 | |
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L80 | |
Args: | |
t: interval edges | |
w: weights | |
t_env: interval edges of the upper bound enveloping historgram | |
w_env: weights that should upper bound the inner (t,w) histogram | |
""" | |
w_outer = outer(t[..., :-1], t[..., 1:], t_env[..., :-1], t_env[..., 1:], w_env) | |
return torch.clip(w - w_outer, min=0) ** 2 / (w + EPS) | |
def ray_samples_to_sdist(ray_samples): | |
"""Convert ray samples to s space""" | |
starts = ray_samples.spacing_starts | |
ends = ray_samples.spacing_ends | |
sdist = torch.cat([starts[..., 0], ends[..., -1:, 0]], dim=-1) # (num_rays, num_samples + 1) | |
return sdist | |
def interlevel_loss(weights_list, ray_samples_list): | |
"""Calculates the proposal loss in the MipNeRF-360 paper. | |
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/model.py#L515 | |
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/train_utils.py#L133 | |
""" | |
c = ray_samples_to_sdist(ray_samples_list[-1]).detach() | |
w = weights_list[-1][..., 0].detach() | |
loss_interlevel = 0.0 | |
for ray_samples, weights in zip(ray_samples_list[:-1], weights_list[:-1]): | |
sdist = ray_samples_to_sdist(ray_samples) | |
cp = sdist # (num_rays, num_samples + 1) | |
wp = weights[..., 0] # (num_rays, num_samples) | |
loss_interlevel += torch.mean(lossfun_outer(c, w, cp, wp)) | |
return loss_interlevel | |
## zip-NeRF losses | |
def blur_stepfun(x, y, r): | |
x_c = torch.cat([x - r, x + r], dim=-1) | |
x_r, x_idx = torch.sort(x_c, dim=-1) | |
zeros = torch.zeros_like(y[:, :1]) | |
y_1 = (torch.cat([y, zeros], dim=-1) - torch.cat([zeros, y], dim=-1)) / (2 * r) | |
x_idx = x_idx[:, :-1] | |
y_2 = torch.cat([y_1, -y_1], dim=-1)[ | |
torch.arange(x_idx.shape[0]).reshape(-1, 1).expand(x_idx.shape).to(x_idx.device), x_idx | |
] | |
y_r = torch.cumsum((x_r[:, 1:] - x_r[:, :-1]) * torch.cumsum(y_2, dim=-1), dim=-1) | |
y_r = torch.cat([zeros, y_r], dim=-1) | |
return x_r, y_r | |
def interlevel_loss_zip(weights_list, ray_samples_list): | |
"""Calculates the proposal loss in the Zip-NeRF paper.""" | |
c = ray_samples_to_sdist(ray_samples_list[-1]).detach() | |
w = weights_list[-1][..., 0].detach() | |
# 1. normalize | |
w_normalize = w / (c[:, 1:] - c[:, :-1]) | |
loss_interlevel = 0.0 | |
for ray_samples, weights, r in zip(ray_samples_list[:-1], weights_list[:-1], [0.03, 0.003]): | |
# 2. step blur with different r | |
x_r, y_r = blur_stepfun(c, w_normalize, r) | |
y_r = torch.clip(y_r, min=0) | |
assert (y_r >= 0.0).all() | |
# 3. accumulate | |
y_cum = torch.cumsum((y_r[:, 1:] + y_r[:, :-1]) * 0.5 * (x_r[:, 1:] - x_r[:, :-1]), dim=-1) | |
y_cum = torch.cat([torch.zeros_like(y_cum[:, :1]), y_cum], dim=-1) | |
# 4 loss | |
sdist = ray_samples_to_sdist(ray_samples) | |
cp = sdist # (num_rays, num_samples + 1) | |
wp = weights[..., 0] # (num_rays, num_samples) | |
# resample | |
inds = torch.searchsorted(x_r, cp, side="right") | |
below = torch.clamp(inds - 1, 0, x_r.shape[-1] - 1) | |
above = torch.clamp(inds, 0, x_r.shape[-1] - 1) | |
cdf_g0 = torch.gather(x_r, -1, below) | |
bins_g0 = torch.gather(y_cum, -1, below) | |
cdf_g1 = torch.gather(x_r, -1, above) | |
bins_g1 = torch.gather(y_cum, -1, above) | |
t = torch.clip(torch.nan_to_num((cp - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1) | |
bins = bins_g0 + t * (bins_g1 - bins_g0) | |
w_gt = bins[:, 1:] - bins[:, :-1] | |
# TODO here might be unstable when wp is very small | |
loss_interlevel += torch.mean(torch.clip(w_gt - wp, min=0) ** 2 / (wp + 1e-5)) | |
return loss_interlevel | |
# Verified | |
def lossfun_distortion(t, w): | |
""" | |
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L142 | |
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L266 | |
""" | |
ut = (t[..., 1:] + t[..., :-1]) / 2 | |
dut = torch.abs(ut[..., :, None] - ut[..., None, :]) | |
loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1) | |
loss_intra = torch.sum(w**2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3 | |
return loss_inter + loss_intra | |
def distortion_loss(weights_list, ray_samples_list): | |
"""From mipnerf360""" | |
c = ray_samples_to_sdist(ray_samples_list[-1]) | |
w = weights_list[-1][..., 0] | |
loss = torch.mean(lossfun_distortion(c, w)) | |
return loss | |
# def nerfstudio_distortion_loss( | |
# ray_samples: RaySamples, | |
# densities: TensorType["bs":..., "num_samples", 1] = None, | |
# weights: TensorType["bs":..., "num_samples", 1] = None, | |
# ) -> TensorType["bs":..., 1]: | |
# """Ray based distortion loss proposed in MipNeRF-360. Returns distortion Loss. | |
# .. math:: | |
# \\mathcal{L}(\\mathbf{s}, \\mathbf{w}) =\\iint\\limits_{-\\infty}^{\\,\\,\\,\\infty} | |
# \\mathbf{w}_\\mathbf{s}(u)\\mathbf{w}_\\mathbf{s}(v)|u - v|\\,d_{u}\\,d_{v} | |
# where :math:`\\mathbf{w}_\\mathbf{s}(u)=\\sum_i w_i \\mathbb{1}_{[\\mathbf{s}_i, \\mathbf{s}_{i+1})}(u)` | |
# is the weight at location :math:`u` between bin locations :math:`s_i` and :math:`s_{i+1}`. | |
# Args: | |
# ray_samples: Ray samples to compute loss over | |
# densities: Predicted sample densities | |
# weights: Predicted weights from densities and sample locations | |
# """ | |
# if torch.is_tensor(densities): | |
# assert not torch.is_tensor(weights), "Cannot use both densities and weights" | |
# # Compute the weight at each sample location | |
# weights = ray_samples.get_weights(densities) | |
# if torch.is_tensor(weights): | |
# assert not torch.is_tensor(densities), "Cannot use both densities and weights" | |
# starts = ray_samples.spacing_starts | |
# ends = ray_samples.spacing_ends | |
# assert starts is not None and ends is not None, "Ray samples must have spacing starts and ends" | |
# midpoints = (starts + ends) / 2.0 # (..., num_samples, 1) | |
# loss = ( | |
# weights * weights[..., None, :, 0] * torch.abs(midpoints - midpoints[..., None, :, 0]) | |
# ) # (..., num_samples, num_samples) | |
# loss = torch.sum(loss, dim=(-1, -2))[..., None] # (..., num_samples) | |
# loss = loss + 1 / 3.0 * torch.sum(weights**2 * (ends - starts), dim=-2) | |
# return loss | |
def orientation_loss( | |
weights: TensorType["bs":..., "num_samples", 1], | |
normals: TensorType["bs":..., "num_samples", 3], | |
viewdirs: TensorType["bs":..., 3], | |
): | |
"""Orientation loss proposed in Ref-NeRF. | |
Loss that encourages that all visible normals are facing towards the camera. | |
""" | |
w = weights | |
n = normals | |
v = viewdirs | |
n_dot_v = (n * v[..., None, :]).sum(axis=-1) | |
return (w[..., 0] * torch.fmin(torch.zeros_like(n_dot_v), n_dot_v) ** 2).sum(dim=-1) | |
def pred_normal_loss( | |
weights: TensorType["bs":..., "num_samples", 1], | |
normals: TensorType["bs":..., "num_samples", 3], | |
pred_normals: TensorType["bs":..., "num_samples", 3], | |
): | |
"""Loss between normals calculated from density and normals from prediction network.""" | |
return (weights[..., 0] * (1.0 - torch.sum(normals * pred_normals, dim=-1))).sum(dim=-1) | |
def monosdf_normal_loss(normal_pred: torch.Tensor, normal_gt: torch.Tensor): | |
"""normal consistency loss as monosdf | |
Args: | |
normal_pred (torch.Tensor): volume rendered normal | |
normal_gt (torch.Tensor): monocular normal | |
""" | |
normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1) | |
normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1) | |
l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean() | |
cos = (1.0 - torch.sum(normal_pred * normal_gt, dim=-1)).mean() | |
return l1 + cos | |
# copy from MiDaS | |
def compute_scale_and_shift(prediction, target, mask): | |
# system matrix: A = [[a_00, a_01], [a_10, a_11]] | |
a_00 = torch.sum(mask * prediction * prediction, (1, 2)) | |
a_01 = torch.sum(mask * prediction, (1, 2)) | |
a_11 = torch.sum(mask, (1, 2)) | |
# right hand side: b = [b_0, b_1] | |
b_0 = torch.sum(mask * prediction * target, (1, 2)) | |
b_1 = torch.sum(mask * target, (1, 2)) | |
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b | |
x_0 = torch.zeros_like(b_0) | |
x_1 = torch.zeros_like(b_1) | |
det = a_00 * a_11 - a_01 * a_01 | |
valid = det.nonzero() | |
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] | |
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] | |
return x_0, x_1 | |
def reduction_batch_based(image_loss, M): | |
# average of all valid pixels of the batch | |
# avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) | |
divisor = torch.sum(M) | |
if divisor == 0: | |
return 0 | |
else: | |
return torch.sum(image_loss) / divisor | |
def reduction_image_based(image_loss, M): | |
# mean of average of valid pixels of an image | |
# avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) | |
valid = M.nonzero() | |
image_loss[valid] = image_loss[valid] / M[valid] | |
return torch.mean(image_loss) | |
def mse_loss(prediction, target, mask, reduction=reduction_batch_based): | |
M = torch.sum(mask, (1, 2)) | |
res = prediction - target | |
image_loss = torch.sum(mask * res * res, (1, 2)) | |
return reduction(image_loss, 2 * M) | |
def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): | |
M = torch.sum(mask, (1, 2)) | |
diff = prediction - target | |
diff = torch.mul(mask, diff) | |
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) | |
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) | |
grad_x = torch.mul(mask_x, grad_x) | |
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) | |
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) | |
grad_y = torch.mul(mask_y, grad_y) | |
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) | |
return reduction(image_loss, M) | |
class MiDaSMSELoss(nn.Module): | |
def __init__(self, reduction="batch-based"): | |
super().__init__() | |
if reduction == "batch-based": | |
self.__reduction = reduction_batch_based | |
else: | |
self.__reduction = reduction_image_based | |
def forward(self, prediction, target, mask): | |
return mse_loss(prediction, target, mask, reduction=self.__reduction) | |
class GradientLoss(nn.Module): | |
def __init__(self, scales=4, reduction="batch-based"): | |
super().__init__() | |
if reduction == "batch-based": | |
self.__reduction = reduction_batch_based | |
else: | |
self.__reduction = reduction_image_based | |
self.__scales = scales | |
def forward(self, prediction, target, mask): | |
total = 0 | |
for scale in range(self.__scales): | |
step = pow(2, scale) | |
total += gradient_loss( | |
prediction[:, ::step, ::step], | |
target[:, ::step, ::step], | |
mask[:, ::step, ::step], | |
reduction=self.__reduction, | |
) | |
return total | |
class ScaleAndShiftInvariantLoss(nn.Module): | |
def __init__(self, alpha=0.5, scales=4, reduction="batch-based"): | |
super().__init__() | |
self.__data_loss = MiDaSMSELoss(reduction=reduction) | |
self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) | |
self.__alpha = alpha | |
self.__prediction_ssi = None | |
def forward(self, prediction, target, mask): | |
scale, shift = compute_scale_and_shift(prediction, target, mask) | |
self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) | |
total = self.__data_loss(self.__prediction_ssi, target, mask) | |
if self.__alpha > 0: | |
total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) | |
return total | |
def __get_prediction_ssi(self): | |
return self.__prediction_ssi | |
prediction_ssi = property(__get_prediction_ssi) | |
# end copy | |
# copy from https://github.com/svip-lab/Indoor-SfMLearner/blob/0d682b7ce292484e5e3e2161fc9fc07e2f5ca8d1/layers.py#L218 | |
class SSIM(nn.Module): | |
"""Layer to compute the SSIM loss between a pair of images""" | |
def __init__(self, patch_size): | |
super(SSIM, self).__init__() | |
self.mu_x_pool = nn.AvgPool2d(patch_size, 1) | |
self.mu_y_pool = nn.AvgPool2d(patch_size, 1) | |
self.sig_x_pool = nn.AvgPool2d(patch_size, 1) | |
self.sig_y_pool = nn.AvgPool2d(patch_size, 1) | |
self.sig_xy_pool = nn.AvgPool2d(patch_size, 1) | |
self.refl = nn.ReflectionPad2d(patch_size // 2) | |
self.C1 = 0.01**2 | |
self.C2 = 0.03**2 | |
def forward(self, x, y): | |
x = self.refl(x) | |
y = self.refl(y) | |
mu_x = self.mu_x_pool(x) | |
mu_y = self.mu_y_pool(y) | |
sigma_x = self.sig_x_pool(x**2) - mu_x**2 | |
sigma_y = self.sig_y_pool(y**2) - mu_y**2 | |
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y | |
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) | |
SSIM_d = (mu_x**2 + mu_y**2 + self.C1) * (sigma_x + sigma_y + self.C2) | |
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) | |
# TODO test different losses | |
class NCC(nn.Module): | |
"""Layer to compute the normalization cross correlation (NCC) of patches""" | |
def __init__(self, patch_size: int = 11, min_patch_variance: float = 0.01): | |
super(NCC, self).__init__() | |
self.patch_size = patch_size | |
self.min_patch_variance = min_patch_variance | |
def forward(self, x, y): | |
# TODO if we use gray image we should do it right after loading the image to save computations | |
# to gray image | |
x = torch.mean(x, dim=1) | |
y = torch.mean(y, dim=1) | |
x_mean = torch.mean(x, dim=(1, 2), keepdim=True) | |
y_mean = torch.mean(y, dim=(1, 2), keepdim=True) | |
x_normalized = x - x_mean | |
y_normalized = y - y_mean | |
norm = torch.sum(x_normalized * y_normalized, dim=(1, 2)) | |
var = torch.square(x_normalized).sum(dim=(1, 2)) * torch.square(y_normalized).sum(dim=(1, 2)) | |
denom = torch.sqrt(var + 1e-6) | |
ncc = norm / (denom + 1e-6) | |
# ignore pathces with low variances | |
not_valid = (torch.square(x_normalized).sum(dim=(1, 2)) < self.min_patch_variance) | ( | |
torch.square(y_normalized).sum(dim=(1, 2)) < self.min_patch_variance | |
) | |
ncc[not_valid] = 1.0 | |
score = 1 - ncc.clip(-1.0, 1.0) # 0->2: smaller, better | |
return score[:, None, None, None] | |
class MultiViewLoss(nn.Module): | |
"""compute multi-view consistency loss""" | |
def __init__(self, patch_size: int = 11, topk: int = 4, min_patch_variance: float = 0.01): | |
super(MultiViewLoss, self).__init__() | |
self.patch_size = patch_size | |
self.topk = topk | |
self.min_patch_variance = min_patch_variance | |
# TODO make metric configurable | |
# self.ssim = SSIM(patch_size=patch_size) | |
# self.ncc = NCC(patch_size=patch_size) | |
self.ssim = NCC(patch_size=patch_size, min_patch_variance=min_patch_variance) | |
self.iter = 0 | |
def forward(self, patches: torch.Tensor, valid: torch.Tensor): | |
"""take the mim | |
Args: | |
patches (torch.Tensor): _description_ | |
valid (torch.Tensor): _description_ | |
Returns: | |
_type_: _description_ | |
""" | |
num_imgs, num_rays, _, num_channels = patches.shape | |
if num_rays <= 0: | |
return torch.tensor(0.0).to(patches.device) | |
ref_patches = ( | |
patches[:1, ...] | |
.reshape(1, num_rays, self.patch_size, self.patch_size, num_channels) | |
.expand(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels) | |
.reshape(-1, self.patch_size, self.patch_size, num_channels) | |
.permute(0, 3, 1, 2) | |
) # [N_src*N_rays, 3, patch_size, patch_size] | |
src_patches = ( | |
patches[1:, ...] | |
.reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels) | |
.reshape(-1, self.patch_size, self.patch_size, num_channels) | |
.permute(0, 3, 1, 2) | |
) # [N_src*N_rays, 3, patch_size, patch_size] | |
# apply same reshape to the valid mask | |
src_patches_valid = ( | |
valid[1:, ...] | |
.reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, 1) | |
.reshape(-1, self.patch_size, self.patch_size, 1) | |
.permute(0, 3, 1, 2) | |
) # [N_src*N_rays, 1, patch_size, patch_size] | |
ssim = self.ssim(ref_patches.detach(), src_patches) | |
ssim = torch.mean(ssim, dim=(1, 2, 3)) | |
ssim = ssim.reshape(num_imgs - 1, num_rays) | |
# ignore invalid patch by setting ssim error to very large value | |
ssim_valid = ( | |
src_patches_valid.reshape(-1, self.patch_size * self.patch_size).all(dim=-1).reshape(num_imgs - 1, num_rays) | |
) | |
# we should mask the error after we select the topk value, otherwise we might select far way patches that happens to be inside the image | |
# ssim[torch.logical_not(ssim_valid)] = 1.1 # max ssim_error is 1 | |
min_ssim, idx = torch.topk(ssim, k=self.topk, largest=False, dim=0, sorted=True) | |
min_ssim_valid = ssim_valid[idx, torch.arange(num_rays)[None].expand_as(idx)] | |
# TODO how to set this value for better visualization | |
min_ssim[torch.logical_not(min_ssim_valid)] = 0.0 # max ssim_error is 1 | |
if False: | |
# visualization of topK error computations | |
import cv2 | |
import numpy as np | |
vis_patch_num = num_rays | |
K = min(100, vis_patch_num) | |
image = ( | |
patches[:, :vis_patch_num, :, :] | |
.reshape(-1, vis_patch_num, self.patch_size, self.patch_size, 3) | |
.permute(1, 2, 0, 3, 4) | |
.reshape(vis_patch_num * self.patch_size, -1, 3) | |
) | |
src_patches_reshaped = src_patches.reshape( | |
num_imgs - 1, num_rays, 3, self.patch_size, self.patch_size | |
).permute(1, 0, 3, 4, 2) | |
idx = idx.permute(1, 0) | |
selected_patch = ( | |
src_patches_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx] | |
.permute(0, 2, 1, 3, 4) | |
.reshape(num_rays, self.patch_size, self.topk * self.patch_size, 3)[:vis_patch_num] | |
.reshape(-1, self.topk * self.patch_size, 3) | |
) | |
# apply same reshape to the valid mask | |
src_patches_valid_reshaped = src_patches_valid.reshape( | |
num_imgs - 1, num_rays, 1, self.patch_size, self.patch_size | |
).permute(1, 0, 3, 4, 2) | |
selected_patch_valid = ( | |
src_patches_valid_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx] | |
.permute(0, 2, 1, 3, 4) | |
.reshape(num_rays, self.patch_size, self.topk * self.patch_size, 1)[:vis_patch_num] | |
.reshape(-1, self.topk * self.patch_size, 1) | |
) | |
# valid to image | |
selected_patch_valid = selected_patch_valid.expand_as(selected_patch).float() | |
# breakpoint() | |
image = torch.cat([selected_patch_valid, selected_patch, image], dim=1) | |
# select top rays with highest errors | |
image = image.reshape(num_rays, self.patch_size, -1, 3) | |
_, idx2 = torch.topk( | |
torch.sum(min_ssim, dim=0) / (min_ssim_valid.float().sum(dim=0) + 1e-6), | |
k=K, | |
largest=True, | |
dim=0, | |
sorted=True, | |
) | |
image = image[idx2].reshape(K * self.patch_size, -1, 3) | |
cv2.imwrite(f"vis/{self.iter}.png", (image.detach().cpu().numpy() * 255).astype(np.uint8)[..., ::-1]) | |
self.iter += 1 | |
if self.iter == 9: | |
breakpoint() | |
return torch.sum(min_ssim) / (min_ssim_valid.float().sum() + 1e-6) | |
# sensor depth loss, adapted from https://github.com/dazinovic/neural-rgbd-surface-reconstruction/blob/main/losses.py | |
# class SensorDepthLoss(nn.Module): | |
# """Sensor Depth loss""" | |
# def __init__(self, truncation: float): | |
# super(SensorDepthLoss, self).__init__() | |
# self.truncation = truncation # 0.05 * 0.3 5cm scaled | |
# def forward(self, batch, outputs): | |
# """take the mim | |
# Args: | |
# batch (Dict): inputs | |
# outputs (Dict): outputs data from surface model | |
# Returns: | |
# l1_loss: l1 loss | |
# freespace_loss: free space loss | |
# sdf_loss: sdf loss | |
# """ | |
# depth_pred = outputs["depth"] | |
# depth_gt = batch["sensor_depth"].to(depth_pred.device)[..., None] | |
# valid_gt_mask = depth_gt > 0.0 | |
# l1_loss = torch.sum(valid_gt_mask * torch.abs(depth_gt - depth_pred)) / (valid_gt_mask.sum() + 1e-6) | |
# # free space loss and sdf loss | |
# ray_samples = outputs["ray_samples"] | |
# filed_outputs = outputs["field_outputs"] | |
# pred_sdf = filed_outputs[FieldHeadNames.SDF][..., 0] | |
# directions_norm = outputs["directions_norm"] | |
# z_vals = ray_samples.frustums.starts[..., 0] / directions_norm | |
# truncation = self.truncation | |
# front_mask = valid_gt_mask & (z_vals < (depth_gt - truncation)) | |
# back_mask = valid_gt_mask & (z_vals > (depth_gt + truncation)) | |
# sdf_mask = valid_gt_mask & (~front_mask) & (~back_mask) | |
# num_fs_samples = front_mask.sum() | |
# num_sdf_samples = sdf_mask.sum() | |
# num_samples = num_fs_samples + num_sdf_samples + 1e-6 | |
# fs_weight = 1.0 - num_fs_samples / num_samples | |
# sdf_weight = 1.0 - num_sdf_samples / num_samples | |
# free_space_loss = torch.mean((F.relu(truncation - pred_sdf) * front_mask) ** 2) * fs_weight | |
# sdf_loss = torch.mean(((z_vals + pred_sdf) - depth_gt) ** 2 * sdf_mask) * sdf_weight | |
# return l1_loss, free_space_loss, sdf_loss | |
r"""Implements Stochastic Structural SIMilarity(S3IM) algorithm. | |
It is proposed in the ICCV2023 paper | |
`S3IM: Stochastic Structural SIMilarity and Its Unreasonable Effectiveness for Neural Fields`. | |
Arguments: | |
s3im_kernel_size (int): kernel size in ssim's convolution(default: 4) | |
s3im_stride (int): stride in ssim's convolution(default: 4) | |
s3im_repeat_time (int): repeat time in re-shuffle virtual patch(default: 10) | |
s3im_patch_height (height): height of virtual patch(default: 64) | |
""" | |
class S3IM(torch.nn.Module): | |
def __init__(self, s3im_kernel_size = 4, s3im_stride=4, s3im_repeat_time=10, s3im_patch_height=64, size_average = True): | |
super(S3IM, self).__init__() | |
self.s3im_kernel_size = s3im_kernel_size | |
self.s3im_stride = s3im_stride | |
self.s3im_repeat_time = s3im_repeat_time | |
self.s3im_patch_height = s3im_patch_height | |
self.size_average = size_average | |
self.channel = 1 | |
self.s3im_kernel = self.create_kernel(s3im_kernel_size, self.channel) | |
def gaussian(self, s3im_kernel_size, sigma): | |
gauss = torch.Tensor([exp(-(x - s3im_kernel_size//2)**2/float(2*sigma**2)) for x in range(s3im_kernel_size)]) | |
return gauss/gauss.sum() | |
def create_kernel(self, s3im_kernel_size, channel): | |
_1D_window = self.gaussian(s3im_kernel_size, 1.5).unsqueeze(1) | |
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
s3im_kernel = Variable(_2D_window.expand(channel, 1, s3im_kernel_size, s3im_kernel_size).contiguous()) | |
return s3im_kernel | |
def _ssim(self, img1, img2, s3im_kernel, s3im_kernel_size, channel, size_average = True, s3im_stride=None): | |
mu1 = F.conv2d(img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) | |
mu2 = F.conv2d(img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) | |
mu1_sq = mu1.pow(2) | |
mu2_sq = mu2.pow(2) | |
mu1_mu2 = mu1*mu2 | |
sigma1_sq = F.conv2d(img1*img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_sq | |
sigma2_sq = F.conv2d(img2*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu2_sq | |
sigma12 = F.conv2d(img1*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_mu2 | |
C1 = 0.01**2 | |
C2 = 0.03**2 | |
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) | |
if size_average: | |
return ssim_map.mean() | |
else: | |
return ssim_map.mean(1).mean(1).mean(1) | |
def ssim_loss(self, img1, img2): | |
""" | |
img1, img2: torch.Tensor([b,c,h,w]) | |
""" | |
(_, channel, _, _) = img1.size() | |
if channel == self.channel and self.s3im_kernel.data.type() == img1.data.type(): | |
s3im_kernel = self.s3im_kernel | |
else: | |
s3im_kernel = self.create_kernel(self.s3im_kernel_size, channel) | |
if img1.is_cuda: | |
s3im_kernel = s3im_kernel.cuda(img1.get_device()) | |
s3im_kernel = s3im_kernel.type_as(img1) | |
self.s3im_kernel = s3im_kernel | |
self.channel = channel | |
return self._ssim(img1, img2, s3im_kernel, self.s3im_kernel_size, channel, self.size_average, s3im_stride=self.s3im_stride) | |
def forward(self, src_vec, tar_vec): | |
loss = 0.0 | |
index_list = [] | |
for i in range(self.s3im_repeat_time): | |
if i == 0: | |
tmp_index = torch.arange(len(tar_vec)) | |
index_list.append(tmp_index) | |
else: | |
ran_idx = torch.randperm(len(tar_vec)) | |
index_list.append(ran_idx) | |
res_index = torch.cat(index_list) | |
tar_all = tar_vec[res_index] | |
src_all = src_vec[res_index] | |
tar_patch = tar_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1) | |
src_patch = src_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1) | |
loss = (1 - self.ssim_loss(src_patch, tar_patch)) | |
return loss | |