import os from dataclasses import dataclass, field import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import threestudio from threestudio.models.geometry.base import ( BaseExplicitGeometry, BaseGeometry, contract_to_unisphere, ) from threestudio.models.mesh import Mesh from threestudio.models.networks import get_encoding, get_mlp from threestudio.utils.ops import scale_tensor from threestudio.utils.typing import * @threestudio.register("custom-mesh") class CustomMesh(BaseExplicitGeometry): @dataclass class Config(BaseExplicitGeometry.Config): n_input_dims: int = 3 n_feature_dims: int = 3 pos_encoding_config: dict = field( default_factory=lambda: { "otype": "HashGrid", "n_levels": 16, "n_features_per_level": 2, "log2_hashmap_size": 19, "base_resolution": 16, "per_level_scale": 1.447269237440378, } ) mlp_network_config: dict = field( default_factory=lambda: { "otype": "VanillaMLP", "activation": "ReLU", "output_activation": "none", "n_neurons": 64, "n_hidden_layers": 1, } ) shape_init: str = "" shape_init_params: Optional[Any] = None shape_init_mesh_up: str = "+z" shape_init_mesh_front: str = "+x" cfg: Config def configure(self) -> None: super().configure() self.encoding = get_encoding( self.cfg.n_input_dims, self.cfg.pos_encoding_config ) self.feature_network = get_mlp( self.encoding.n_output_dims, self.cfg.n_feature_dims, self.cfg.mlp_network_config, ) # Initialize custom mesh if self.cfg.shape_init.startswith("mesh:"): assert isinstance(self.cfg.shape_init_params, float) mesh_path = self.cfg.shape_init[5:] if not os.path.exists(mesh_path): raise ValueError(f"Mesh file {mesh_path} does not exist.") import trimesh scene = trimesh.load(mesh_path) if isinstance(scene, trimesh.Trimesh): mesh = scene elif isinstance(scene, trimesh.scene.Scene): mesh = trimesh.Trimesh() for obj in scene.geometry.values(): mesh = trimesh.util.concatenate([mesh, obj]) else: raise ValueError(f"Unknown mesh type at {mesh_path}.") # move to center centroid = mesh.vertices.mean(0) mesh.vertices = mesh.vertices - centroid # align to up-z and front-x dirs = ["+x", "+y", "+z", "-x", "-y", "-z"] dir2vec = { "+x": np.array([1, 0, 0]), "+y": np.array([0, 1, 0]), "+z": np.array([0, 0, 1]), "-x": np.array([-1, 0, 0]), "-y": np.array([0, -1, 0]), "-z": np.array([0, 0, -1]), } if ( self.cfg.shape_init_mesh_up not in dirs or self.cfg.shape_init_mesh_front not in dirs ): raise ValueError( f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}." ) if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]: raise ValueError( "shape_init_mesh_up and shape_init_mesh_front must be orthogonal." ) z_, x_ = ( dir2vec[self.cfg.shape_init_mesh_up], dir2vec[self.cfg.shape_init_mesh_front], ) y_ = np.cross(z_, x_) std2mesh = np.stack([x_, y_, z_], axis=0).T mesh2std = np.linalg.inv(std2mesh) # scaling scale = np.abs(mesh.vertices).max() mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T v_pos = torch.tensor(mesh.vertices, dtype=torch.float32).to(self.device) t_pos_idx = torch.tensor(mesh.faces, dtype=torch.int64).to(self.device) self.mesh = Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx) self.register_buffer( "v_buffer", v_pos, ) self.register_buffer( "t_buffer", t_pos_idx, ) else: raise ValueError( f"Unknown shape initialization type: {self.cfg.shape_init}" ) print(self.mesh.v_pos.device) def isosurface(self) -> Mesh: if hasattr(self, "mesh"): return self.mesh elif hasattr(self, "v_buffer"): self.mesh = Mesh(v_pos=self.v_buffer, t_pos_idx=self.t_buffer) return self.mesh else: raise ValueError(f"custom mesh is not initialized") def forward( self, points: Float[Tensor, "*N Di"], output_normal: bool = False ) -> Dict[str, Float[Tensor, "..."]]: assert ( output_normal == False ), f"Normal output is not supported for {self.__class__.__name__}" points_unscaled = points # points in the original scale points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1) enc = self.encoding(points.view(-1, self.cfg.n_input_dims)) features = self.feature_network(enc).view( *points.shape[:-1], self.cfg.n_feature_dims ) return {"features": features} def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]: out: Dict[str, Any] = {} if self.cfg.n_feature_dims == 0: return out points_unscaled = points points = contract_to_unisphere(points_unscaled, self.bbox) enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims)) features = self.feature_network(enc).view( *points.shape[:-1], self.cfg.n_feature_dims ) out.update( { "features": features, } ) return out