# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. # # Modified by Jiale Xu # The modifications are subject to the same license as the original. """ The renderer is a module that takes in rays, decides where to sample along each ray, and computes pixel colors using the volume rendering equation. """ import torch import torch.nn as nn import torch.nn.functional as F from . import math_utils def generate_planes(): """ Defines planes by the three vectors that form the "axes" of the plane. Should work with arbitrary number of planes and planes of arbitrary orientation. Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 """ return torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]], [[0, 0, 1], [0, 1, 0], [1, 0, 0]]], dtype=torch.float32) def project_onto_planes(planes, coordinates): """ Does a projection of a 3D point onto a batch of 2D planes, returning 2D plane coordinates. Takes plane axes of shape n_planes, 3, 3 # Takes coordinates of shape N, M, 3 # returns projections of shape N*n_planes, M, 2 """ N, M, C = coordinates.shape n_planes, _, _ = planes.shape coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) projections = torch.bmm(coordinates, inv_planes) return projections[..., :2] def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): assert padding_mode == 'zeros' N, n_planes, C, H, W = plane_features.shape _, M, _ = coordinates.shape plane_features = plane_features.view(N*n_planes, C, H, W) dtype = plane_features.dtype coordinates = (2/box_warp) * coordinates # add specific box bounds projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) output_features = torch.nn.functional.grid_sample( plane_features, projected_coordinates.to(dtype), mode=mode, padding_mode=padding_mode, align_corners=False, ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) return output_features def sample_from_3dgrid(grid, coordinates): """ Expects coordinates in shape (batch_size, num_points_per_batch, 3) Expects grid in shape (1, channels, H, W, D) (Also works if grid has batch size) Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) """ batch_size, n_coords, n_dims = coordinates.shape sampled_features = torch.nn.functional.grid_sample( grid.expand(batch_size, -1, -1, -1, -1), coordinates.reshape(batch_size, 1, 1, -1, n_dims), mode='bilinear', padding_mode='zeros', align_corners=False, ) N, C, H, W, D = sampled_features.shape sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) return sampled_features