diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b597ee4e8097c1563905d8d1a6769f9b9e9041cc --- /dev/null +++ b/app.py @@ -0,0 +1,152 @@ +import gradio as gr +# from gradio_litmodel3d import LitModel3D + +import os +from typing import * +import imageio +import uuid +from PIL import Image +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.utils import render_utils, postprocessing_utils + + +def preprocess_image(image: Image.Image) -> Image.Image: + """ + Preprocess the input image. + + Args: + image (Image.Image): The input image. + + Returns: + Image.Image: The preprocessed image. + """ + return pipeline.preprocess_image(image) + + +def image_to_3d(image: Image.Image) -> Tuple[dict, str]: + """ + Convert an image to a 3D model. + + Args: + image (Image.Image): The input image. + + Returns: + dict: The information of the generated 3D model. + str: The path to the video of the 3D model. + """ + outputs = pipeline(image, formats=["gaussian", "mesh"], preprocess_image=False) + video = render_utils.render_video(outputs['gaussian'][0])['color'] + model_id = uuid.uuid4() + video_path = f"/tmp/Trellis-demo/{model_id}.mp4" + os.makedirs(os.path.dirname(video_path), exist_ok=True) + imageio.mimsave(video_path, video, fps=30) + model = {'gaussian': outputs['gaussian'][0], 'mesh': outputs['mesh'][0], 'model_id': model_id} + return model, video_path + + +def extract_glb(model: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]: + """ + Extract a GLB file from the 3D model. + + Args: + model (dict): The generated 3D model. + mesh_simplify (float): The mesh simplification factor. + texture_size (int): The texture resolution. + + Returns: + str: The path to the extracted GLB file. + """ + glb = postprocessing_utils.to_glb(model['gaussian'], model['mesh'], simplify=mesh_simplify, texture_size=texture_size) + glb_path = f"/tmp/Trellis-demo/{model['model_id']}.glb" + glb.export(glb_path) + return glb_path, glb_path + + +def activate_button() -> gr.Button: + return gr.Button(interactive=True) + + +def deactivate_button() -> gr.Button: + return gr.Button(interactive=False) + + +with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300) + generate_btn = gr.Button("Generate", interactive=False) + + mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) + texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) + extract_glb_btn = gr.Button("Extract GLB", interactive=False) + + with gr.Column(): + video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) + model_output = gr.Model3D(label="Extracted GLB", height=300) + download_glb = gr.DownloadButton(label="Download GLB", interactive=False) + + # Example images at the bottom of the page + with gr.Row(): + examples = gr.Examples( + examples=[ + f'assets/example_image/{image}' + for image in os.listdir("assets/example_image") + ], + inputs=[image_prompt], + fn=lambda image: (preprocess_image(image), gr.Button(interactive=True)), + outputs=[image_prompt, generate_btn], + run_on_click=True, + examples_per_page=64, + ) + + model = gr.State() + + # Handlers + image_prompt.upload( + preprocess_image, + inputs=[image_prompt], + outputs=[image_prompt], + ).then( + activate_button, + outputs=[generate_btn], + ) + + image_prompt.clear( + deactivate_button, + outputs=[generate_btn], + ) + + generate_btn.click( + image_to_3d, + inputs=[image_prompt], + outputs=[model, video_output], + ).then( + activate_button, + outputs=[extract_glb_btn], + ) + + video_output.clear( + deactivate_button, + outputs=[extract_glb_btn], + ) + + extract_glb_btn.click( + extract_glb, + inputs=[model, mesh_simplify, texture_size], + outputs=[model_output, download_glb], + ).then( + activate_button, + outputs=[download_glb], + ) + + model_output.clear( + deactivate_button, + outputs=[download_glb], + ) + + +# Launch the Gradio app +if __name__ == "__main__": + pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") + pipeline.cuda() + demo.launch() diff --git a/assets/example_image/T.png b/assets/example_image/T.png new file mode 100644 index 0000000000000000000000000000000000000000..79af51bc9d711951fbc63be16b7d07c84294355b Binary files /dev/null and b/assets/example_image/T.png differ diff --git a/assets/example_image/typical_building_building.png b/assets/example_image/typical_building_building.png new file mode 100644 index 0000000000000000000000000000000000000000..4f9adcf79d4297bb9b906608c23c311f9b8f23d2 Binary files /dev/null and b/assets/example_image/typical_building_building.png differ diff --git a/assets/example_image/typical_building_castle.png b/assets/example_image/typical_building_castle.png new file mode 100644 index 0000000000000000000000000000000000000000..5f5f50733f3b8ed168026340ed679df357ccb9ec Binary files /dev/null and b/assets/example_image/typical_building_castle.png differ diff --git a/assets/example_image/typical_building_colorful_cottage.png b/assets/example_image/typical_building_colorful_cottage.png new file mode 100644 index 0000000000000000000000000000000000000000..94616b19be6a896413c287b3168b3b7886d64a56 Binary files /dev/null and b/assets/example_image/typical_building_colorful_cottage.png differ diff --git a/assets/example_image/typical_building_maya_pyramid.png b/assets/example_image/typical_building_maya_pyramid.png new file mode 100644 index 0000000000000000000000000000000000000000..1d87f3c3980f80eee878b4f8ab69e32279a5ea50 Binary files /dev/null and b/assets/example_image/typical_building_maya_pyramid.png differ diff --git a/assets/example_image/typical_building_mushroom.png b/assets/example_image/typical_building_mushroom.png new file mode 100644 index 0000000000000000000000000000000000000000..c4db49d4284e6fec83b2ea548e7e85d489711b84 Binary files /dev/null and b/assets/example_image/typical_building_mushroom.png differ diff --git a/assets/example_image/typical_building_space_station.png b/assets/example_image/typical_building_space_station.png new file mode 100644 index 0000000000000000000000000000000000000000..e37a5806d403dccc098d82492d687d71afa36850 Binary files /dev/null and b/assets/example_image/typical_building_space_station.png differ diff --git a/assets/example_image/typical_creature_dragon.png b/assets/example_image/typical_creature_dragon.png new file mode 100644 index 0000000000000000000000000000000000000000..c3fb92ff0400451c69f44fb75156f50db93da6ec Binary files /dev/null and b/assets/example_image/typical_creature_dragon.png differ diff --git a/assets/example_image/typical_creature_elephant.png b/assets/example_image/typical_creature_elephant.png new file mode 100644 index 0000000000000000000000000000000000000000..6fc3cf1776c66b91e739cad8e839b52074a57e4c Binary files /dev/null and b/assets/example_image/typical_creature_elephant.png differ diff --git a/assets/example_image/typical_creature_furry.png b/assets/example_image/typical_creature_furry.png new file mode 100644 index 0000000000000000000000000000000000000000..eb4e8d6c6cac1e03a206429eaf7de261c6f14072 Binary files /dev/null and b/assets/example_image/typical_creature_furry.png differ diff --git a/assets/example_image/typical_creature_quadruped.png b/assets/example_image/typical_creature_quadruped.png new file mode 100644 index 0000000000000000000000000000000000000000..b246e08e05702051fb22cced1366ab765cd6fbb0 Binary files /dev/null and b/assets/example_image/typical_creature_quadruped.png differ diff --git a/assets/example_image/typical_creature_robot_crab.png b/assets/example_image/typical_creature_robot_crab.png new file mode 100644 index 0000000000000000000000000000000000000000..8b4e10b353e0e9b60634ea272ff8fd9135fdd640 Binary files /dev/null and b/assets/example_image/typical_creature_robot_crab.png differ diff --git a/assets/example_image/typical_creature_robot_dinosour.png b/assets/example_image/typical_creature_robot_dinosour.png new file mode 100644 index 0000000000000000000000000000000000000000..7f8f51728fe1fecb0532673756b1601ef46edc2c Binary files /dev/null and b/assets/example_image/typical_creature_robot_dinosour.png differ diff --git a/assets/example_image/typical_creature_rock_monster.png b/assets/example_image/typical_creature_rock_monster.png new file mode 100644 index 0000000000000000000000000000000000000000..29dc243b197d9b3ee4df9355a5f08752ef0b9b9e Binary files /dev/null and b/assets/example_image/typical_creature_rock_monster.png differ diff --git a/assets/example_image/typical_humanoid_block_robot.png b/assets/example_image/typical_humanoid_block_robot.png new file mode 100644 index 0000000000000000000000000000000000000000..195212e38e6a8e331b02c2d58728ba41dba429a1 Binary files /dev/null and b/assets/example_image/typical_humanoid_block_robot.png differ diff --git a/assets/example_image/typical_humanoid_dragonborn.png b/assets/example_image/typical_humanoid_dragonborn.png new file mode 100644 index 0000000000000000000000000000000000000000..61ca2d9e69634c12ee9ae6f7e77f84839df83fdb Binary files /dev/null and b/assets/example_image/typical_humanoid_dragonborn.png differ diff --git a/assets/example_image/typical_humanoid_dwarf.png b/assets/example_image/typical_humanoid_dwarf.png new file mode 100644 index 0000000000000000000000000000000000000000..16de1631fff3cc42a3a5d6a8b0f638da75ad7b2f Binary files /dev/null and b/assets/example_image/typical_humanoid_dwarf.png differ diff --git a/assets/example_image/typical_humanoid_goblin.png b/assets/example_image/typical_humanoid_goblin.png new file mode 100644 index 0000000000000000000000000000000000000000..4e4fe04517801d5722817e8dfaed2af83b31d67e Binary files /dev/null and b/assets/example_image/typical_humanoid_goblin.png differ diff --git a/assets/example_image/typical_humanoid_mech.png b/assets/example_image/typical_humanoid_mech.png new file mode 100644 index 0000000000000000000000000000000000000000..f0fbbdf6cda5636f517b6e2fa3f20e15e56e3777 Binary files /dev/null and b/assets/example_image/typical_humanoid_mech.png differ diff --git a/assets/example_image/typical_misc_crate.png b/assets/example_image/typical_misc_crate.png new file mode 100644 index 0000000000000000000000000000000000000000..c3086f885bf9fc27c398b5bacfb04a65bd7dfbd9 Binary files /dev/null and b/assets/example_image/typical_misc_crate.png differ diff --git a/assets/example_image/typical_misc_fireplace.png b/assets/example_image/typical_misc_fireplace.png new file mode 100644 index 0000000000000000000000000000000000000000..82d79bc10346604a8b8b9cc8e2c317e8dc6d8c47 Binary files /dev/null and b/assets/example_image/typical_misc_fireplace.png differ diff --git a/assets/example_image/typical_misc_gate.png b/assets/example_image/typical_misc_gate.png new file mode 100644 index 0000000000000000000000000000000000000000..fa77919f9d9faabc26b9287b35c4dd3b4006163e Binary files /dev/null and b/assets/example_image/typical_misc_gate.png differ diff --git a/assets/example_image/typical_misc_lantern.png b/assets/example_image/typical_misc_lantern.png new file mode 100644 index 0000000000000000000000000000000000000000..4c93f5dea2638a5a169dd1557a36f6d34b57144d Binary files /dev/null and b/assets/example_image/typical_misc_lantern.png differ diff --git a/assets/example_image/typical_misc_magicbook.png b/assets/example_image/typical_misc_magicbook.png new file mode 100644 index 0000000000000000000000000000000000000000..7dc521a10fda176694c30170811809050f478a66 Binary files /dev/null and b/assets/example_image/typical_misc_magicbook.png differ diff --git a/assets/example_image/typical_misc_mailbox.png b/assets/example_image/typical_misc_mailbox.png new file mode 100644 index 0000000000000000000000000000000000000000..b6e8bc50cd270bb7462eee2af7a6d5649ef54cf2 Binary files /dev/null and b/assets/example_image/typical_misc_mailbox.png differ diff --git a/assets/example_image/typical_misc_monster_chest.png b/assets/example_image/typical_misc_monster_chest.png new file mode 100644 index 0000000000000000000000000000000000000000..6d544370fa306138e7dbab3e548d6e05b8ef2317 Binary files /dev/null and b/assets/example_image/typical_misc_monster_chest.png differ diff --git a/assets/example_image/typical_misc_paper_machine.png b/assets/example_image/typical_misc_paper_machine.png new file mode 100644 index 0000000000000000000000000000000000000000..a630074dbfe32c53f52f2f27e5b6b3eff8469a9e Binary files /dev/null and b/assets/example_image/typical_misc_paper_machine.png differ diff --git a/assets/example_image/typical_misc_phonograph.png b/assets/example_image/typical_misc_phonograph.png new file mode 100644 index 0000000000000000000000000000000000000000..668662d741344ac16427259fc966186ef8ca97a9 Binary files /dev/null and b/assets/example_image/typical_misc_phonograph.png differ diff --git a/assets/example_image/typical_misc_portal2.png b/assets/example_image/typical_misc_portal2.png new file mode 100644 index 0000000000000000000000000000000000000000..666daa75fbaf7df55585f7143906d158175be6be Binary files /dev/null and b/assets/example_image/typical_misc_portal2.png differ diff --git a/assets/example_image/typical_misc_storage_chest.png b/assets/example_image/typical_misc_storage_chest.png new file mode 100644 index 0000000000000000000000000000000000000000..38f4bd31f8eb62badcc5e1a51d4612e528b4069e Binary files /dev/null and b/assets/example_image/typical_misc_storage_chest.png differ diff --git a/assets/example_image/typical_misc_telephone.png b/assets/example_image/typical_misc_telephone.png new file mode 100644 index 0000000000000000000000000000000000000000..a0a7d65a300d9f1adc55b2fc36951731f5abb355 Binary files /dev/null and b/assets/example_image/typical_misc_telephone.png differ diff --git a/assets/example_image/typical_misc_television.png b/assets/example_image/typical_misc_television.png new file mode 100644 index 0000000000000000000000000000000000000000..1d6b5882b42ce532f6a60080ad55bda7053530c0 Binary files /dev/null and b/assets/example_image/typical_misc_television.png differ diff --git a/assets/example_image/typical_misc_workbench.png b/assets/example_image/typical_misc_workbench.png new file mode 100644 index 0000000000000000000000000000000000000000..88024f960ff56aa619b0c496f85de390076bbf5a Binary files /dev/null and b/assets/example_image/typical_misc_workbench.png differ diff --git a/assets/example_image/typical_vehicle_biplane.png b/assets/example_image/typical_vehicle_biplane.png new file mode 100644 index 0000000000000000000000000000000000000000..7427cad3270d8ed33dad05c7a2ae1b0092b4beb2 Binary files /dev/null and b/assets/example_image/typical_vehicle_biplane.png differ diff --git a/assets/example_image/typical_vehicle_bulldozer.png b/assets/example_image/typical_vehicle_bulldozer.png new file mode 100644 index 0000000000000000000000000000000000000000..17ffe389498d9561ef92766de654ef17b5755f60 Binary files /dev/null and b/assets/example_image/typical_vehicle_bulldozer.png differ diff --git a/assets/example_image/typical_vehicle_cart.png b/assets/example_image/typical_vehicle_cart.png new file mode 100644 index 0000000000000000000000000000000000000000..137bb4887f3879691ffff21227951790eb1840b4 Binary files /dev/null and b/assets/example_image/typical_vehicle_cart.png differ diff --git a/assets/example_image/typical_vehicle_excavator.png b/assets/example_image/typical_vehicle_excavator.png new file mode 100644 index 0000000000000000000000000000000000000000..c434e8b0ab142ecc35caf91d42df5f4541825b8f Binary files /dev/null and b/assets/example_image/typical_vehicle_excavator.png differ diff --git a/assets/example_image/typical_vehicle_helicopter.png b/assets/example_image/typical_vehicle_helicopter.png new file mode 100644 index 0000000000000000000000000000000000000000..39c2497d22ea519cddf576c6504954c338f943e4 Binary files /dev/null and b/assets/example_image/typical_vehicle_helicopter.png differ diff --git a/assets/example_image/typical_vehicle_locomotive.png b/assets/example_image/typical_vehicle_locomotive.png new file mode 100644 index 0000000000000000000000000000000000000000..dac6a2a2de9e8830bac53d3893aa1d3741916b1a Binary files /dev/null and b/assets/example_image/typical_vehicle_locomotive.png differ diff --git a/assets/example_image/typical_vehicle_pirate_ship.png b/assets/example_image/typical_vehicle_pirate_ship.png new file mode 100644 index 0000000000000000000000000000000000000000..9eed1529f309c64fc6237caba97631cc1f2bab53 Binary files /dev/null and b/assets/example_image/typical_vehicle_pirate_ship.png differ diff --git a/assets/example_image/weatherworn_misc_paper_machine3.png b/assets/example_image/weatherworn_misc_paper_machine3.png new file mode 100644 index 0000000000000000000000000000000000000000..46e8a9dc123aaf71a41e994a8e50eabb4e53e721 Binary files /dev/null and b/assets/example_image/weatherworn_misc_paper_machine3.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..93a5d99a3436c988b46fe75255da205d66dde789 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 +--find-links https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.4.0_cu121.html + + +torch==2.4.0 +torchvision==0.19.0 +pillow==10.4.0 +imageio==2.36.1 +imageio-ffmpeg==0.5.1 +tqdm==4.67.1 +easydict==1.13 +opencv-python-headless==4.10.0.84 +scipy==1.14.1 +rembg==2.0.60 +onnxruntime==1.20.1 +trimesh==4.5.3 +xatlas==0.0.9 +pyvista==0.44.2 +pymeshfix==0.17.0 +igraph==0.11.8 +git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8 +xformers==0.0.27.post2+cu118 +flash-attn==2.7.0.post2 +kaolin==0.17.0 +spconv-cu118==2.3.6 +transformers==4.46.3 +wheels/nvdiffrast-0.3.3-py3-none-any.whl +wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl \ No newline at end of file diff --git a/trellis/__init__.py b/trellis/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..20d240afc9c26a21aee76954628b3d4ef9a1ccbd --- /dev/null +++ b/trellis/__init__.py @@ -0,0 +1,6 @@ +from . import models +from . import modules +from . import pipelines +from . import renderers +from . import representations +from . import utils diff --git a/trellis/models/__init__.py b/trellis/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d90e9f9ab48e7028a370a0df663182f4b8ccadc5 --- /dev/null +++ b/trellis/models/__init__.py @@ -0,0 +1,70 @@ +import importlib + +__attributes = { + 'SparseStructureEncoder': 'sparse_structure_vae', + 'SparseStructureDecoder': 'sparse_structure_vae', + 'SparseStructureFlowModel': 'sparse_structure_flow', + 'SLatEncoder': 'structured_latent_vae', + 'SLatGaussianDecoder': 'structured_latent_vae', + 'SLatRadianceFieldDecoder': 'structured_latent_vae', + 'SLatMeshDecoder': 'structured_latent_vae', + 'SLatFlowModel': 'structured_latent_flow', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +def from_pretrained(path: str, **kwargs): + """ + Load a model from a pretrained checkpoint. + + Args: + path: The path to the checkpoint. Can be either local path or a Hugging Face model name. + NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. + **kwargs: Additional arguments for the model constructor. + """ + import os + import json + from safetensors.torch import load_file + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") + + if is_local: + config_file = f"{path}.json" + model_file = f"{path}.safetensors" + else: + from huggingface_hub import hf_hub_download + path_parts = path.split('/') + repo_id = f'{path_parts[0]}/{path_parts[1]}' + model_name = '/'.join(path_parts[2:]) + config_file = hf_hub_download(repo_id, f"{model_name}.json") + model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") + + with open(config_file, 'r') as f: + config = json.load(f) + model = __getattr__(config['name'])(**config['args'], **kwargs) + model.load_state_dict(load_file(model_file)) + + return model + + +# For Pylance +if __name__ == '__main__': + from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder + from .sparse_structure_flow import SparseStructureFlowModel + from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder + from .structured_latent_flow import SLatFlowModel diff --git a/trellis/models/sparse_structure_flow.py b/trellis/models/sparse_structure_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..aee71a9686fd3795960cf1df970e9b8db0ebd57a --- /dev/null +++ b/trellis/models/sparse_structure_flow.py @@ -0,0 +1,200 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to_f16, convert_module_to_f32 +from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock +from ..modules.spatial import patchify, unpatchify + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + self.register_buffer("pos_emb", pos_emb) + + self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = patchify(x, self.patch_size) + h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + h = h.type(self.dtype) + cond = cond.type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + h = h.type(x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3) + h = unpatchify(h, self.patch_size).contiguous() + + return h diff --git a/trellis/models/sparse_structure_vae.py b/trellis/models/sparse_structure_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e09136cf294c4c1b47b0f09fa6ee57bad2166d --- /dev/null +++ b/trellis/models/sparse_structure_vae.py @@ -0,0 +1,306 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..modules.norm import GroupNorm32, ChannelLayerNorm32 +from ..modules.spatial import pixel_shuffle_3d +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + DownsampleBlock3d(ch, channels[i+1]) + ) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ]) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor: + h = self.input_layer(x) + h = h.type(self.dtype) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + + mean, logvar = h.chunk(2, dim=1) + + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + + if return_raw: + return z, mean, logvar + return z + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h diff --git a/trellis/models/structured_latent_flow.py b/trellis/models/structured_latent_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..f1463d79bc472ce3ef6859a42e10a06de1f9ebf7 --- /dev/null +++ b/trellis/models/structured_latent_flow.py @@ -0,0 +1,262 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ..modules.transformer import AbsolutePositionEmbedder +from ..modules.norm import LayerNorm32 +from ..modules import sparse as sp +from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock +from .sparse_structure_flow import TimestepEmbedder + + +class SparseResBlock3d(nn.Module): + def __init__( + self, + channels: int, + emb_channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + + assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(emb_channels, 2 * self.out_channels, bias=True), + ) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.updown = None + if self.downsample: + self.updown = sp.SparseDownsample(2) + elif self.upsample: + self.updown = sp.SparseUpsample(2) + + def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.updown is not None: + x = self.updown(x) + return x + + def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor: + emb_out = self.emb_layers(emb).type(x.dtype) + scale, shift = torch.chunk(emb_out, 2, dim=1) + + x = self._updown(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + + return h + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + num_io_res_blocks: int = 2, + io_block_channels: List[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + use_skip_connection: bool = True, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.num_io_res_blocks = num_io_res_blocks + self.io_block_channels = io_block_channels + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.use_skip_connection = use_skip_connection + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = torch.float16 if use_fp16 else torch.float32 + + assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2" + assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages" + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0]) + self.input_blocks = nn.ModuleList([]) + for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]): + self.input_blocks.extend([ + SparseResBlock3d( + chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks-1) + ]) + self.input_blocks.append( + SparseResBlock3d( + chs, + model_channels, + out_channels=next_chs, + downsample=True, + ) + ) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_blocks = nn.ModuleList([]) + for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))): + self.out_blocks.append( + SparseResBlock3d( + prev_chs * 2 if self.use_skip_connection else prev_chs, + model_channels, + out_channels=chs, + upsample=True, + ) + ) + self.out_blocks.extend([ + SparseResBlock3d( + chs * 2 if self.use_skip_connection else chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks-1) + ]) + self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.blocks.apply(convert_module_to_f16) + self.out_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.blocks.apply(convert_module_to_f32) + self.out_blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor: + h = self.input_layer(x).type(self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + cond = cond.type(self.dtype) + + skips = [] + # pack with input blocks + for block in self.input_blocks: + h = block(h, t_emb) + skips.append(h.feats) + + if self.pe_mode == "ape": + h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + # unpack with output blocks + for block, skip in zip(self.out_blocks, reversed(skips)): + if self.use_skip_connection: + h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb) + else: + h = block(h, t_emb) + + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h.type(x.dtype)) + return h diff --git a/trellis/models/structured_latent_vae/__init__.py b/trellis/models/structured_latent_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75603bc1d86c3036972c3d740ca7cb93d872f836 --- /dev/null +++ b/trellis/models/structured_latent_vae/__init__.py @@ -0,0 +1,4 @@ +from .encoder import SLatEncoder +from .decoder_gs import SLatGaussianDecoder +from .decoder_rf import SLatRadianceFieldDecoder +from .decoder_mesh import SLatMeshDecoder diff --git a/trellis/models/structured_latent_vae/base.py b/trellis/models/structured_latent_vae/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0bf6a850b1c146e081c32ad92c7c44ead5ef6e --- /dev/null +++ b/trellis/models/structured_latent_vae/base.py @@ -0,0 +1,117 @@ +from typing import * +import torch +import torch.nn as nn +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from ...modules.transformer import AbsolutePositionEmbedder +from ...modules.sparse.transformer import SparseTransformerBlock + + +def block_attn_config(self): + """ + Return the attention configuration of the model. + """ + for i in range(self.num_blocks): + if self.attn_mode == "shift_window": + yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_sequence": + yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_order": + yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] + elif self.attn_mode == "full": + yield "full", None, None, None, None + elif self.attn_mode == "swin": + yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None + + +class SparseTransformerBase(nn.Module): + """ + Sparse Transformer without output layers. + Serve as the base class for encoder and decoder. + """ + def __init__( + self, + in_channels: int, + model_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.window_size = window_size + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.attn_mode = attn_mode + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.qk_rms_norm = qk_rms_norm + self.dtype = torch.float16 if use_fp16 else torch.float32 + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + self.blocks = nn.ModuleList([ + SparseTransformerBlock( + model_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + qk_rms_norm=self.qk_rms_norm, + ) + for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self) + ]) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.input_layer(x) + if self.pe_mode == "ape": + h = h + self.pos_embedder(x.coords[:, 1:]) + h = h.type(self.dtype) + for block in self.blocks: + h = block(h) + return h diff --git a/trellis/models/structured_latent_vae/decoder_gs.py b/trellis/models/structured_latent_vae/decoder_gs.py new file mode 100644 index 0000000000000000000000000000000000000000..b893cfcfb2a166c7d57f96086a79317bd91884b9 --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_gs.py @@ -0,0 +1,122 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from ...utils.random_utils import hammersley_sequence +from .base import SparseTransformerBase +from ...representations import Gaussian + + +class SLatGaussianDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + self._build_perturbation() + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _build_perturbation(self) -> None: + perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])] + perturbation = torch.tensor(perturbation).float() * 2 - 1 + perturbation = perturbation / self.rep_config['voxel_size'] + perturbation = torch.atanh(perturbation).to(self.device) + self.register_buffer('offset_perturbation', perturbation) + + def _calc_layout(self) -> None: + self.layout = { + '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4}, + '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']}, + } + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Gaussian( + sh_degree=0, + aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], + mininum_kernel_size = self.rep_config['3d_filter_kernel_size'], + scaling_bias = self.rep_config['scaling_bias'], + opacity_bias = self.rep_config['opacity_bias'], + scaling_activation = self.rep_config['scaling_activation'] + ) + xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution + for k, v in self.layout.items(): + if k == '_xyz': + offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']) + offset = offset * self.rep_config['lr'][k] + if self.rep_config['perturb_offset']: + offset = offset + self.offset_perturbation + offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size'] + _xyz = xyz.unsqueeze(1) + offset + setattr(representation, k, _xyz.flatten(0, 1)) + else: + feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) + feats = feats * self.rep_config['lr'][k] + setattr(representation, k, feats) + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Gaussian]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/decoder_mesh.py b/trellis/models/structured_latent_vae/decoder_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..75c1b1ec7b6fdc28e787be283e55589b36461e50 --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_mesh.py @@ -0,0 +1,167 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import MeshExtractResult +from ...representations.mesh import SparseFeatures2Mesh + + +class SparseSubdivideBlock3d(nn.Module): + """ + A 3D subdivide block that can subdivide the sparse tensor. + + Args: + channels: channels in the inputs and outputs. + out_channels: if specified, the number of output channels. + num_groups: the number of groups for the group norm. + """ + def __init__( + self, + channels: int, + resolution: int, + out_channels: Optional[int] = None, + num_groups: int = 32 + ): + super().__init__() + self.channels = channels + self.resolution = resolution + self.out_resolution = resolution * 2 + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseGroupNorm32(num_groups, channels), + sp.SparseSiLU() + ) + + self.sub = sp.SparseSubdivide() + + self.out_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"), + sp.SparseGroupNorm32(num_groups, self.out_channels), + sp.SparseSiLU(), + zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}") + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Args: + x: an [N x C x ...] Tensor of features. + Returns: + an [N x C x ...] Tensor of outputs. + """ + h = self.act_layers(x) + h = self.sub(h) + x = self.sub(x) + h = self.out_layers(h) + h = h + self.skip_connection(x) + return h + + +class SLatMeshDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False)) + self.out_channels = self.mesh_extractor.feats_channels + self.upsample = nn.ModuleList([ + SparseSubdivideBlock3d( + channels=model_channels, + resolution=resolution, + out_channels=model_channels // 4 + ), + SparseSubdivideBlock3d( + channels=model_channels // 4, + resolution=resolution * 2, + out_channels=model_channels // 8 + ) + ]) + self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + super().convert_to_fp16() + self.upsample.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + super().convert_to_fp32() + self.upsample.apply(convert_module_to_f32) + + def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + mesh = self.mesh_extractor(x[i], training=self.training) + ret.append(mesh) + return ret + + def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + h = super().forward(x) + for block in self.upsample: + h = block(h) + h = h.type(x.dtype) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/decoder_rf.py b/trellis/models/structured_latent_vae/decoder_rf.py new file mode 100644 index 0000000000000000000000000000000000000000..968bb30596647224292da0392dfdefeed49d214d --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_rf.py @@ -0,0 +1,104 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import Strivec + + +class SLatRadianceFieldDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _calc_layout(self) -> None: + self.layout = { + 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']}, + 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']}, + 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3}, + } + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Strivec]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Strivec( + sh_degree=0, + resolution=self.resolution, + aabb=[-0.5, -0.5, -0.5, 1, 1, 1], + rank=self.rep_config['rank'], + dim=self.rep_config['dim'], + device='cuda', + ) + representation.density_shift = 0.0 + representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution + representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda') + for k, v in self.layout.items(): + setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])) + representation.trivec = representation.trivec + 1 + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Strivec]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/trellis/models/structured_latent_vae/encoder.py b/trellis/models/structured_latent_vae/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8370921d8d61954b43dcf3e251b8d9b315f4f536 --- /dev/null +++ b/trellis/models/structured_latent_vae/encoder.py @@ -0,0 +1,72 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .base import SparseTransformerBase + + +class SLatEncoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__( + in_channels=in_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False): + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z diff --git a/trellis/modules/attention/__init__.py b/trellis/modules/attention/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..f452320d5dbc4c0aa1664e33f76c56ff4bbe2039 --- /dev/null +++ b/trellis/modules/attention/__init__.py @@ -0,0 +1,36 @@ +from typing import * + +BACKEND = 'flash_attn' +DEBUG = False + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get('ATTN_BACKEND') + env_sttn_debug = os.environ.get('ATTN_DEBUG') + + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: + BACKEND = env_attn_backend + if env_sttn_debug is not None: + DEBUG = env_sttn_debug == '1' + + print(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal['xformers', 'flash_attn']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + + +from .full_attn import * +from .modules import * diff --git a/trellis/modules/attention/full_attn.py b/trellis/modules/attention/full_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..d9ebf6380a78906d4c6e969c63223fb7b398e5a7 --- /dev/null +++ b/trellis/modules/attention/full_attn.py @@ -0,0 +1,140 @@ +from typing import * +import torch +import math +from . import DEBUG, BACKEND + +if BACKEND == 'xformers': + import xformers.ops as xops +elif BACKEND == 'flash_attn': + import flash_attn +elif BACKEND == 'sdpa': + from torch.nn.functional import scaled_dot_product_attention as sdpa +elif BACKEND == 'naive': + pass +else: + raise ValueError(f"Unknown attention backend: {BACKEND}") + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if BACKEND == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif BACKEND == 'flash_attn': + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif BACKEND == 'sdpa': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {BACKEND}") + + return out diff --git a/trellis/modules/attention/modules.py b/trellis/modules/attention/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..dbe6235c27134f0477e48d3e12de3068c6a500ef --- /dev/null +++ b/trellis/modules/attention/modules.py @@ -0,0 +1,146 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class RotaryPositionEmbedder(nn.Module): + def __init__(self, hidden_size: int, in_channels: int = 3): + super().__init__() + assert hidden_size % 2 == 0, "Hidden size must be divisible by 2" + self.hidden_size = hidden_size + self.in_channels = in_channels + self.freq_dim = hidden_size // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (sp.SparseTensor): [..., N, D] tensor of queries + k (sp.SparseTensor): [..., N, D] tensor of keys + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + if indices is None: + indices = torch.arange(q.shape[-2], device=q.device) + if len(q.shape) > 2: + indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) + + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[1] < self.hidden_size // 2: + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device), + torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device) + )], dim=-1) + q_embed = self._rotary_embedding(q, phases) + k_embed = self._rotary_embedding(k, phases) + return q_embed, k_embed + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + if self.use_rope: + q, k, v = qkv.unbind(dim=2) + q, k = self.rope(q, k, indices) + qkv = torch.stack([q, k, v], dim=2) + if self.attn_mode == "full": + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=2) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h diff --git a/trellis/modules/norm.py b/trellis/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..09035726081fb7afda2c62504d5474cfa483c58f --- /dev/null +++ b/trellis/modules/norm.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/trellis/modules/sparse/__init__.py b/trellis/modules/sparse/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..5b04954c494a12c9534e36724d0113cb67916788 --- /dev/null +++ b/trellis/modules/sparse/__init__.py @@ -0,0 +1,100 @@ +from typing import * + +BACKEND = 'spconv' +DEBUG = False +ATTN = 'flash_attn' + +def __from_env(): + import os + + global BACKEND + global DEBUG + global ATTN + + env_sparse_backend = os.environ.get('SPARSE_BACKEND') + env_sparse_debug = os.environ.get('SPARSE_DEBUG') + env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND') + if env_sparse_attn is None: + env_sparse_attn = os.environ.get('ATTN_BACKEND') + + if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']: + BACKEND = env_sparse_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == '1' + if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']: + ATTN = env_sparse_attn + + +__from_env() + + +def set_backend(backend: Literal['spconv', 'torchsparse']): + global BACKEND + BACKEND = backend + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + +def set_attn(attn: Literal['xformers', 'flash_attn']): + global ATTN + ATTN = attn + + +import importlib + +__attributes = { + 'SparseTensor': 'basic', + 'sparse_batch_broadcast': 'basic', + 'sparse_batch_op': 'basic', + 'sparse_cat': 'basic', + 'sparse_unbind': 'basic', + 'SparseGroupNorm': 'norm', + 'SparseLayerNorm': 'norm', + 'SparseGroupNorm32': 'norm', + 'SparseLayerNorm32': 'norm', + 'SparseReLU': 'nonlinearity', + 'SparseSiLU': 'nonlinearity', + 'SparseGELU': 'nonlinearity', + 'SparseActivation': 'nonlinearity', + 'SparseLinear': 'linear', + 'sparse_scaled_dot_product_attention': 'attention', + 'SerializeMode': 'attention', + 'sparse_serialized_scaled_dot_product_self_attention': 'attention', + 'sparse_windowed_scaled_dot_product_self_attention': 'attention', + 'SparseMultiHeadAttention': 'attention', + 'SparseConv3d': 'conv', + 'SparseInverseConv3d': 'conv', + 'SparseDownsample': 'spatial', + 'SparseUpsample': 'spatial', + 'SparseSubdivide' : 'spatial' +} + +__submodules = ['transformer'] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .basic import * + from .norm import * + from .nonlinearity import * + from .linear import * + from .attention import * + from .conv import * + from .spatial import * + import transformer diff --git a/trellis/modules/sparse/attention/__init__.py b/trellis/modules/sparse/attention/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..32b3c2c837c613e41755ac4c85f9ed057a6f5bfb --- /dev/null +++ b/trellis/modules/sparse/attention/__init__.py @@ -0,0 +1,4 @@ +from .full_attn import * +from .serialized_attn import * +from .windowed_attn import * +from .modules import * diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..e9e27aeb98419621f3f9999fd3b11eebf2b90a40 --- /dev/null +++ b/trellis/modules/sparse/attention/full_attn.py @@ -0,0 +1,215 @@ +from typing import * +import torch +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, SparseTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, SparseTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if DEBUG: + if s is not None: + for i in range(s.shape[0]): + assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" + if num_all_args in [2, 3]: + assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" + if num_all_args == 3: + assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" + assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" + + if ATTN == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif ATTN == 'flash_attn': + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + else: + raise ValueError(f"Unknown attention module: {ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/trellis/modules/sparse/attention/modules.py b/trellis/modules/sparse/attention/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..5d2fe782b0947700e308e9ec0325e7e91c84e3c2 --- /dev/null +++ b/trellis/modules/sparse/attention/modules.py @@ -0,0 +1,139 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from ...attention import RotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, SparseTensor): + x = x.replace(F.normalize(x.feats, dim=-1)) + else: + x = F.normalize(x, dim=-1) + return (x * self.gamma * self.scale).to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "serialized", "windowed"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}" + assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" + assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention" + self.channels = channels + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_sequence = shift_sequence + self.shift_window = shift_window + self.serialize_mode = serialize_mode + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + @staticmethod + def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats + + def _rope(self, qkv: SparseTensor) -> SparseTensor: + q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] + q, k = self.rope(q, k, qkv.coords[:, 1:]) + qkv = qkv.replace(torch.stack([q, k, v], dim=1)) + return qkv + + def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.use_rope: + qkv = self._rope(qkv) + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=1) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "serialized": + h = sparse_serialized_scaled_dot_product_self_attention( + qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window + ) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=1) + k = self.k_rms_norm(k) + kv = kv.replace(torch.stack([k.feats, v.feats], dim=1)) + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..5950b75b2f5a6d6e79ab6d472b8501aaa5ec4a26 --- /dev/null +++ b/trellis/modules/sparse/attention/serialized_attn.py @@ -0,0 +1,193 @@ +from typing import * +from enum import Enum +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_serialized_scaled_dot_product_self_attention', +] + + +class SerializeMode(Enum): + Z_ORDER = 0 + Z_ORDER_TRANSPOSED = 1 + HILBERT = 2 + HILBERT_TRANSPOSED = 3 + + +SerializeModes = [ + SerializeMode.Z_ORDER, + SerializeMode.Z_ORDER_TRANSPOSED, + SerializeMode.HILBERT, + SerializeMode.HILBERT_TRANSPOSED +] + + +def calc_serialization( + tensor: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (torch.Tensor, torch.Tensor): Forwards and backwards indices. + """ + fwd_indices = [] + bwd_indices = [] + seq_lens = [] + seq_batch_indices = [] + offsets = [0] + + if 'vox2seq' not in globals(): + import vox2seq + + # Serialize the input + serialize_coords = tensor.coords[:, 1:].clone() + serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) + if serialize_mode == SerializeMode.Z_ORDER: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) + elif serialize_mode == SerializeMode.HILBERT: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) + else: + raise ValueError(f"Unknown serialize mode: {serialize_mode}") + + for bi, s in enumerate(tensor.layout): + num_points = s.stop - s.start + num_windows = (num_points + window_size - 1) // window_size + valid_window_size = num_points / num_windows + to_ordered = torch.argsort(code[s.start:s.stop]) + if num_windows == 1: + fwd_indices.append(to_ordered) + bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) + fwd_indices[-1] += s.start + bwd_indices[-1] += offsets[-1] + seq_lens.append(num_points) + seq_batch_indices.append(bi) + offsets.append(offsets[-1] + seq_lens[-1]) + else: + # Partition the input + offset = 0 + mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] + split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] + bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) + for i in range(num_windows): + mid = mids[i] + valid_start = split[i] + valid_end = split[i + 1] + padded_start = math.floor(mid - 0.5 * window_size) + padded_end = padded_start + window_size + fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) + offset += valid_start - padded_start + bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) + offset += padded_end - valid_start + fwd_indices[-1] += s.start + seq_lens.extend([window_size] * num_windows) + seq_batch_indices.extend([bi] * num_windows) + bwd_indices.append(bwd_index + offsets[-1]) + offsets.append(offsets[-1] + num_windows * window_size) + + fwd_indices = torch.cat(fwd_indices) + bwd_indices = torch.cat(bwd_indices) + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_serialized_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply serialized scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/trellis/modules/sparse/attention/windowed_attn.py b/trellis/modules/sparse/attention/windowed_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..cd642c5252e29a3a5e59fad7ed3880b7b00bcf9a --- /dev/null +++ b/trellis/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,135 @@ +from typing import * +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_windowed_scaled_dot_product_self_attention', +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0 +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (List[int]): Sequence lengths. + (List[int]): Sequence batch indices. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0] + mask = seq_lens != 0 + seq_lens = seq_lens[mask].tolist() + seq_batch_indices = seq_batch_indices[mask].tolist() + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start:start+seq_lens[i]] + assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \ + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/trellis/modules/sparse/basic.py b/trellis/modules/sparse/basic.py new file mode 100755 index 0000000000000000000000000000000000000000..8837f44052f6d573d09e3bfb897e659e10516bb5 --- /dev/null +++ b/trellis/modules/sparse/basic.py @@ -0,0 +1,459 @@ +from typing import * +import torch +import torch.nn as nn +from . import BACKEND, DEBUG +SparseTensorData = None # Lazy import + + +__all__ = [ + 'SparseTensor', + 'sparse_batch_broadcast', + 'sparse_batch_op', + 'sparse_cat', + 'sparse_unbind', +] + + +class SparseTensor: + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + global SparseTensorData + if SparseTensorData is None: + import importlib + if BACKEND == 'torchsparse': + SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif BACKEND == 'spconv': + SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape, layout = args + (None,) * (4 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + if shape is None: + shape = self.__cal_shape(feats, coords) + if layout is None: + layout = self.__cal_layout(coords, shape[0]) + if BACKEND == 'torchsparse': + self.data = SparseTensorData(feats, coords, **kwargs) + elif BACKEND == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1)[1:] + self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs) + self.data._features = feats + elif method_id == 1: + data, shape, layout = args + (None,) * (3 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + if 'layout' in kwargs: + layout = kwargs['layout'] + del kwargs['layout'] + + self.data = data + if shape is None: + shape = self.__cal_shape(self.feats, self.coords) + if layout is None: + layout = self.__cal_layout(self.coords, shape[0]) + + self._shape = shape + self._layout = layout + self._scale = kwargs.get('scale', (1, 1, 1)) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + if DEBUG: + try: + assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous" + except Exception as e: + print('Debugging information:') + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + @property + def shape(self) -> torch.Size: + return self._shape + + def dim(self) -> int: + return len(self.shape) + + @property + def layout(self) -> List[slice]: + return self._layout + + @property + def feats(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.F + elif BACKEND == 'spconv': + return self.data.features + + @feats.setter + def feats(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.F = value + elif BACKEND == 'spconv': + self.data.features = value + + @property + def coords(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.C + elif BACKEND == 'spconv': + return self.data.indices + + @coords.setter + def coords(self, value: torch.Tensor): + if BACKEND == 'torchsparse': + self.data.C = value + elif BACKEND == 'spconv': + self.data.indices = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @overload + def to(self, dtype: torch.dtype) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + + new_feats = self.feats.to(device=device, dtype=dtype) + new_coords = self.coords.to(device=device) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def dense(self) -> torch.Tensor: + if BACKEND == 'torchsparse': + return self.data.dense() + elif BACKEND == 'spconv': + return self.data.dense() + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + new_shape = [self.shape[0]] + new_shape.extend(feats.shape[1:]) + if BACKEND == 'torchsparse': + new_data = SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif BACKEND == 'spconv': + new_data = SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) + return new_tensor + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __neg__(self) -> 'SparseTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = sparse_batch_broadcast(self, other) + except: + pass + if isinstance(other, SparseTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + coords = [] + feats = [] + for new_idx, old_idx in enumerate(idx): + coords.append(self.coords[self.layout[old_idx]].clone()) + coords[-1][:, 0] = new_idx + feats.append(self.feats[self.layout[old_idx]]) + coords = torch.cat(coords, dim=0).contiguous() + feats = torch.cat(feats, dim=0).contiguous() + return SparseTensor(feats=feats, coords=coords) + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + +def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + coords, feats = input.coords, input.feats + broadcasted = torch.zeros_like(feats) + for k in range(input.shape[0]): + broadcasted[input.layout[k]] = other[k] + return broadcasted + + +def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) + + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/trellis/modules/sparse/conv/__init__.py b/trellis/modules/sparse/conv/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..8fe2fefb0823b7e489d75d392b2cf6d1a6fa34e7 --- /dev/null +++ b/trellis/modules/sparse/conv/__init__.py @@ -0,0 +1,6 @@ +from .. import BACKEND + +if BACKEND == 'torchsparse': + from .conv_torchsparse import * +elif BACKEND == 'spconv': + from .conv_spconv import * \ No newline at end of file diff --git a/trellis/modules/sparse/conv/conv_spconv.py b/trellis/modules/sparse/conv/conv_spconv.py new file mode 100755 index 0000000000000000000000000000000000000000..6cfa3d33d20fe7d9f980b4278b0f9c5a747bec81 --- /dev/null +++ b/trellis/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from .. import DEBUG + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features[fwd] + sorted_coords = new_data.indices[fwd] + unsorted_data = new_data + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + if DEBUG: + assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/trellis/modules/sparse/conv/conv_torchsparse.py b/trellis/modules/sparse/conv/conv_torchsparse.py new file mode 100755 index 0000000000000000000000000000000000000000..1d612582d4b31f90aca3c00b693bbbc2550dc62c --- /dev/null +++ b/trellis/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +from .. import SparseTensor + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + + diff --git a/trellis/modules/sparse/linear.py b/trellis/modules/sparse/linear.py new file mode 100755 index 0000000000000000000000000000000000000000..a854e77ce87d1a190b9730d91f363a821ff250bd --- /dev/null +++ b/trellis/modules/sparse/linear.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseLinear' +] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) diff --git a/trellis/modules/sparse/nonlinearity.py b/trellis/modules/sparse/nonlinearity.py new file mode 100755 index 0000000000000000000000000000000000000000..f200098dd82011a3aeee1688b9eb17018fa78295 --- /dev/null +++ b/trellis/modules/sparse/nonlinearity.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' +] + + +class SparseReLU(nn.ReLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseGELU(nn.GELU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(self.activation(input.feats)) + diff --git a/trellis/modules/sparse/norm.py b/trellis/modules/sparse/norm.py new file mode 100755 index 0000000000000000000000000000000000000000..6b38a36682c098210000dc31d68ddc31ccd2929d --- /dev/null +++ b/trellis/modules/sparse/norm.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from . import SparseTensor +from . import DEBUG + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + if DEBUG: + assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) diff --git a/trellis/modules/sparse/spatial.py b/trellis/modules/sparse/spatial.py new file mode 100755 index 0000000000000000000000000000000000000000..ad7121473f335b307e2f7ea5f05c964d3aec0440 --- /dev/null +++ b/trellis/modules/sparse/spatial.py @@ -0,0 +1,110 @@ +from typing import * +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseDownsample', + 'SparseUpsample', + 'SparseSubdivide' +] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): + super(SparseDownsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' + + coord = list(input.coords.unbind(dim=-1)) + for i, f in enumerate(factor): + coord[i+1] = coord[i+1] // f + + MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_feats = torch.scatter_reduce( + torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), + dim=0, + index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), + src=input.feats, + reduce='mean' + ) + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + out = SparseTensor(new_feats, new_coords, input.shape,) + out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + + out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) + out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) + out.register_spatial_cache(f'upsample_{factor}_idx', idx) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): + super(SparseUpsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' + + new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') + new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') + idx = input.get_spatial_cache(f'upsample_{factor}_idx') + if any([x is None for x in [new_coords, new_layout, idx]]): + raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') + new_feats = input.feats[idx] + out = SparseTensor(new_feats, new_coords, input.shape, new_layout) + out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + return out + +class SparseSubdivide(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + def __init__(self): + super(SparseSubdivide, self).__init__() + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + # upsample scale=2^DIM + n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) + n_coords = torch.nonzero(n_cube) + n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) + factor = n_coords.shape[0] + assert factor == 2 ** DIM + # print(n_coords.shape) + new_coords = input.coords.clone() + new_coords[:, 1:] *= 2 + new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) + + new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) + out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) + out._scale = input._scale * 2 + out._spatial_cache = input._spatial_cache + return out + diff --git a/trellis/modules/sparse/transformer/__init__.py b/trellis/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/trellis/modules/sparse/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis/modules/sparse/transformer/blocks.py b/trellis/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..9d037a49bf83e1c2dfb2f8c4b23d2e9d6c51e9f0 --- /dev/null +++ b/trellis/modules/sparse/transformer/blocks.py @@ -0,0 +1,151 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: SparseTensor) -> SparseTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) diff --git a/trellis/modules/sparse/transformer/modulated.py b/trellis/modules/sparse/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..4a8416559f39acbed9e5996e9891c97f95c80c8f --- /dev/null +++ b/trellis/modules/sparse/transformer/modulated.py @@ -0,0 +1,166 @@ +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) diff --git a/trellis/modules/spatial.py b/trellis/modules/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..79e268d36c2ba49b0275744022a1a1e19983dae3 --- /dev/null +++ b/trellis/modules/spatial.py @@ -0,0 +1,48 @@ +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], [])) + x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)])) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:])) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)]) + return x diff --git a/trellis/modules/transformer/__init__.py b/trellis/modules/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08b0d4e5bc24060a2cdc8df75d06dce122972bd --- /dev/null +++ b/trellis/modules/transformer/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .modulated import * \ No newline at end of file diff --git a/trellis/modules/transformer/blocks.py b/trellis/modules/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..c37eb7ed92f4aacfc9e974a63b247589d95977da --- /dev/null +++ b/trellis/modules/transformer/blocks.py @@ -0,0 +1,182 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000 ** self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + N, D = x.shape + assert D == self.in_channels, "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False) + else: + return self._forward(x) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor): + h = self.norm1(x) + h = self.self_attn(h) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False) + else: + return self._forward(x, context) + \ No newline at end of file diff --git a/trellis/modules/transformer/modulated.py b/trellis/modules/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..d4aeca0689e68f656b08f7aa822b7be839aa727d --- /dev/null +++ b/trellis/modules/transformer/modulated.py @@ -0,0 +1,157 @@ +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False) + else: + return self._forward(x, mod, context) + \ No newline at end of file diff --git a/trellis/modules/utils.py b/trellis/modules/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..f0afb1b6c767aa2ad00bad96649fb30315e696ea --- /dev/null +++ b/trellis/modules/utils.py @@ -0,0 +1,54 @@ +import torch.nn as nn +from ..modules import sparse as sp + +FP16_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, +) + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) diff --git a/trellis/pipelines/__init__.py b/trellis/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9e8548b894aeb3d354c739320ed3288be9c7b0e --- /dev/null +++ b/trellis/pipelines/__init__.py @@ -0,0 +1,24 @@ +from . import samplers +from .trellis_image_to_3d import TrellisImageTo3DPipeline + + +def from_pretrained(path: str): + """ + Load a pipeline from a model folder or a Hugging Face model hub. + + Args: + path: The path to the model. Can be either local path or a Hugging Face model name. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + config = json.load(f) + return globals()[config['name']].from_pretrained(path) diff --git a/trellis/pipelines/base.py b/trellis/pipelines/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9e0df4ec5fb915d57d30189cac854e3f095620 --- /dev/null +++ b/trellis/pipelines/base.py @@ -0,0 +1,66 @@ +from typing import * +import torch +import torch.nn as nn +from .. import models + + +class Pipeline: + """ + A base class for pipelines. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + ): + if models is None: + return + self.models = models + for model in self.models.values(): + model.eval() + + @staticmethod + def from_pretrained(path: str) -> "Pipeline": + """ + Load a pretrained model. + """ + import os + import json + is_local = os.path.exists(f"{path}/pipeline.json") + + if is_local: + config_file = f"{path}/pipeline.json" + else: + from huggingface_hub import hf_hub_download + config_file = hf_hub_download(path, "pipeline.json") + + with open(config_file, 'r') as f: + args = json.load(f)['args'] + + _models = { + k: models.from_pretrained(f"{path}/{v}") + for k, v in args['models'].items() + } + + new_pipeline = Pipeline(_models) + new_pipeline._pretrained_args = args + return new_pipeline + + @property + def device(self) -> torch.device: + for model in self.models.values(): + if hasattr(model, 'device'): + return model.device + for model in self.models.values(): + if hasattr(model, 'parameters'): + return next(model.parameters()).device + raise RuntimeError("No device found.") + + def to(self, device: torch.device) -> None: + for model in self.models.values(): + model.to(device) + + def cuda(self) -> None: + self.to(torch.device("cuda")) + + def cpu(self) -> None: + self.to(torch.device("cpu")) diff --git a/trellis/pipelines/samplers/__init__.py b/trellis/pipelines/samplers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..54d412fc5d8eb662081a92a56ad078243988c2f9 --- /dev/null +++ b/trellis/pipelines/samplers/__init__.py @@ -0,0 +1,2 @@ +from .base import Sampler +from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler \ No newline at end of file diff --git a/trellis/pipelines/samplers/base.py b/trellis/pipelines/samplers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1966ce787009a5ee0c1ed06dce491525ff1dbcbf --- /dev/null +++ b/trellis/pipelines/samplers/base.py @@ -0,0 +1,20 @@ +from typing import * +from abc import ABC, abstractmethod + + +class Sampler(ABC): + """ + A base class for samplers. + """ + + @abstractmethod + def sample( + self, + model, + **kwargs + ): + """ + Sample from a model. + """ + pass + \ No newline at end of file diff --git a/trellis/pipelines/samplers/classifier_free_guidance_mixin.py b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..5701b25f5d7a2197612eb256f8ee13e8c489da1f --- /dev/null +++ b/trellis/pipelines/samplers/classifier_free_guidance_mixin.py @@ -0,0 +1,12 @@ +from typing import * + + +class ClassifierFreeGuidanceSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs): + pred = super()._inference_model(model, x_t, t, cond, **kwargs) + neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred diff --git a/trellis/pipelines/samplers/flow_euler.py b/trellis/pipelines/samplers/flow_euler.py new file mode 100644 index 0000000000000000000000000000000000000000..d79124cf1b07515e8f0b88684e271028b1e3a71d --- /dev/null +++ b/trellis/pipelines/samplers/flow_euler.py @@ -0,0 +1,199 @@ +from typing import * +import torch +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from .base import Sampler +from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin +from .guidance_interval_mixin import GuidanceIntervalSamplerMixin + + +class FlowEulerSampler(Sampler): + """ + Generate samples from a flow-matching model using Euler sampling. + + Args: + sigma_min: The minimum scale of noise in flow. + """ + def __init__( + self, + sigma_min: float, + ): + self.sigma_min = sigma_min + + def _eps_to_xstart(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) + + def _xstart_to_eps(self, x_t, t, x_0): + assert x_t.shape == x_0.shape + return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _v_to_xstart_eps(self, x_t, t, v): + assert x_t.shape == v.shape + eps = (1 - t) * v + x_t + x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + return x_0, eps + + def _inference_model(self, model, x_t, t, cond=None, **kwargs): + t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + return model(x_t, t, cond, **kwargs) + + def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): + pred_v = self._inference_model(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) + return pred_x_0, pred_eps, pred_v + + @torch.no_grad() + def sample_once( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) + + @torch.no_grad() + def sample( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + +class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs) + + +class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) diff --git a/trellis/pipelines/samplers/guidance_interval_mixin.py b/trellis/pipelines/samplers/guidance_interval_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..7074a4d5fea20a8f799416aa6571faca4f9eea06 --- /dev/null +++ b/trellis/pipelines/samplers/guidance_interval_mixin.py @@ -0,0 +1,15 @@ +from typing import * + + +class GuidanceIntervalSamplerMixin: + """ + A mixin class for samplers that apply classifier-free guidance with interval. + """ + + def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): + if cfg_interval[0] <= t <= cfg_interval[1]: + pred = super()._inference_model(model, x_t, t, cond, **kwargs) + neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs) + return (1 + cfg_strength) * pred - cfg_strength * neg_pred + else: + return super()._inference_model(model, x_t, t, cond, **kwargs) diff --git a/trellis/pipelines/trellis_image_to_3d.py b/trellis/pipelines/trellis_image_to_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..59d5894d870b8c03e014c0c29bb47873c11c8c84 --- /dev/null +++ b/trellis/pipelines/trellis_image_to_3d.py @@ -0,0 +1,281 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from torchvision import transforms +from PIL import Image +import rembg +from .base import Pipeline +from . import samplers +from ..modules import sparse as sp +from ..representations import Gaussian, Strivec, MeshExtractResult + + +class TrellisImageTo3DPipeline(Pipeline): + """ + Pipeline for inferring Trellis image-to-3D models. + + Args: + models (dict[str, nn.Module]): The models to use in the pipeline. + sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. + slat_sampler (samplers.Sampler): The sampler for the structured latent. + slat_normalization (dict): The normalization parameters for the structured latent. + image_cond_model (str): The name of the image conditioning model. + """ + def __init__( + self, + models: dict[str, nn.Module] = None, + sparse_structure_sampler: samplers.Sampler = None, + slat_sampler: samplers.Sampler = None, + slat_normalization: dict = None, + image_cond_model: str = None, + ): + if models is None: + return + super().__init__(models) + self.sparse_structure_sampler = sparse_structure_sampler + self.slat_sampler = slat_sampler + self.sparse_structure_sampler_params = {} + self.slat_sampler_params = {} + self.slat_normalization = slat_normalization + self.rembg_session = None + self._init_image_cond_model(image_cond_model) + + @staticmethod + def from_pretrained(path: str) -> "TrellisImageTo3DPipeline": + """ + Load a pretrained model. + + Args: + path (str): The path to the model. Can be either local path or a Hugging Face repository. + """ + pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path) + new_pipeline = TrellisImageTo3DPipeline() + new_pipeline.__dict__ = pipeline.__dict__ + args = pipeline._pretrained_args + + new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) + new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] + + new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) + new_pipeline.slat_sampler_params = args['slat_sampler']['params'] + + new_pipeline.slat_normalization = args['slat_normalization'] + + new_pipeline._init_image_cond_model(args['image_cond_model']) + + return new_pipeline + + def _init_image_cond_model(self, name: str): + """ + Initialize the image conditioning model. + """ + dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True) + dinov2_model.eval() + self.models['image_cond_model'] = dinov2_model + transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + self.image_cond_model_transform = transform + + def preprocess_image(self, input: Image.Image) -> Image.Image: + """ + Preprocess the input image. + """ + # if has alpha channel, use it directly; otherwise, remove background + has_alpha = False + if input.mode == 'RGBA': + alpha = np.array(input)[:, :, 3] + if not np.all(alpha == 255): + has_alpha = True + if has_alpha: + output = input + else: + input = input.convert('RGB') + max_size = max(input.size) + scale = min(1, 1024 / max_size) + if scale < 1: + input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + if getattr(self, 'rembg_session', None) is None: + self.rembg_session = rembg.new_session('u2net') + output = rembg.remove(input, session=self.rembg_session) + output_np = np.array(output) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1.2) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = output.crop(bbox) # type: ignore + output = output.resize((518, 518), Image.Resampling.LANCZOS) + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + output = Image.fromarray((output * 255).astype(np.uint8)) + return output + + @torch.no_grad() + def encode_image(self, image: Union[torch.Tensor, list[Image.Image]]) -> torch.Tensor: + """ + Encode the image. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image to encode + + Returns: + torch.Tensor: The encoded features. + """ + if isinstance(image, torch.Tensor): + assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" + elif isinstance(image, list): + assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" + image = [i.resize((518, 518), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).to(self.device) + else: + raise ValueError(f"Unsupported type of image: {type(image)}") + + image = self.image_cond_model_transform(image).to(self.device) + features = self.models['image_cond_model'](image, is_training=True)['x_prenorm'] + patchtokens = F.layer_norm(features, features.shape[-1:]) + return patchtokens + + def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + cond = self.encode_image(image) + neg_cond = torch.zeros_like(cond) + return { + 'cond': cond, + 'neg_cond': neg_cond, + } + + def sample_sparse_structure( + self, + cond: dict, + num_samples: int = 1, + sampler_params: dict = {}, + ) -> torch.Tensor: + """ + Sample sparse structures with the given conditioning. + + Args: + cond (dict): The conditioning information. + num_samples (int): The number of samples to generate. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample occupancy latent + flow_model = self.models['sparse_structure_flow_model'] + reso = flow_model.resolution + noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) + sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} + z_s = self.sparse_structure_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True + ).samples + + # Decode occupancy latent + decoder = self.models['sparse_structure_decoder'] + coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() + + return coords + + def decode_slat( + self, + slat: sp.SparseTensor, + formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + ) -> dict: + """ + Decode the structured latent. + + Args: + slat (sp.SparseTensor): The structured latent. + formats (List[str]): The formats to decode the structured latent to. + + Returns: + dict: The decoded structured latent. + """ + ret = {} + if 'mesh' in formats: + ret['mesh'] = self.models['slat_decoder_mesh'](slat) + if 'gaussian' in formats: + ret['gaussian'] = self.models['slat_decoder_gs'](slat) + if 'radiance_field' in formats: + ret['radiance_field'] = self.models['slat_decoder_rf'](slat) + return ret + + def sample_slat( + self, + cond: dict, + coords: torch.Tensor, + sampler_params: dict = {}, + ) -> sp.SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + flow_model = self.models['slat_flow_model'] + noise = sp.SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.slat_sampler_params, **sampler_params} + slat = self.slat_sampler.sample( + flow_model, + noise, + **cond, + **sampler_params, + verbose=True + ).samples + + std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + + return slat + + @torch.no_grad() + def __call__( + self, + image: Image.Image, + num_samples: int = 1, + sparse_structure_sampler_params: dict = {}, + slat_sampler_params: dict = {}, + formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], + preprocess_image: bool = True, + ) -> dict: + """ + Run the pipeline. + + Args: + image (Image.Image): The image prompt. + num_samples (int): The number of samples to generate. + sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. + slat_sampler_params (dict): Additional parameters for the structured latent sampler. + preprocess_image (bool): Whether to preprocess the image. + """ + if preprocess_image: + image = self.preprocess_image(image) + cond = self.get_cond([image]) + coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params) + slat = self.sample_slat(cond, coords, slat_sampler_params) + return self.decode_slat(slat, formats) diff --git a/trellis/renderers/__init__.py b/trellis/renderers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..0339355c56b8d17f72e926650d140a658452fbe9 --- /dev/null +++ b/trellis/renderers/__init__.py @@ -0,0 +1,31 @@ +import importlib + +__attributes = { + 'OctreeRenderer': 'octree_renderer', + 'GaussianRenderer': 'gaussian_render', + 'MeshRenderer': 'mesh_renderer', +} + +__submodules = [] + +__all__ = list(__attributes.keys()) + __submodules + +def __getattr__(name): + if name not in globals(): + if name in __attributes: + module_name = __attributes[name] + module = importlib.import_module(f".{module_name}", __name__) + globals()[name] = getattr(module, name) + elif name in __submodules: + module = importlib.import_module(f".{name}", __name__) + globals()[name] = module + else: + raise AttributeError(f"module {__name__} has no attribute {name}") + return globals()[name] + + +# For Pylance +if __name__ == '__main__': + from .octree_renderer import OctreeRenderer + from .gaussian_render import GaussianRenderer + from .mesh_renderer import MeshRenderer \ No newline at end of file diff --git a/trellis/renderers/gaussian_render.py b/trellis/renderers/gaussian_render.py new file mode 100755 index 0000000000000000000000000000000000000000..57108e3cccf6aab8e3059431557c461de46aff1a --- /dev/null +++ b/trellis/renderers/gaussian_render.py @@ -0,0 +1,231 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from easydict import EasyDict as edict +import numpy as np +from ..representations.gaussian import Gaussian +from .sh_utils import eval_sh +import torch.nn.functional as F +from easydict import EasyDict as edict + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if 'GaussianRasterizer' not in globals(): + from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + kernel_size = pipe.kernel_size + subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda") + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + kernel_size=kernel_size, + subpixel_offset=subpixel_offset, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = pc.get_xyz + means2D = screenspace_points + opacity = pc.get_opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if override_color is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) + dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) + dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + shs = pc.get_features + else: + colors_precomp = override_color + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp + ) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return edict({"render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "radii": radii}) + + +class GaussianRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.pipe = edict({ + "kernel_size": 0.1, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "scale_modifier": 1.0, + "debug": False + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + gausssian: Gaussian, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None + ) -> edict: + """ + Render the gausssian. + + Args: + gaussian : gaussianmodule + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color image + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict({ + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera + }) + + # Render + render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier) + + if ssaa > 1: + render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + ret = edict({ + 'color': render_ret['render'] + }) + return ret diff --git a/trellis/renderers/mesh_renderer.py b/trellis/renderers/mesh_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..837094cf8f2125b212d2bdd61a05d99fa39358a1 --- /dev/null +++ b/trellis/renderers/mesh_renderer.py @@ -0,0 +1,144 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +try: + import kaolin as kal + import nvdiffrast.torch as dr +except : + print("Kaolin and nvdiffrast are not installed. Please install them to use the mesh renderer.") +from easydict import EasyDict as edict +from ..representations.mesh import MeshExtractResult +import torch.nn.functional as F + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +class MeshRenderer: + """ + Renderer for the Mesh representation. + + Args: + rendering_options (dict): Rendering options. + glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop. + """ + def __init__(self, rendering_options={}, device='cuda'): + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1 + }) + self.rendering_options.update(rendering_options) + self.glctx = dr.RasterizeCudaContext(device=device) + self.device=device + + def render( + self, + mesh : MeshExtractResult, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_types = ["mask", "normal", "depth"] + ) -> edict: + """ + Render the mesh. + + Args: + mesh : meshmodel + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + return_types (list): list of return types, can be "mask", "depth", "normal_map", "normal", "color" + + Returns: + edict based on return_types containing: + color (torch.Tensor): [3, H, W] rendered color image + depth (torch.Tensor): [H, W] rendered depth image + normal (torch.Tensor): [3, H, W] rendered normal image + normal_map (torch.Tensor): [3, H, W] rendered normal map image + mask (torch.Tensor): [H, W] rendered mask image + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0: + default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device) + ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types} + return ret_dict + + perspective = intrinsics_to_projection(intrinsics, near, far) + + RT = extrinsics.unsqueeze(0) + full_proj = (perspective @ extrinsics).unsqueeze(0) + + vertices = mesh.vertices.unsqueeze(0) + + vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2)) + vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2)) + faces_int = mesh.faces.int() + rast, _ = dr.rasterize( + self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa)) + + out_dict = edict() + for type in return_types: + img = None + if type == "mask" : + img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int) + elif type == "depth": + img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + elif type == "normal" : + img = dr.interpolate( + mesh.face_normal.reshape(1, -1, 3), rast, + torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3) + )[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + # normalize norm pictures + img = (img + 1) / 2 + elif type == "normal_map" : + img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + elif type == "color" : + img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0] + img = dr.antialias(img, rast, vertices_clip, faces_int) + + if ssaa > 1: + img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True) + img = img.squeeze() + else: + img = img.permute(0, 3, 1, 2).squeeze() + out_dict[type] = img + + return out_dict diff --git a/trellis/renderers/octree_renderer.py b/trellis/renderers/octree_renderer.py new file mode 100755 index 0000000000000000000000000000000000000000..136069cdb0645b5759d5d17f7815612a1dfc7bea --- /dev/null +++ b/trellis/renderers/octree_renderer.py @@ -0,0 +1,300 @@ +import numpy as np +import torch +import torch.nn.functional as F +import math +import cv2 +from scipy.stats import qmc +from easydict import EasyDict as edict +from ..representations.octree import DfsOctree + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if 'OctreeTrivecRasterizer' not in globals(): + from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = edict( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=octree.active_sh_degree, + campos=viewpoint_camera.camera_center, + with_distloss=pipe.with_distloss, + jitter=pipe.jitter, + debug=pipe.debug, + ) + + positions = octree.get_xyz + if octree.primitive == "voxel": + densities = octree.get_density + elif octree.primitive == "gaussian": + opacities = octree.get_opacity + elif octree.primitive == "trivec": + trivecs = octree.get_trivec + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + elif octree.primitive == "decoupoly": + decoupolys_V, decoupolys_g = octree.get_decoupoly + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + else: + raise ValueError(f"Unknown primitive {octree.primitive}") + depths = octree.get_depth + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + colors_precomp = None + shs = octree.get_features + if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None: + colors_precomp = colors_overwrite + shs = None + + ret = edict() + + if octree.primitive == "voxel": + renderer = OctreeVoxelRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, distloss = renderer( + positions = positions, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + ret['distloss'] = distloss + elif octree.primitive == "gaussian": + renderer = OctreeGaussianRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions = positions, + opacities = opacities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + elif octree.primitive == "trivec": + raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1] + renderer = OctreeTrivecRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, percent_depth = renderer( + positions = positions, + trivecs = trivecs, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + colors_overwrite = colors_overwrite, + depths = depths, + aabb = octree.aabb, + aux = aux, + halton_sampler = halton_sampler, + ) + ret['percent_depth'] = percent_depth + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + elif octree.primitive == "decoupoly": + raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1] + renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions = positions, + decoupolys_V = decoupolys_V, + decoupolys_g = decoupolys_g, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + + return ret + + +class OctreeRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + try: + import diffoctreerast + except ImportError: + print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m") + self.unsupported = True + else: + self.unsupported = False + + self.pipe = edict({ + "with_distloss": False, + "with_aux": False, + "scale_modifier": 1.0, + "used_rank": None, + "jitter": False, + "debug": False, + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.halton_sampler = qmc.Halton(2, scramble=False) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + octree: DfsOctree, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None, + ) -> edict: + """ + Render the octree. + + Args: + octree (Octree): octree + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + distloss (Optional[torch.Tensor]): (H, W) rendered distance loss + percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth + aux (Optional[edict]): auxiliary tensors + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.unsupported: + image = np.zeros((512, 512, 3), dtype=np.uint8) + text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0] + origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2 + image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA) + return { + 'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255, + } + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + + if self.pipe["with_aux"]: + aux = { + 'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + 'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + } + for k in aux.keys(): + aux[k].requires_grad_() + aux[k].retain_grad() + else: + aux = None + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict({ + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera + }) + + # Render + render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler) + + if ssaa > 1: + render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + if hasattr(render_ret, 'percent_depth'): + render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + ret = edict({ + 'color': render_ret.rgb, + 'depth': render_ret.depth, + 'alpha': render_ret.alpha, + }) + if self.pipe["with_distloss"] and 'distloss' in render_ret: + ret['distloss'] = render_ret.distloss + if self.pipe["with_aux"]: + ret['aux'] = aux + if hasattr(render_ret, 'percent_depth'): + ret['percent_depth'] = render_ret.percent_depth + return ret diff --git a/trellis/renderers/sh_utils.py b/trellis/renderers/sh_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785 --- /dev/null +++ b/trellis/renderers/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/trellis/representations/__init__.py b/trellis/representations/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..549ffdb97e87181552e9b3e086766f873e4bfb5e --- /dev/null +++ b/trellis/representations/__init__.py @@ -0,0 +1,4 @@ +from .radiance_field import Strivec +from .octree import DfsOctree as Octree +from .gaussian import Gaussian +from .mesh import MeshExtractResult diff --git a/trellis/representations/gaussian/__init__.py b/trellis/representations/gaussian/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e3de6e180bd732836af876d748255595be2d4d74 --- /dev/null +++ b/trellis/representations/gaussian/__init__.py @@ -0,0 +1 @@ +from .gaussian_model import Gaussian \ No newline at end of file diff --git a/trellis/representations/gaussian/gaussian_model.py b/trellis/representations/gaussian/gaussian_model.py new file mode 100755 index 0000000000000000000000000000000000000000..8716bb5ed0b16b6d545d05b04191deeeee042bae --- /dev/null +++ b/trellis/representations/gaussian/gaussian_model.py @@ -0,0 +1,185 @@ +import torch +import numpy as np +from plyfile import PlyData, PlyElement +from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation + + +class Gaussian: + def __init__( + self, + aabb : list, + sh_degree : int = 0, + mininum_kernel_size : float = 0.0, + scaling_bias : float = 0.01, + opacity_bias : float = 0.1, + scaling_activation : str = "exp", + device='cuda' + ): + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.mininum_kernel_size = mininum_kernel_size + self.scaling_bias = scaling_bias + self.opacity_bias = opacity_bias + self.scaling_activation_type = scaling_activation + self.device = device + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.setup_functions() + + self._xyz = None + self._features_dc = None + self._features_rest = None + self._scaling = None + self._rotation = None + self._opacity = None + + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + if self.scaling_activation_type == "exp": + self.scaling_activation = torch.exp + self.inverse_scaling_activation = torch.log + elif self.scaling_activation_type == "softplus": + self.scaling_activation = torch.nn.functional.softplus + self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x)) + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda() + self.rots_bias = torch.zeros((4)).cuda() + self.rots_bias[0] = 1 + self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda() + + @property + def get_scaling(self): + scales = self.scaling_activation(self._scaling + self.scale_bias) + scales = torch.square(scales) + self.mininum_kernel_size ** 2 + scales = torch.sqrt(scales) + return scales + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation + self.rots_bias[None, :]) + + @property + def get_xyz(self): + return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3] + + @property + def get_features(self): + return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity + self.opacity_bias) + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :]) + + def from_scaling(self, scales): + scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2) + self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias + + def from_rotation(self, rots): + self._rotation = rots - self.rots_bias[None, :] + + def from_xyz(self, xyz): + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + + def from_features(self, features): + self._features_dc = features + + def from_opacity(self, opacities): + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path): + xyz = self.get_xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy() + scale = torch.log(self.get_scaling).detach().cpu().numpy() + rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def load_ply(self, path): + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + if self.sh_degree > 0: + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + # convert to actual gaussian attributes + xyz = torch.tensor(xyz, dtype=torch.float, device=self.device) + features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + if self.sh_degree > 0: + features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device)) + scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device)) + rots = torch.tensor(rots, dtype=torch.float, device=self.device) + + # convert to _hidden attributes + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + self._features_dc = features_dc + if self.sh_degree > 0: + self._features_rest = features_extra + else: + self._features_rest = None + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias + self._rotation = rots - self.rots_bias[None, :] + \ No newline at end of file diff --git a/trellis/representations/gaussian/general_utils.py b/trellis/representations/gaussian/general_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..541c0825229a2d86e84460b765879f86f724a59d --- /dev/null +++ b/trellis/representations/gaussian/general_utils.py @@ -0,0 +1,133 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + resized_image_PIL = pil_image.resize(resolution) + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/trellis/representations/mesh/__init__.py b/trellis/representations/mesh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38cf35c0853d11cf09bdc228a87ee9d0b2f34b62 --- /dev/null +++ b/trellis/representations/mesh/__init__.py @@ -0,0 +1 @@ +from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult diff --git a/trellis/representations/mesh/cube2mesh.py b/trellis/representations/mesh/cube2mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e32b51adc7755a5d6bdfa38e9f6b898a6aa7f8 --- /dev/null +++ b/trellis/representations/mesh/cube2mesh.py @@ -0,0 +1,153 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +from ...modules.sparse import SparseTensor +from easydict import EasyDict as edict +from .utils_cube import * +try: + from .flexicube import FlexiCubes +except: + print("Please install kaolin and diso to use the mesh extractor.") + + +class MeshExtractResult: + def __init__(self, + vertices, + faces, + vertex_attrs=None, + res=64 + ): + self.vertices = vertices + self.faces = faces.long() + self.vertex_attrs = vertex_attrs + self.face_normal = self.comput_face_normals(vertices, faces) + self.res = res + self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0) + + # training only + self.tsdf_v = None + self.tsdf_s = None + self.reg_loss = None + + def comput_face_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + face_normals = torch.nn.functional.normalize(face_normals, dim=1) + # print(face_normals.min(), face_normals.max(), face_normals.shape) + return face_normals[:, None, :].repeat(1, 3, 1) + + def comput_v_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + v_normals = torch.zeros_like(verts) + v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) + + v_normals = torch.nn.functional.normalize(v_normals, dim=1) + return v_normals + + +class SparseFeatures2Mesh: + def __init__(self, device="cuda", res=64, use_color=True): + ''' + a model to generate a mesh from sparse features structures using flexicube + ''' + super().__init__() + self.device=device + self.res = res + self.mesh_extractor = FlexiCubes(device=device) + self.sdf_bias = -1.0 / res + verts, cube = construct_dense_grid(self.res, self.device) + self.reg_c = cube.to(self.device) + self.reg_v = verts.to(self.device) + self.use_color = use_color + self._calc_layout() + + def _calc_layout(self): + LAYOUTS = { + 'sdf': {'shape': (8, 1), 'size': 8}, + 'deform': {'shape': (8, 3), 'size': 8 * 3}, + 'weights': {'shape': (21,), 'size': 21} + } + if self.use_color: + ''' + 6 channel color including normal map + ''' + LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} + self.layouts = edict(LAYOUTS) + start = 0 + for k, v in self.layouts.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.feats_channels = start + + def get_layout(self, feats : torch.Tensor, name : str): + if name not in self.layouts: + return None + return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape']) + + def __call__(self, cubefeats : SparseTensor, training=False): + """ + Generates a mesh based on the specified sparse voxel structures. + Args: + cube_attrs [Nx21] : Sparse Tensor attrs about cube weights + verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal + Returns: + return the success tag and ni you loss, + """ + # add sdf bias to verts_attrs + coords = cubefeats.coords[:, 1:] + feats = cubefeats.feats + + sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']] + sdf += self.sdf_bias + v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] + v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training) + v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True) + weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False) + if self.use_color: + sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:] + else: + sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] + colors_d = None + + x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) + + vertices, faces, L_dev, colors = self.mesh_extractor( + voxelgrid_vertices=x_nx3, + scalar_field=sdf_d, + cube_idx=self.reg_c, + resolution=self.res, + beta=weights_d[:, :12], + alpha=weights_d[:, 12:20], + gamma_f=weights_d[:, 20], + voxelgrid_colors=colors_d, + training=training) + + mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res) + if training: + if mesh.success: + reg_loss += L_dev.mean() * 0.5 + reg_loss += (weights[:,:20]).abs().mean() * 0.2 + mesh.reg_loss = reg_loss + mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) + mesh.tsdf_s = v_attrs[:, 0] + return mesh diff --git a/trellis/representations/mesh/flexicube.py b/trellis/representations/mesh/flexicube.py new file mode 100644 index 0000000000000000000000000000000000000000..63c786b32bf53775a45f2bb4c343b009b81fec9c --- /dev/null +++ b/trellis/representations/mesh/flexicube.py @@ -0,0 +1,384 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from .tables import * +from kaolin.utils.testing import check_tensor + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + def __init__(self, device="cuda"): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + + def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3, + weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False): + assert torch.is_tensor(voxelgrid_vertices) and \ + check_tensor(voxelgrid_vertices, (None, 3), throw=False), \ + "'voxelgrid_vertices' should be a tensor of shape (num_vertices, 3)" + num_vertices = voxelgrid_vertices.shape[0] + assert torch.is_tensor(scalar_field) and \ + check_tensor(scalar_field, (num_vertices,), throw=False), \ + "'scalar_field' should be a tensor of shape (num_vertices,)" + assert torch.is_tensor(cube_idx) and \ + check_tensor(cube_idx, (None, 8), throw=False), \ + "'cube_idx' should be a tensor of shape (num_cubes, 8)" + num_cubes = cube_idx.shape[0] + assert beta is None or ( + torch.is_tensor(beta) and + check_tensor(beta, (num_cubes, 12), throw=False) + ), "'beta' should be a tensor of shape (num_cubes, 12)" + assert alpha is None or ( + torch.is_tensor(alpha) and + check_tensor(alpha, (num_cubes, 8), throw=False) + ), "'alpha' should be a tensor of shape (num_cubes, 8)" + assert gamma_f is None or ( + torch.is_tensor(gamma_f) and + check_tensor(gamma_f, (num_cubes,), throw=False) + ), "'gamma_f' should be a tensor of shape (num_cubes,)" + + surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx) + if surf_cubes.sum() == 0: + return ( + torch.zeros((0, 3), device=self.device), + torch.zeros((0, 3), dtype=torch.long, device=self.device), + torch.zeros((0), device=self.device), + torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None + ) + beta, alpha, gamma_f = self._normalize_weights( + beta, alpha, gamma_f, surf_cubes, weight_scale) + + if voxelgrid_colors is not None: + voxelgrid_colors = torch.sigmoid(voxelgrid_colors) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges( + scalar_field, cube_idx, surf_cubes + ) + + vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd( + voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field, + case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors) + vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate( + scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, + vd_idx_map, surf_edges_mask, training, vd_color) + return vertices, faces, L_dev, vertices_color + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta is not None: + beta = (torch.tanh(beta) * weight_scale + 1) + else: + beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha is not None: + alpha = (torch.tanh(alpha) * weight_scale + 1) + else: + alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = scalar_field < 0 + all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, scalar_field, cube_idx): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = scalar_field < 0 + occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)] + , edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field, + case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + if voxelgrid_colors is not None: + C = voxelgrid_colors.shape[-1] + surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + # if color is not None: + # vd_color = [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + # if color is not None: + # vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3)) + + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + # if color is not None: + # vd_color = torch.cat(vd_color) + # else: + # vd_color = None + + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + + ''' + interpolate colors use the same method as dual vertices + ''' + if voxelgrid_colors is not None: + vd_color = torch.zeros((total_num_vd, C), device=self.device) + c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C) + uc_group = self._linear_interp(s_group * alpha_group, c_group) + vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum + else: + vd_color = None + + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map, vd_color + + def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2] + gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3] + if not training: + mask = (gamma_02 > gamma_13) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2 + vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + + if vd_color is not None: + color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1]) + color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2 + color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2 + color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + vd_color = torch.cat([vd_color, color_center]) + + + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices, vd_color diff --git a/trellis/representations/mesh/tables.py b/trellis/representations/mesh/tables.py new file mode 100644 index 0000000000000000000000000000000000000000..7c02dd7f4133aef487f623c02b11e3075cab0916 --- /dev/null +++ b/trellis/representations/mesh/tables.py @@ -0,0 +1,791 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +dmc_table = [ +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] +] +num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, +2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, +1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, +1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, +2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, +3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, +2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, +1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, +1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +check_table = [ +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 194], +[1, -1, 0, 0, 193], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 164], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 161], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 152], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 145], +[1, 0, 0, 1, 144], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 137], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 133], +[1, 0, 1, 0, 132], +[1, 1, 0, 0, 131], +[1, 1, 0, 0, 130], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 100], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 98], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 96], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 88], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 82], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 74], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 72], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 70], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 67], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 65], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 56], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 52], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 44], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 40], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 38], +[1, 0, -1, 0, 37], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 33], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 28], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 26], +[1, 0, 0, -1, 25], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 20], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 18], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 9], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 6], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0] +] +tet_table = [ +[-1, -1, -1, -1, -1, -1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, -1], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, -1], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, -1, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, -1, 2, 4, 4, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, 5, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, -1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[-1, 1, 1, 4, 4, 1], +[0, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[8, 8, 8, 8, 8, 8], +[1, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 4, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 5, 5, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[6, 6, 6, 6, 6, 6], +[6, -1, 0, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 4, -1, 6, 4, 6], +[6, 4, 0, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 2, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 1, 1, 6, -1, 6], +[6, 1, 1, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 4], +[2, 2, 2, 2, 2, 2], +[6, 1, 1, 6, 4, 6], +[6, 1, 1, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 5, 0, 5, 0, 5], +[5, 5, 5, 5, 5, 5], +[5, 5, 5, 5, 5, 5], +[0, 5, 0, 5, 0, 5], +[-1, 5, 0, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[4, 5, -1, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[4, 5, 0, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 6, 6, 6, 6, 6], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, -1, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[2, 5, 2, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 4], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 6, 2, 6, 6, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 1, 4, 1], +[0, 1, 1, 1, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 0, 0, 6, 0, 6], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[5, 5, 5, 5, 5, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 4, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[4, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[8, 8, 8, 8, 8, 8], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 1, 1, 4, 4, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 4, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[12, 12, 12, 12, 12, 12] +] \ No newline at end of file diff --git a/trellis/representations/mesh/utils_cube.py b/trellis/representations/mesh/utils_cube.py new file mode 100644 index 0000000000000000000000000000000000000000..23913c97bb2d57dfa0384667c69f9860ea0a4155 --- /dev/null +++ b/trellis/representations/mesh/utils_cube.py @@ -0,0 +1,61 @@ +import torch +cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int) +cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]]) +cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False) + +def construct_dense_grid(res, device='cuda'): + '''construct a dense grid based on resolution''' + res_v = res + 1 + vertsid = torch.arange(res_v ** 3, device=device) + coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() + cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2] + cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device)) + verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1) + return verts, cube_fx8 + + +def construct_voxel_grid(coords): + verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3) + verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True) + cubes = inverse_indices.reshape(-1, 8) + return verts_unique, cubes + + +def cubes_to_verts(num_verts, cubes, value, reduce='mean'): + """ + Args: + cubes [Vx8] verts index for each cube + value [Vx8xM] value to be scattered + Operation: + reduced[cubes[i][j]][k] += value[i][k] + """ + M = value.shape[2] # number of channels + reduced = torch.zeros(num_verts, M, device=cubes.device) + return torch.scatter_reduce(reduced, 0, + cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), + value.flatten(0, 1), reduce=reduce, include_self=False) + +def sparse_cube2verts(coords, feats, training=True): + new_coords, cubes = construct_voxel_grid(coords) + new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) + if training: + con_loss = torch.mean((feats - new_feats[cubes]) ** 2) + else: + con_loss = 0.0 + return new_coords, new_feats, con_loss + + +def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True): + F = feats.shape[-1] + dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device) + if sdf_init: + dense_attrs[..., 0] = 1 # initial outside sdf value + dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats + return dense_attrs.reshape(-1, F) + + +def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res): + return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform) + \ No newline at end of file diff --git a/trellis/representations/octree/__init__.py b/trellis/representations/octree/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..f66a39a5a7498e2e99fe9d94d663796b3bc157b5 --- /dev/null +++ b/trellis/representations/octree/__init__.py @@ -0,0 +1 @@ +from .octree_dfs import DfsOctree \ No newline at end of file diff --git a/trellis/representations/octree/octree_dfs.py b/trellis/representations/octree/octree_dfs.py new file mode 100755 index 0000000000000000000000000000000000000000..9d1f7898f30414f304953cfb2d51d00511ec8325 --- /dev/null +++ b/trellis/representations/octree/octree_dfs.py @@ -0,0 +1,362 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +DEFAULT_TRIVEC_CONFIG = { + 'dim': 8, + 'rank': 8, +} + +DEFAULT_VOXEL_CONFIG = { + 'solid': False, +} + +DEFAULT_DECOPOLY_CONFIG = { + 'degree': 8, + 'rank': 16, +} + + +class DfsOctree: + """ + Sparse Voxel Octree (SVO) implementation for PyTorch. + Using Depth-First Search (DFS) order to store the octree. + DFS order suits rendering and ray tracing. + + The structure and data are separatedly stored. + Structure is stored as a continuous array, each element is a 3*32 bits descriptor. + |-----------------------------------------| + | 0:3 bits | 4:31 bits | + | leaf num | unused | + |-----------------------------------------| + | 0:31 bits | + | child ptr | + |-----------------------------------------| + | 0:31 bits | + | data ptr | + |-----------------------------------------| + Each element represents a non-leaf node in the octree. + The valid mask is used to indicate whether the children are valid. + The leaf mask is used to indicate whether the children are leaf nodes. + The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr. + The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr. + + There are also auxiliary arrays to store the additional structural information to facilitate parallel processing. + - Position: the position of the octree nodes. + - Depth: the depth of the octree nodes. + + Args: + depth (int): the depth of the octree. + """ + + def __init__( + self, + depth, + aabb=[0,0,0,1,1,1], + sh_degree=2, + primitive='voxel', + primitive_config={}, + device='cuda', + ): + self.max_depth = depth + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.device = device + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.primitive = primitive + self.primitive_config = primitive_config + + self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device) + self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device) + self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device) + self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device) + self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device) + self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device) + self.depth[:, 0] = 1 + + self.data = ['position', 'depth'] + self.param_names = [] + + if primitive == 'voxel': + self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.data += ['features_dc', 'features_ac'] + self.param_names += ['features_dc', 'features_ac'] + if not primitive_config.get('solid', False): + self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data.append('density') + self.param_names.append('density') + elif primitive == 'gaussian': + self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data += ['features_dc', 'features_ac', 'opacity'] + self.param_names += ['features_dc', 'features_ac', 'opacity'] + elif primitive == 'trivec': + self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device) + self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) + self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.density_shift = 0 + self.data += ['trivec', 'density', 'features_dc', 'features_ac'] + self.param_names += ['trivec', 'density', 'features_dc', 'features_ac'] + elif primitive == 'decoupoly': + self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device) + self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device) + self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) + self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.density_shift = 0 + self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + + self.setup_functions() + + def setup_functions(self): + self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x) + self.opacity_activation = lambda x: torch.sigmoid(x - 6) + self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6 + self.color_activation = lambda x: torch.sigmoid(x) + + @property + def num_non_leaf_nodes(self): + return self.structure.shape[0] + + @property + def num_leaf_nodes(self): + return self.depth.shape[0] + + @property + def cur_depth(self): + return self.depth.max().item() + + @property + def occupancy(self): + return self.num_leaf_nodes / 8 ** self.cur_depth + + @property + def get_xyz(self): + return self.position + + @property + def get_depth(self): + return self.depth + + @property + def get_density(self): + if self.primitive == 'voxel' and self.voxel_config['solid']: + return torch.full((self.position.shape[0], 1), 1000, dtype=torch.float32, device=self.device) + return self.density_activation(self.density) + + @property + def get_opacity(self): + return self.opacity_activation(self.density) + + @property + def get_trivec(self): + return self.trivec + + @property + def get_decoupoly(self): + return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g + + @property + def get_color(self): + return self.color_activation(self.colors) + + @property + def get_features(self): + if self.sh_degree == 0: + return self.features_dc + return torch.cat([self.features_dc, self.features_ac], dim=-2) + + def state_dict(self): + ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'trivec_config': self.trivec_config, 'voxel_config': self.voxel_config, 'primitive': self.primitive} + if hasattr(self, 'density_shift'): + ret['density_shift'] = self.density_shift + for data in set(self.data + self.param_names): + if not isinstance(getattr(self, data), nn.Module): + ret[data] = getattr(self, data) + else: + ret[data] = getattr(self, data).state_dict() + return ret + + def load_state_dict(self, state_dict): + keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth'])) + for key in keys: + if key not in state_dict: + print(f"Warning: key {key} not found in the state_dict.") + continue + try: + if not isinstance(getattr(self, key), nn.Module): + setattr(self, key, state_dict[key]) + else: + getattr(self, key).load_state_dict(state_dict[key]) + except Exception as e: + print(e) + raise ValueError(f"Error loading key {key}.") + + def gather_from_leaf_children(self, data): + """ + Gather the data from the leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + leaf_cnt = self.structure[:, 0] + leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device) + for i in range(8): + if leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[leaf_cnt_masks[i], 2] + for j in range(i+1): + ret[leaf_cnt_masks[i]] += data[start + j] + return ret + + def gather_from_non_leaf_children(self, data): + """ + Gather the data from the non-leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + non_leaf_cnt = 8 - self.structure[:, 0] + non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros_like(data, device=self.device) + for i in range(8): + if non_leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[non_leaf_cnt_masks[i], 1] + for j in range(i+1): + ret[non_leaf_cnt_masks[i]] += data[start + j] + return ret + + def structure_control(self, mask): + """ + Control the structure of the octree. + + Args: + mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep. + """ + # Dont subdivide when the depth is the maximum. + mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0) + # Dont merge when the depth is the minimum. + mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0) + + # Gather control mask + structre_ctrl = self.gather_from_leaf_children(mask) + structre_ctrl[structre_ctrl==-8] = -1 + + new_leaf_num = self.structure[:, 0].clone() + # Modify the leaf num. + structre_valid = structre_ctrl >= 0 + new_leaf_num[structre_valid] -= structre_ctrl[structre_valid] # Add the new nodes. + structre_delete = structre_ctrl < 0 + merged_nodes = self.gather_from_non_leaf_children(structre_delete.int()) + new_leaf_num += merged_nodes # Delete the merged nodes. + + # Update the structure array to allocate new nodes. + mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device) + mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid]) # Add the new nodes. + mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes. + new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + new_structure_length = new_structre_idx[-1].item() + new_structre_idx = new_structre_idx[:-1] + new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device) + new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid] + + # Initialize the new nodes. + new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device) + new_node_mask[new_structre_idx[structre_valid]] = False + new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes. + new_node_num = new_node_mask.sum().item() + + # Rebuild child ptr. + non_leaf_cnt = 8 - new_structure[:, 0] + new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]]) + new_structure[:, 1] = new_child_ptr + 1 + + # Rebuild data ptr with old data. + leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device) + leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0]) + old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + + # Update the data array + subdivide_mask = mask == 1 + merge_mask = mask == -1 + data_valid = ~(subdivide_mask | merge_mask) + mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device) + mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device)) # Add data array for new nodes + mem_offset[:-1] -= subdivide_mask.int() # Delete data elements for subdivide nodes + mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes + mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid]) # Add data elements for merge nodes + new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + new_data_length = new_data_idx[-1].item() + new_data_idx = new_data_idx[:-1] + new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data} + for data in self.data: + new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid] + + # Rebuild data ptr + leaf_cnt = new_structure[:, 0] + new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + new_structure[:, 2] = new_data_ptr + + # Initialize the new data array + ## For subdivide nodes + if subdivide_mask.sum() > 0: + subdivide_data_ptr = new_structure[new_node_mask, 2] + for data in self.data: + for i in range(8): + if data == 'position': + offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5 + scale = 2 ** (-1.0 - self.depth[subdivide_mask]) + new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale + elif data == 'depth': + new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1 + elif data == 'opacity': + new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask]))) + elif data == 'trivec': + offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5 + coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1) + axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1) + coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1 + new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True) + else: + new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask] + ## For merge nodes + if merge_mask.sum() > 0: + merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device) + merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]]) + for i in range(8): + merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i + old_merge_data_ptr = self.structure[structre_delete, 2] + for data in self.data: + if data == 'position': + scale = 2 ** (1.0 - self.depth[old_merge_data_ptr]) + new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5 + elif data == 'depth': + new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1 + elif data == 'opacity': + new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2) + elif data == 'trivec': + new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr] + else: + new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr] + + # Update the structure and data array + self.structure = new_structure + for data in self.data: + setattr(self, data, new_data[data]) + + # Save data array control temp variables + self.data_rearrange_buffer = { + 'subdivide_mask': subdivide_mask, + 'merge_mask': merge_mask, + 'data_valid': data_valid, + 'new_data_idx': new_data_idx, + 'new_data_length': new_data_length, + 'new_data': new_data + } diff --git a/trellis/representations/radiance_field/__init__.py b/trellis/representations/radiance_field/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b72a1b7e76b509ee5a5e6979858eb17b4158a151 --- /dev/null +++ b/trellis/representations/radiance_field/__init__.py @@ -0,0 +1 @@ +from .strivec import Strivec \ No newline at end of file diff --git a/trellis/representations/radiance_field/strivec.py b/trellis/representations/radiance_field/strivec.py new file mode 100644 index 0000000000000000000000000000000000000000..8fc4b749786d934dae82864b560baccd91fcabbc --- /dev/null +++ b/trellis/representations/radiance_field/strivec.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..octree import DfsOctree as Octree + + +class Strivec(Octree): + def __init__( + self, + resolution: int, + aabb: list, + sh_degree: int = 0, + rank: int = 8, + dim: int = 8, + device: str = "cuda", + ): + assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2" + self.resolution = resolution + depth = int(np.round(np.log2(resolution))) + super().__init__( + depth=depth, + aabb=aabb, + sh_degree=sh_degree, + primitive="trivec", + primitive_config={"rank": rank, "dim": dim}, + device=device, + ) diff --git a/trellis/utils/__init__.py b/trellis/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/trellis/utils/general_utils.py b/trellis/utils/general_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..3b454d9c75521e33466055fe37c3fc1e37180a79 --- /dev/null +++ b/trellis/utils/general_utils.py @@ -0,0 +1,187 @@ +import numpy as np +import cv2 +import torch + + +# Dictionary utils +def _dict_merge(dicta, dictb, prefix=''): + """ + Merge two dictionaries. + """ + assert isinstance(dicta, dict), 'input must be a dictionary' + assert isinstance(dictb, dict), 'input must be a dictionary' + dict_ = {} + all_keys = set(dicta.keys()).union(set(dictb.keys())) + for key in all_keys: + if key in dicta.keys() and key in dictb.keys(): + if isinstance(dicta[key], dict) and isinstance(dictb[key], dict): + dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}') + else: + raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}') + elif key in dicta.keys(): + dict_[key] = dicta[key] + else: + dict_[key] = dictb[key] + return dict_ + + +def dict_merge(dicta, dictb): + """ + Merge two dictionaries. + """ + return _dict_merge(dicta, dictb, prefix='') + + +def dict_foreach(dic, func, special_func={}): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + dic[key] = dict_foreach(dic[key], func) + else: + if key in special_func.keys(): + dic[key] = special_func[key](dic[key]) + else: + dic[key] = func(dic[key]) + return dic + + +def dict_reduce(dicts, func, special_func={}): + """ + Reduce a list of dictionaries. Leaf values must be scalars. + """ + assert isinstance(dicts, list), 'input must be a list of dictionaries' + assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries' + assert len(dicts) > 0, 'input must be a non-empty list of dictionaries' + all_keys = set([key for dict_ in dicts for key in dict_.keys()]) + reduced_dict = {} + for key in all_keys: + vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()] + if isinstance(vlist[0], dict): + reduced_dict[key] = dict_reduce(vlist, func, special_func) + else: + if key in special_func.keys(): + reduced_dict[key] = special_func[key](vlist) + else: + reduced_dict[key] = func(vlist) + return reduced_dict + + +def dict_any(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if dict_any(dic[key], func): + return True + else: + if func(dic[key]): + return True + return False + + +def dict_all(dic, func): + """ + Recursively apply a function to all non-dictionary leaf values in a dictionary. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + for key in dic.keys(): + if isinstance(dic[key], dict): + if not dict_all(dic[key], func): + return False + else: + if not func(dic[key]): + return False + return True + + +def dict_flatten(dic, sep='.'): + """ + Flatten a nested dictionary into a dictionary with no nested dictionaries. + """ + assert isinstance(dic, dict), 'input must be a dictionary' + flat_dict = {} + for key in dic.keys(): + if isinstance(dic[key], dict): + sub_dict = dict_flatten(dic[key], sep=sep) + for sub_key in sub_dict.keys(): + flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key] + else: + flat_dict[key] = dic[key] + return flat_dict + + +def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): + num_images = len(images) + if nrow is None and ncol is None: + if aspect_ratio is not None: + nrow = int(np.round(np.sqrt(num_images / aspect_ratio))) + else: + nrow = int(np.sqrt(num_images)) + ncol = (num_images + nrow - 1) // nrow + elif nrow is None and ncol is not None: + nrow = (num_images + ncol - 1) // ncol + elif nrow is not None and ncol is None: + ncol = (num_images + nrow - 1) // nrow + else: + assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' + + grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype) + for i, img in enumerate(images): + row = i // ncol + col = i % ncol + grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img + return grid + + +def notes_on_image(img, notes=None): + img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + if notes is not None: + img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def save_image_with_notes(img, path, notes=None): + """ + Save an image with notes. + """ + if isinstance(img, torch.Tensor): + img = img.cpu().numpy().transpose(1, 2, 0) + if img.dtype == np.float32 or img.dtype == np.float64: + img = np.clip(img * 255, 0, 255).astype(np.uint8) + img = notes_on_image(img, notes) + cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + + +# debug utils + +def atol(x, y): + """ + Absolute tolerance. + """ + return torch.abs(x - y) + + +def rtol(x, y): + """ + Relative tolerance. + """ + return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12) + + +# print utils +def indent(s, n=4): + """ + Indent a string. + """ + lines = s.split('\n') + for i in range(1, len(lines)): + lines[i] = ' ' * n + lines[i] + return '\n'.join(lines) + diff --git a/trellis/utils/postprocessing_utils.py b/trellis/utils/postprocessing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d0709033d2f5b784767b028b51aa5d4014f9db34 --- /dev/null +++ b/trellis/utils/postprocessing_utils.py @@ -0,0 +1,458 @@ +from typing import * +import numpy as np +import torch +import utils3d +import nvdiffrast.torch as dr +from tqdm import tqdm +import trimesh +import trimesh.visual +import xatlas +import pyvista as pv +from pymeshfix import _meshfix +import igraph +import cv2 +from PIL import Image +from .random_utils import sphere_hammersley_sequence +from .render_utils import render_multiview +from ..representations import Strivec, Gaussian, MeshExtractResult + + +@torch.no_grad() +def _fill_holes( + verts, + faces, + max_hole_size=0.04, + max_hole_nbe=32, + resolution=128, + num_views=500, + debug=False, + verbose=False +): + """ + Rasterize a mesh from multiple views and remove invisible faces. + Also includes postprocessing to: + 1. Remove connected components that are have low visibility. + 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. + + Args: + verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). + faces (torch.Tensor): Faces of the mesh. Shape (F, 3). + max_hole_size (float): Maximum area of a hole to fill. + resolution (int): Resolution of the rasterization. + num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + # Construct cameras + yaws = [] + pitchs = [] + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views) + yaws.append(y) + pitchs.append(p) + yaws = torch.tensor(yaws).cuda() + pitchs = torch.tensor(pitchs).cuda() + radius = 2.0 + fov = torch.deg2rad(torch.tensor(40)).cuda() + projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) + views = [] + for (yaw, pitch) in zip(yaws, pitchs): + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda().float() * radius + view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + views.append(view) + views = torch.stack(views, dim=0) + + # Rasterize + visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) + rastctx = utils3d.torch.RastContext(backend='cuda') + for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'): + view = views[i] + buffers = utils3d.torch.rasterize_triangle_faces( + rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection + ) + face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1 + face_id = torch.unique(face_id).long() + visblity[face_id] += 1 + visblity = visblity.float() / num_views + + # Mincut + ## construct outer faces + edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) + boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) + connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge) + outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device) + for i in range(len(connected_components)): + outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) + outer_face_indices = outer_face_indices.nonzero().reshape(-1) + + ## construct inner faces + inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) + if verbose: + tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces') + if inner_face_indices.shape[0] == 0: + return verts, faces + + ## Construct dual graph (faces as nodes, edges as edges) + dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) + dual_edge2edge = edges[dual_edge2edge] + dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1) + if verbose: + tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges') + + ## solve mincut problem + ### construct main graph + g = igraph.Graph() + g.add_vertices(faces.shape[0]) + g.add_edges(dual_edges.cpu().numpy()) + g.es['weight'] = dual_edges_weights.cpu().numpy() + + ### source and target + g.add_vertex('s') + g.add_vertex('t') + + ### connect invisible faces to source + g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### connect outer faces to target + g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### solve mincut + cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist()) + remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device) + if verbose: + tqdm.write(f'Mincut solved, start checking the cut') + + ### check if the cut is valid with each connected component + to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices]) + if debug: + tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}') + valid_remove_cc = [] + cutting_edges = [] + for cc in to_remove_cc: + #### check if the connected component has low visibility + visblity_median = visblity[remove_face_indices[cc]].median() + if debug: + tqdm.write(f'visblity_median: {visblity_median}') + if visblity_median > 0.25: + continue + + #### check if the cuting loop is small enough + cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True) + cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] + cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)] + if len(cc_new_boundary_edge_indices) > 0: + cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices]) + cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc] + cc_new_boundary_edges_cc_area = [] + for i, edge_cc in enumerate(cc_new_boundary_edge_cc): + _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] + _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] + cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5) + if debug: + cutting_edges.append(cc_new_boundary_edge_indices) + tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}') + if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): + continue + + valid_remove_cc.append(cc) + + if debug: + face_v = verts[faces].mean(dim=1).cpu().numpy() + vis_dual_edges = dual_edges.cpu().numpy() + vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) + vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] + vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] + vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] + if len(valid_remove_cc) > 0: + vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0] + utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors) + + vis_verts = verts.cpu().numpy() + vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() + utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges) + + + if len(valid_remove_cc) > 0: + remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] + mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) + mask[remove_face_indices] = 0 + faces = faces[mask] + faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) + if verbose: + tqdm.write(f'Removed {(~mask).sum()} faces by mincut') + else: + if verbose: + tqdm.write(f'Removed 0 faces by mincut') + + mesh = _meshfix.PyTMesh() + mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) + mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) + verts, faces = mesh.return_arrays() + verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32) + + return verts, faces + + +def postprocess_mesh( + vertices: np.array, + faces: np.array, + simplify: bool = True, + simplify_ratio: float = 0.9, + fill_holes: bool = True, + fill_holes_max_hole_size: float = 0.04, + fill_holes_max_hole_nbe: int = 32, + fill_holes_resolution: int = 1024, + fill_holes_num_views: int = 1000, + debug: bool = False, + verbose: bool = False, +): + """ + Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + simplify (bool): Whether to simplify the mesh, using quadric edge collapse. + simplify_ratio (float): Ratio of faces to keep after simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_hole_size (float): Maximum area of a hole to fill. + fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. + fill_holes_resolution (int): Resolution of the rasterization. + fill_holes_num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + + if verbose: + tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Simplify + if simplify and simplify_ratio > 0: + mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) + mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) + vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] + if verbose: + tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Remove invisible faces + if fill_holes: + vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda() + vertices, faces = _fill_holes( + vertices, faces, + max_hole_size=fill_holes_max_hole_size, + max_hole_nbe=fill_holes_max_hole_nbe, + resolution=fill_holes_resolution, + num_views=fill_holes_num_views, + debug=debug, + verbose=verbose, + ) + vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() + if verbose: + tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + return vertices, faces + + +def parametrize_mesh(vertices: np.array, faces: np.array): + """ + Parametrize a mesh to a texture space, using xatlas. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + """ + + vmapping, indices, uvs = xatlas.parametrize(vertices, faces) + + vertices = vertices[vmapping] + faces = indices + + return vertices, faces, uvs + + +def bake_texture( + vertices: np.array, + faces: np.array, + uvs: np.array, + observations: List[np.array], + masks: List[np.array], + extrinsics: List[np.array], + intrinsics: List[np.array], + texture_size: int = 2048, + near: float = 0.1, + far: float = 10.0, + mode: Literal['fast', 'opt'] = 'opt', + lambda_tv: float = 1e-2, + verbose: bool = False, +): + """ + Bake texture to a mesh from multiple observations. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + uvs (np.array): UV coordinates of the mesh. Shape (V, 2). + observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). + masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). + extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). + intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). + texture_size (int): Size of the texture. + near (float): Near plane of the camera. + far (float): Far plane of the camera. + mode (Literal['fast', 'opt']): Mode of texture baking. + lambda_tv (float): Weight of total variation loss in optimization. + verbose (bool): Whether to print progress. + """ + vertices = torch.tensor(vertices).cuda() + faces = torch.tensor(faces.astype(np.int32)).cuda() + uvs = torch.tensor(uvs).cuda() + observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] + masks = [torch.tensor(m>0).bool().cuda() for m in masks] + views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics] + projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics] + + if mode == 'fast': + texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda() + texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda() + rastctx = utils3d.torch.RastContext(backend='cuda') + for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + ) + uv_map = rast['uv'][0].detach().flip(0) + mask = rast['mask'][0].detach().bool() & masks[0] + + # nearest neighbor interpolation + uv_map = (uv_map * texture_size).floor().long() + obs = observation[mask] + uv_map = uv_map[mask] + idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size + texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) + texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device)) + + mask = texture_weights > 0 + texture[mask] /= texture_weights[mask][:, None] + texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8) + + # inpaint + mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + + elif mode == 'opt': + rastctx = utils3d.torch.RastContext(backend='cuda') + observations = [observations.flip(0) for observations in observations] + masks = [m.flip(0) for m in masks] + _uv = [] + _uv_dr = [] + for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + ) + _uv.append(rast['uv'].detach()) + _uv_dr.append(rast['uv_dr'].detach()) + + texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda()) + optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + + def tv_loss(texture): + return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \ + torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) + + total_steps = 2500 + with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar: + for step in range(total_steps): + optimizer.zero_grad() + selected = np.random.randint(0, len(views)) + uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected] + render = dr.texture(texture, uv, uv_dr)[0] + loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) + if lambda_tv > 0: + loss += lambda_tv * tv_loss(texture) + loss.backward() + optimizer.step() + # annealing + optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5) + pbar.set_postfix({'loss': loss.item()}) + pbar.update() + texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) + mask = 1 - utils3d.torch.rasterize_triangle_faces( + rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size + )['mask'][0].detach().cpu().numpy().astype(np.uint8) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + else: + raise ValueError(f'Unknown mode: {mode}') + + return texture + + +def to_glb( + app_rep: Union[Strivec, Gaussian], + mesh: MeshExtractResult, + simplify: float = 0.95, + fill_holes: bool = True, + fill_holes_max_size: float = 0.04, + texture_size: int = 1024, + debug: bool = False, + verbose: bool = True, +) -> trimesh.Trimesh: + """ + Convert a generated asset to a glb file. + + Args: + app_rep (Union[Strivec, Gaussian]): Appearance representation. + mesh (MeshExtractResult): Extracted mesh. + simplify (float): Ratio of faces to remove in simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_size (float): Maximum area of a hole to fill. + texture_size (int): Size of the texture. + debug (bool): Whether to print debug information. + verbose (bool): Whether to print progress. + """ + vertices = mesh.vertices.cpu().numpy() + faces = mesh.faces.cpu().numpy() + + # mesh postprocess + vertices, faces = postprocess_mesh( + vertices, faces, + simplify=simplify > 0, + simplify_ratio=simplify, + fill_holes=fill_holes, + fill_holes_max_hole_size=fill_holes_max_size, + fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)), + fill_holes_resolution=1024, + fill_holes_num_views=1000, + debug=debug, + verbose=verbose, + ) + + # parametrize mesh + vertices, faces, uvs = parametrize_mesh(vertices, faces) + + # bake texture + observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100) + masks = [np.any(observation > 0, axis=-1) for observation in observations] + extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] + intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] + texture = bake_texture( + vertices, faces, uvs, + observations, masks, extrinsics, intrinsics, + texture_size=texture_size, mode='opt', + lambda_tv=0.01, + verbose=True + ) + texture = Image.fromarray(texture) + + # rotate mesh (from z-up to y-up) + vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture)) + return mesh diff --git a/trellis/utils/random_utils.py b/trellis/utils/random_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5b668c277b51f4930991912a80573adc79364028 --- /dev/null +++ b/trellis/utils/random_utils.py @@ -0,0 +1,30 @@ +import numpy as np + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] \ No newline at end of file diff --git a/trellis/utils/render_utils.py b/trellis/utils/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8187c84f305d51540e88ae5b634a484a74c16e95 --- /dev/null +++ b/trellis/utils/render_utils.py @@ -0,0 +1,116 @@ +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from PIL import Image + +from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer +from ..representations import Octree, Gaussian, MeshExtractResult +from ..modules import sparse as sp +from .random_utils import sphere_hammersley_sequence + + +def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).cuda() + yaw = torch.tensor(float(yaw)).cuda() + pitch = torch.tensor(float(pitch)).cuda() + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda() * r + extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs): + if isinstance(sample, Octree): + renderer = OctreeRenderer() + renderer.rendering_options.resolution = options.get('resolution', 512) + renderer.rendering_options.near = options.get('near', 0.8) + renderer.rendering_options.far = options.get('far', 1.6) + renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) + renderer.rendering_options.ssaa = options.get('ssaa', 4) + renderer.pipe.primitive = sample.primitive + elif isinstance(sample, Gaussian): + renderer = GaussianRenderer() + renderer.rendering_options.resolution = options.get('resolution', 512) + renderer.rendering_options.near = options.get('near', 0.8) + renderer.rendering_options.far = options.get('far', 1.6) + renderer.rendering_options.bg_color = options.get('bg_color', (0, 0, 0)) + renderer.rendering_options.ssaa = options.get('ssaa', 1) + renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1) + renderer.pipe.use_mip_gaussian = True + elif isinstance(sample, MeshExtractResult): + renderer = MeshRenderer() + renderer.rendering_options.resolution = options.get('resolution', 512) + renderer.rendering_options.near = options.get('near', 1) + renderer.rendering_options.far = options.get('far', 100) + renderer.rendering_options.ssaa = options.get('ssaa', 4) + else: + raise ValueError(f'Unsupported sample type: {type(sample)}') + + rets = {} + for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose): + if not isinstance(sample, MeshExtractResult): + res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite) + if 'color' not in rets: rets['color'] = [] + if 'depth' not in rets: rets['depth'] = [] + rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + if 'percent_depth' in res: + rets['depth'].append(res['percent_depth'].detach().cpu().numpy()) + elif 'depth' in res: + rets['depth'].append(res['depth'].detach().cpu().numpy()) + else: + rets['depth'].append(None) + else: + res = renderer.render(sample, extr, intr) + if 'normal' not in rets: rets['normal'] = [] + rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + return rets + + +def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs): + yaws = torch.linspace(0, 2 * 3.1415, num_frames) + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitch = pitch.tolist() + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) + return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def render_multiview(sample, resolution=512, nviews=30): + r = 2 + fov = 40 + cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] + yaws = [cam[0] for cam in cams] + pitchs = [cam[1] for cam in cams] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) + return res['color'], extrinsics, intrinsics + + +def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs): + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = offset[0] + yaw = [y + yaw_offset for y in yaw] + pitch = [offset[1] for _ in range(4)] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) + return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) diff --git a/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl b/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl new file mode 100644 index 0000000000000000000000000000000000000000..5a7f93e9f56294b4e4079c7c6f32902f3c5890a8 Binary files /dev/null and b/wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl differ diff --git a/wheels/nvdiffrast-0.3.3-py3-none-any.whl b/wheels/nvdiffrast-0.3.3-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..cb56d6d359394e2829cd6a433f6f938a49ee733f Binary files /dev/null and b/wheels/nvdiffrast-0.3.3-py3-none-any.whl differ