import math import trimesh import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from packaging import version as pver import tinycudann as tcnn from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd import raymarching def custom_meshgrid(*args): # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid if pver.parse(torch.__version__) < pver.parse('1.10'): return torch.meshgrid(*args) else: return torch.meshgrid(*args, indexing='ij') def sample_pdf(bins, weights, n_samples, det=False): # This implementation is from NeRF # bins: [B, T], old_z_vals # weights: [B, T - 1], bin weights. # return: [B, n_samples], new_z_vals # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # Take uniform samples if det: u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) u = u.expand(list(cdf.shape[:-1]) + [n_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) # Invert CDF u = u.contiguous() inds = torch.searchsorted(cdf, u, right=True) below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples def plot_pointcloud(pc, color=None): # pc: [N, 3] # color: [N, 3/4] print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) pc = trimesh.PointCloud(pc, color) # axis axes = trimesh.creation.axis(axis_length=4) # sphere sphere = trimesh.creation.icosphere(radius=1) trimesh.Scene([pc, axes, sphere]).show() class NGPRenderer(nn.Module): def __init__(self, bound=1, cuda_ray=True, density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance. min_near=0.2, density_thresh=0.01, bg_radius=-1, ): super().__init__() self.bound = bound self.cascade = 1 self.grid_size = 128 self.density_scale = density_scale self.min_near = min_near self.density_thresh = density_thresh self.bg_radius = bg_radius # radius of the background sphere. # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound]) aabb_infer = aabb_train.clone() self.register_buffer('aabb_train', aabb_train) self.register_buffer('aabb_infer', aabb_infer) # extra state for cuda raymarching self.cuda_ray = cuda_ray if cuda_ray: # density grid density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] self.register_buffer('density_grid', density_grid) self.register_buffer('density_bitfield', density_bitfield) self.mean_density = 0 self.iter_density = 0 # step counter step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... self.register_buffer('step_counter', step_counter) self.mean_count = 0 self.local_step = 0 def forward(self, x, d): raise NotImplementedError() # separated density and color query (can accelerate non-cuda-ray mode.) def density(self, x): raise NotImplementedError() def color(self, x, d, mask=None, **kwargs): raise NotImplementedError() def reset_extra_state(self): if not self.cuda_ray: return # density grid self.density_grid.zero_() self.mean_density = 0 self.iter_density = 0 # step counter self.step_counter.zero_() self.mean_count = 0 self.local_step = 0 def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # bg_color: [3] in range [0, 1] # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # choose aabb aabb = self.aabb_train if self.training else self.aabb_infer # sample steps nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near) nears.unsqueeze_(-1) fars.unsqueeze_(-1) #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T] z_vals = z_vals.expand((N, num_steps)) # [N, T] z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] # perturb z_vals sample_dist = (fars - nears) / num_steps if perturb: z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. # generate xyzs xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) # query SDF and RGB density_outputs = self.density(xyzs.reshape(-1, 3)) #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] for k, v in density_outputs.items(): density_outputs[k] = v.view(N, num_steps, -1) # upsample z_vals (nerf-like) if upsample_steps > 0: with torch.no_grad(): deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T] alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] # sample new z_vals z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t] new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. # only forward new points to save computation new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] for k, v in new_density_outputs.items(): new_density_outputs[k] = v.view(N, upsample_steps, -1) # re-order z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] z_vals, z_index = torch.sort(z_vals, dim=1) xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) for k in density_outputs: tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t] alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) for k, v in density_outputs.items(): density_outputs[k] = v.view(-1, v.shape[-1]) mask = weights > 1e-4 # hard coded rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs) rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] #print(xyzs.shape, 'valid_rgb:', mask.sum().item()) # calculate weight_sum (mask) weights_sum = weights.sum(dim=-1) # [N] # calculate depth ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) depth = torch.sum(weights * ori_z_vals, dim=-1) # calculate color image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3] elif bg_color is None: bg_color = 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color image = image.view(*prefix, 3) depth = depth.view(*prefix) # tmp: reg loss in mip-nerf 360 # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1) # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T] # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum() return { 'depth': depth, 'image': image, 'weights_sum': weights_sum, } def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # pre-calculate near far nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near) # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] bg_color = self.background(sph, rays_d) # [N, 3] elif bg_color is None: bg_color = 1 results = {} if self.training: # setup counter counter = self.step_counter[self.local_step % 16] counter.zero_() # set to 0 self.local_step += 1 xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) sigmas, rgbs = self(xyzs, dirs) sigmas = self.density_scale * sigmas weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh) image = image + (1 - weights_sum).unsqueeze(-1) * bg_color depth = torch.clamp(depth - nears, min=0) / (fars - nears) image = image.view(*prefix, 3) depth = depth.view(*prefix) else: # allocate outputs # if use autocast, must init as half so it won't be autocasted and lose reference. #dtype = torch.half if torch.is_autocast_enabled() else torch.float32 # output should always be float32! only network inference uses half. dtype = torch.float32 weights_sum = torch.zeros(N, dtype=dtype, device=device) depth = torch.zeros(N, dtype=dtype, device=device) image = torch.zeros(N, 3, dtype=dtype, device=device) n_alive = N rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] rays_t = nears.clone() # [N] step = 0 while step < max_steps: # count alive rays n_alive = rays_alive.shape[0] # exit loop if n_alive <= 0: break # decide compact_steps n_step = max(min(N // n_alive, 8), 1) xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps) sigmas, rgbs = self(xyzs, dirs) # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. # sigmas = density_outputs['sigma'] # rgbs = self.color(xyzs, dirs, **density_outputs) sigmas = self.density_scale * sigmas raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh) rays_alive = rays_alive[rays_alive >= 0] #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') step += n_step image = image + (1 - weights_sum).unsqueeze(-1) * bg_color depth = torch.clamp(depth - nears, min=0) / (fars - nears) image = image.view(*prefix, 3) depth = depth.view(*prefix) results['weights_sum'] = weights_sum results['depth'] = depth results['image'] = image return results @torch.no_grad() def mark_untrained_grid(self, poses, intrinsic, S=64): # poses: [B, 4, 4] # intrinsic: [3, 3] if not self.cuda_ray: return if isinstance(poses, np.ndarray): poses = torch.from_numpy(poses) B = poses.shape[0] fx, fy, cx, cy = intrinsic X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) count = torch.zeros_like(self.density_grid) poses = poses.to(count.device) # 5-level loop, forgive me... for xs in X: for ys in Y: for zs in Z: # construct points xx, yy, zz = custom_meshgrid(xs, ys, zs) coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1] # cascading for cas in range(self.cascade): bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_world_xyzs = world_xyzs * (bound - half_grid_size) # split batch to avoid OOM head = 0 while head < B: tail = min(head + S, B) # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.) cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1) cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3] # query if point is covered by any camera mask_z = cam_xyzs[:, :, 2] > 0 # [S, N] mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2 mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2 mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N] # update count count[cas, indices] += mask head += S # mark untrained grid as -1 self.density_grid[count == 0] = -1 print(f'[mark untrained grid] {(count == 0).sum()} from {self.grid_size ** 3 * self.cascade}') @torch.no_grad() def update_extra_state(self, decay=0.95, S=128): # call before each epoch to update extra states. if not self.cuda_ray: return ### update density grid tmp_grid = - torch.ones_like(self.density_grid) # full update. if self.iter_density < 16: #if True: X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) for xs in X: for ys in Y: for zs in Z: # construct points xx, yy, zz = custom_meshgrid(xs, ys, zs) coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] # cascading for cas in range(self.cascade): bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_xyzs = xyzs * (bound - half_grid_size) # add noise in [-hgs, hgs] cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size # query density sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() sigmas *= self.density_scale # assign tmp_grid[cas, indices] = sigmas # partial update (half the computation) # TODO: why no need of maxpool ? else: N = self.grid_size ** 3 // 4 # H * H * H / 4 for cas in range(self.cascade): # random sample some positions coords = torch.randint(0, self.grid_size, (N, 3), device=self.density_bitfield.device) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] # random sample occupied positions occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(-1) # [Nz] rand_mask = torch.randint(0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_bitfield.device) occ_indices = occ_indices[rand_mask] # [Nz] --> [N], allow for duplication occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3] # concat indices = torch.cat([indices, occ_indices], dim=0) coords = torch.cat([coords, occ_coords], dim=0) # same below xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_xyzs = xyzs * (bound - half_grid_size) # add noise in [-hgs, hgs] cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size # query density sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() sigmas *= self.density_scale # assign tmp_grid[cas, indices] = sigmas ## max-pool on tmp_grid for less aggressive culling [No significant improvement...] # invalid_mask = tmp_grid < 0 # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1) # tmp_grid[invalid_mask] = -1 # ema update valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 regions are viewed as 0 density. #self.mean_density = torch.mean(self.density_grid[self.density_grid > 0]).item() # do not count -1 regions self.iter_density += 1 # convert to bitfield density_thresh = min(self.mean_density, self.density_thresh) self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) ### update step counter total_step = min(16, self.local_step) if total_step > 0: self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) self.local_step = 0 #print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: pred_rgb: [B, N, 3] if self.cuda_ray: _run = self.run_cuda else: _run = self.run results = _run(rays_o, rays_d, **kwargs) return results class _trunc_exp(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # cast to float32 def forward(ctx, x): ctx.save_for_backward(x) return torch.exp(x) @staticmethod @custom_bwd def backward(ctx, g): x = ctx.saved_tensors[0] return g * torch.exp(x.clamp(-15, 15)) trunc_exp = _trunc_exp.apply class NGPNetwork(NGPRenderer): def __init__(self, num_layers=2, hidden_dim=64, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64, bound=0.5, max_resolution=128, base_resolution=16, n_levels=16, **kwargs ): super().__init__(bound, **kwargs) # sigma network self.num_layers = num_layers self.hidden_dim = hidden_dim self.geo_feat_dim = geo_feat_dim self.bound = bound log2_hashmap_size = 19 n_features_per_level = 2 per_level_scale = np.exp2(np.log2(max_resolution / base_resolution) / (n_levels - 1)) self.encoder = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "HashGrid", "n_levels": n_levels, "n_features_per_level": n_features_per_level, "log2_hashmap_size": log2_hashmap_size, "base_resolution": base_resolution, "per_level_scale": per_level_scale, }, ) self.sigma_net = tcnn.Network( n_input_dims = n_levels * 2, n_output_dims=1 + self.geo_feat_dim, network_config={ "otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "None", "n_neurons": hidden_dim, "n_hidden_layers": num_layers - 1, }, ) # color network self.num_layers_color = num_layers_color self.hidden_dim_color = hidden_dim_color self.encoder_dir = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "SphericalHarmonics", "degree": 4, }, ) self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim self.color_net = tcnn.Network( n_input_dims = self.in_dim_color, n_output_dims=3, network_config={ "otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "None", "n_neurons": hidden_dim_color, "n_hidden_layers": num_layers_color - 1, }, ) self.density_scale, self.density_std = 10.0, 0.25 def forward(self, x, d): # x: [N, 3], in [-bound, bound] # d: [N, 3], nomalized in [-1, 1] # sigma x_raw = x x = (x + self.bound) / (2 * self.bound) # to [0, 1] x = self.encoder(x) h = self.sigma_net(x) # sigma = F.relu(h[..., 0]) density = h[..., 0] # add density bias dist = torch.norm(x_raw, dim=-1) density_bias = (1 - dist / self.density_std) * self.density_scale density = density_bias + density sigma = F.softplus(density) geo_feat = h[..., 1:] # color d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] d = self.encoder_dir(d) # p = torch.zeros_like(geo_feat[..., :1]) # manual input padding h = torch.cat([d, geo_feat], dim=-1) h = self.color_net(h) # sigmoid activation for rgb color = torch.sigmoid(h) return sigma, color def density(self, x): # x: [N, 3], in [-bound, bound] x_raw = x x = (x + self.bound) / (2 * self.bound) # to [0, 1] x = self.encoder(x) h = self.sigma_net(x) # sigma = F.relu(h[..., 0]) density = h[..., 0] # add density bias dist = torch.norm(x_raw, dim=-1) density_bias = (1 - dist / self.density_std) * self.density_scale density = density_bias + density sigma = F.softplus(density) geo_feat = h[..., 1:] return { 'sigma': sigma, 'geo_feat': geo_feat, } # allow masked inference def color(self, x, d, mask=None, geo_feat=None, **kwargs): # x: [N, 3] in [-bound, bound] # mask: [N,], bool, indicates where we actually needs to compute rgb. x = (x + self.bound) / (2 * self.bound) # to [0, 1] if mask is not None: rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] # in case of empty mask if not mask.any(): return rgbs x = x[mask] d = d[mask] geo_feat = geo_feat[mask] # color d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] d = self.encoder_dir(d) h = torch.cat([d, geo_feat], dim=-1) h = self.color_net(h) # sigmoid activation for rgb h = torch.sigmoid(h) if mask is not None: rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 else: rgbs = h return rgbs