import os from dataclasses import dataclass, field import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import threestudio from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere from threestudio.models.mesh import Mesh from threestudio.models.networks import get_encoding, get_mlp from threestudio.utils.misc import broadcast, get_rank from threestudio.utils.typing import * @threestudio.register("implicit-sdf") class ImplicitSDF(BaseImplicitGeometry): @dataclass class Config(BaseImplicitGeometry.Config): n_input_dims: int = 3 n_feature_dims: int = 3 pos_encoding_config: dict = field( default_factory=lambda: { "otype": "HashGrid", "n_levels": 16, "n_features_per_level": 2, "log2_hashmap_size": 19, "base_resolution": 16, "per_level_scale": 1.447269237440378, } ) mlp_network_config: dict = field( default_factory=lambda: { "otype": "VanillaMLP", "activation": "ReLU", "output_activation": "none", "n_neurons": 64, "n_hidden_layers": 1, } ) normal_type: Optional[ str ] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian'] finite_difference_normal_eps: Union[ float, str ] = 0.01 # in [float, "progressive"] shape_init: Optional[str] = None shape_init_params: Optional[Any] = None shape_init_mesh_up: str = "+z" shape_init_mesh_front: str = "+x" force_shape_init: bool = False sdf_bias: Union[float, str] = 0.0 sdf_bias_params: Optional[Any] = None # no need to removal outlier for SDF isosurface_remove_outliers: bool = False cfg: Config def configure(self) -> None: super().configure() self.encoding = get_encoding( self.cfg.n_input_dims, self.cfg.pos_encoding_config ) self.sdf_network = get_mlp( self.encoding.n_output_dims, 1, self.cfg.mlp_network_config ) if self.cfg.n_feature_dims > 0: self.feature_network = get_mlp( self.encoding.n_output_dims, self.cfg.n_feature_dims, self.cfg.mlp_network_config, ) if self.cfg.normal_type == "pred": self.normal_network = get_mlp( self.encoding.n_output_dims, 3, self.cfg.mlp_network_config ) if self.cfg.isosurface_deformable_grid: assert ( self.cfg.isosurface_method == "mt" ), "isosurface_deformable_grid only works with mt" self.deformation_network = get_mlp( self.encoding.n_output_dims, 3, self.cfg.mlp_network_config ) self.finite_difference_normal_eps: Optional[float] = None def initialize_shape(self) -> None: if self.cfg.shape_init is None and not self.cfg.force_shape_init: return # do not initialize shape if weights are provided if self.cfg.weights is not None and not self.cfg.force_shape_init: return if self.cfg.sdf_bias != 0.0: threestudio.warn( "shape_init and sdf_bias are both specified, which may lead to unexpected results." ) get_gt_sdf: Callable[[Float[Tensor, "N 3"]], Float[Tensor, "N 1"]] assert isinstance(self.cfg.shape_init, str) if self.cfg.shape_init == "ellipsoid": assert ( isinstance(self.cfg.shape_init_params, Sized) and len(self.cfg.shape_init_params) == 3 ) size = torch.as_tensor(self.cfg.shape_init_params).to(self.device) def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: return ((points_rand / size) ** 2).sum( dim=-1, keepdim=True ).sqrt() - 1.0 # pseudo signed distance of an ellipsoid get_gt_sdf = func elif self.cfg.shape_init == "sphere": assert isinstance(self.cfg.shape_init_params, float) radius = self.cfg.shape_init_params def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: return (points_rand**2).sum(dim=-1, keepdim=True).sqrt() - radius get_gt_sdf = func elif self.cfg.shape_init.startswith("mesh:"): assert isinstance(self.cfg.shape_init_params, float) mesh_path = self.cfg.shape_init[5:] if not os.path.exists(mesh_path): raise ValueError(f"Mesh file {mesh_path} does not exist.") import trimesh scene = trimesh.load(mesh_path) if isinstance(scene, trimesh.Trimesh): mesh = scene elif isinstance(scene, trimesh.scene.Scene): mesh = trimesh.Trimesh() for obj in scene.geometry.values(): mesh = trimesh.util.concatenate([mesh, obj]) else: raise ValueError(f"Unknown mesh type at {mesh_path}.") # move to center centroid = mesh.vertices.mean(0) mesh.vertices = mesh.vertices - centroid # align to up-z and front-x dirs = ["+x", "+y", "+z", "-x", "-y", "-z"] dir2vec = { "+x": np.array([1, 0, 0]), "+y": np.array([0, 1, 0]), "+z": np.array([0, 0, 1]), "-x": np.array([-1, 0, 0]), "-y": np.array([0, -1, 0]), "-z": np.array([0, 0, -1]), } if ( self.cfg.shape_init_mesh_up not in dirs or self.cfg.shape_init_mesh_front not in dirs ): raise ValueError( f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}." ) if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]: raise ValueError( "shape_init_mesh_up and shape_init_mesh_front must be orthogonal." ) z_, x_ = ( dir2vec[self.cfg.shape_init_mesh_up], dir2vec[self.cfg.shape_init_mesh_front], ) y_ = np.cross(z_, x_) std2mesh = np.stack([x_, y_, z_], axis=0).T mesh2std = np.linalg.inv(std2mesh) # scaling scale = np.abs(mesh.vertices).max() mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T from pysdf import SDF sdf = SDF(mesh.vertices, mesh.faces) def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]: # add a negative signed here # as in pysdf the inside of the shape has positive signed distance return torch.from_numpy(-sdf(points_rand.cpu().numpy())).to( points_rand )[..., None] get_gt_sdf = func else: raise ValueError( f"Unknown shape initialization type: {self.cfg.shape_init}" ) # Initialize SDF to a given shape when no weights are provided or force_shape_init is True optim = torch.optim.Adam(self.parameters(), lr=1e-3) from tqdm import tqdm for _ in tqdm( range(1000), desc=f"Initializing SDF to a(n) {self.cfg.shape_init}:", disable=get_rank() != 0, ): points_rand = ( torch.rand((10000, 3), dtype=torch.float32).to(self.device) * 2.0 - 1.0 ) sdf_gt = get_gt_sdf(points_rand) sdf_pred = self.forward_sdf(points_rand) loss = F.mse_loss(sdf_pred, sdf_gt) optim.zero_grad() loss.backward() optim.step() # explicit broadcast to ensure param consistency across ranks for param in self.parameters(): broadcast(param, src=0) def get_shifted_sdf( self, points: Float[Tensor, "*N Di"], sdf: Float[Tensor, "*N 1"] ) -> Float[Tensor, "*N 1"]: sdf_bias: Union[float, Float[Tensor, "*N 1"]] if self.cfg.sdf_bias == "ellipsoid": assert ( isinstance(self.cfg.sdf_bias_params, Sized) and len(self.cfg.sdf_bias_params) == 3 ) size = torch.as_tensor(self.cfg.sdf_bias_params).to(points) sdf_bias = ((points / size) ** 2).sum( dim=-1, keepdim=True ).sqrt() - 1.0 # pseudo signed distance of an ellipsoid elif self.cfg.sdf_bias == "sphere": assert isinstance(self.cfg.sdf_bias_params, float) radius = self.cfg.sdf_bias_params sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius elif isinstance(self.cfg.sdf_bias, float): sdf_bias = self.cfg.sdf_bias else: raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}") return sdf + sdf_bias def forward( self, points: Float[Tensor, "*N Di"], output_normal: bool = False ) -> Dict[str, Float[Tensor, "..."]]: grad_enabled = torch.is_grad_enabled() if output_normal and self.cfg.normal_type == "analytic": torch.set_grad_enabled(True) points.requires_grad_(True) points_unscaled = points # points in the original scale points = contract_to_unisphere( points, self.bbox, self.unbounded ) # points normalized to (0, 1) enc = self.encoding(points.view(-1, self.cfg.n_input_dims)) sdf = self.sdf_network(enc).view(*points.shape[:-1], 1) sdf = self.get_shifted_sdf(points_unscaled, sdf) output = {"sdf": sdf} if self.cfg.n_feature_dims > 0: features = self.feature_network(enc).view( *points.shape[:-1], self.cfg.n_feature_dims ) output.update({"features": features}) if output_normal: if ( self.cfg.normal_type == "finite_difference" or self.cfg.normal_type == "finite_difference_laplacian" ): assert self.finite_difference_normal_eps is not None eps: float = self.finite_difference_normal_eps if self.cfg.normal_type == "finite_difference_laplacian": offsets: Float[Tensor, "6 3"] = torch.as_tensor( [ [eps, 0.0, 0.0], [-eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, -eps, 0.0], [0.0, 0.0, eps], [0.0, 0.0, -eps], ] ).to(points_unscaled) points_offset: Float[Tensor, "... 6 3"] = ( points_unscaled[..., None, :] + offsets ).clamp(-self.cfg.radius, self.cfg.radius) sdf_offset: Float[Tensor, "... 6 1"] = self.forward_sdf( points_offset ) sdf_grad = ( 0.5 * (sdf_offset[..., 0::2, 0] - sdf_offset[..., 1::2, 0]) / eps ) else: offsets: Float[Tensor, "3 3"] = torch.as_tensor( [[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]] ).to(points_unscaled) points_offset: Float[Tensor, "... 3 3"] = ( points_unscaled[..., None, :] + offsets ).clamp(-self.cfg.radius, self.cfg.radius) sdf_offset: Float[Tensor, "... 3 1"] = self.forward_sdf( points_offset ) sdf_grad = (sdf_offset[..., 0::1, 0] - sdf) / eps normal = F.normalize(sdf_grad, dim=-1) elif self.cfg.normal_type == "pred": normal = self.normal_network(enc).view(*points.shape[:-1], 3) normal = F.normalize(normal, dim=-1) sdf_grad = normal elif self.cfg.normal_type == "analytic": sdf_grad = -torch.autograd.grad( sdf, points_unscaled, grad_outputs=torch.ones_like(sdf), create_graph=True, )[0] normal = F.normalize(sdf_grad, dim=-1) if not grad_enabled: sdf_grad = sdf_grad.detach() normal = normal.detach() else: raise AttributeError(f"Unknown normal type {self.cfg.normal_type}") output.update( {"normal": normal, "shading_normal": normal, "sdf_grad": sdf_grad} ) return output def forward_sdf(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]: points_unscaled = points points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) sdf = self.sdf_network( self.encoding(points.reshape(-1, self.cfg.n_input_dims)) ).reshape(*points.shape[:-1], 1) sdf = self.get_shifted_sdf(points_unscaled, sdf) return sdf def forward_field( self, points: Float[Tensor, "*N Di"] ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: points_unscaled = points points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims)) sdf = self.sdf_network(enc).reshape(*points.shape[:-1], 1) sdf = self.get_shifted_sdf(points_unscaled, sdf) deformation: Optional[Float[Tensor, "*N 3"]] = None if self.cfg.isosurface_deformable_grid: deformation = self.deformation_network(enc).reshape(*points.shape[:-1], 3) return sdf, deformation def forward_level( self, field: Float[Tensor, "*N 1"], threshold: float ) -> Float[Tensor, "*N 1"]: return field - threshold def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: out: Dict[str, Any] = {} if self.cfg.n_feature_dims == 0: return out points_unscaled = points points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded) enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims)) features = self.feature_network(enc).view( *points.shape[:-1], self.cfg.n_feature_dims ) out.update( { "features": features, } ) return out def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): if ( self.cfg.normal_type == "finite_difference" or self.cfg.normal_type == "finite_difference_laplacian" ): if isinstance(self.cfg.finite_difference_normal_eps, float): self.finite_difference_normal_eps = ( self.cfg.finite_difference_normal_eps ) elif self.cfg.finite_difference_normal_eps == "progressive": # progressive finite difference eps from Neuralangelo # https://arxiv.org/abs/2306.03092 hg_conf: Any = self.cfg.pos_encoding_config assert ( hg_conf.otype == "ProgressiveBandHashGrid" ), "finite_difference_normal_eps=progressive only works with ProgressiveBandHashGrid" current_level = min( hg_conf.start_level + max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps, hg_conf.n_levels, ) grid_res = hg_conf.base_resolution * hg_conf.per_level_scale ** ( current_level - 1 ) grid_size = 2 * self.cfg.radius / grid_res if grid_size != self.finite_difference_normal_eps: threestudio.info( f"Update finite_difference_normal_eps to {grid_size}" ) self.finite_difference_normal_eps = grid_size else: raise ValueError( f"Unknown finite_difference_normal_eps={self.cfg.finite_difference_normal_eps}" )