diff --git a/.gitignore b/.gitignore
index 848f6b8c145be06f13f45f919289145621fa19ac..c7eea8f0d018962db22512844bad06b53db68e3d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,3 +20,5 @@ __pycache__
 .mypy_cache
 .ruff_cache
 venv
+
+shell.nix
diff --git a/CMakeLists.txt b/CMakeLists.txt
deleted file mode 100644
index 27d6ae32293b453758895027ef26fcaff10d592a..0000000000000000000000000000000000000000
--- a/CMakeLists.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-cmake_minimum_required(VERSION 3.16...3.27)
-
-project(PROJ_NAME LANGUAGES CXX)
-
-set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
-
-if(NOT DEFINED CMAKE_CXX_STANDARD)
-    set(CMAKE_CXX_STANDARD 17)
-endif()
-
-# Rerun:
-include(FetchContent)
-FetchContent_Declare(rerun_sdk URL https://github.com/rerun-io/rerun/releases/download/0.15.1/rerun_cpp_sdk.zip)
-FetchContent_MakeAvailable(rerun_sdk)
-
-add_executable(PROJ_NAME src/main.cpp)
-target_link_libraries(PROJ_NAME rerun_sdk)
-target_include_directories(PROJ_NAME PRIVATE src)
diff --git a/Cargo.lock b/Cargo.lock
deleted file mode 100644
index e484c406c67c652b6f8cc478aafabf15499a984e..0000000000000000000000000000000000000000
--- a/Cargo.lock
+++ /dev/null
@@ -1,7 +0,0 @@
-# This file is automatically @generated by Cargo.
-# It is not intended for manual editing.
-version = 3
-
-[[package]]
-name = "new_project_name"
-version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
deleted file mode 100644
index 42d394dbbbd663017148e3bfd5893ad2258d802b..0000000000000000000000000000000000000000
--- a/Cargo.toml
+++ /dev/null
@@ -1,198 +0,0 @@
-[package]
-authors = ["rerun.io <opensource@rerun.io>"]
-categories = []                                                      # TODO: fill in if you plan on publishing the crate
-description = ""                                                     # TODO: fill in if you plan on publishing the crate
-edition = "2021"
-homepage = "https://github.com/rerun-io/new_repo_name"
-include = ["LICENSE-APACHE", "LICENSE-MIT", "**/*.rs", "Cargo.toml"]
-keywords = []                                                        # TODO: fill in if you plan on publishing the crate
-license = "MIT OR Apache-2.0"
-name = "new_project_name"
-publish = false                                                      # TODO: set to `true` if you plan on publishing the crate
-readme = "README.md"
-repository = "https://github.com/rerun-io/new_repo_name"
-rust-version = "1.76"
-version = "0.1.0"
-
-[package.metadata.docs.rs]
-all-features = true
-targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"]
-
-
-[features]
-default = []
-
-
-[dependencies]
-
-
-[dev-dependencies]
-
-
-[patch.crates-io]
-
-
-[lints]
-workspace = true
-
-
-[workspace.lints.rust]
-unsafe_code = "deny"
-
-elided_lifetimes_in_paths = "warn"
-future_incompatible = "warn"
-nonstandard_style = "warn"
-rust_2018_idioms = "warn"
-rust_2021_prelude_collisions = "warn"
-semicolon_in_expressions_from_macros = "warn"
-trivial_numeric_casts = "warn"
-unsafe_op_in_unsafe_fn = "warn"               # `unsafe_op_in_unsafe_fn` may become the default in future Rust versions: https://github.com/rust-lang/rust/issues/71668
-unused_extern_crates = "warn"
-unused_import_braces = "warn"
-unused_lifetimes = "warn"
-
-trivial_casts = "allow"
-unused_qualifications = "allow"
-
-[workspace.lints.rustdoc]
-all = "warn"
-missing_crate_level_docs = "warn"
-
-# See also clippy.toml
-[workspace.lints.clippy]
-as_ptr_cast_mut = "warn"
-await_holding_lock = "warn"
-bool_to_int_with_if = "warn"
-char_lit_as_u8 = "warn"
-checked_conversions = "warn"
-clear_with_drain = "warn"
-cloned_instead_of_copied = "warn"
-dbg_macro = "warn"
-debug_assert_with_mut_call = "warn"
-derive_partial_eq_without_eq = "warn"
-disallowed_macros = "warn"                  # See clippy.toml
-disallowed_methods = "warn"                 # See clippy.toml
-disallowed_names = "warn"                   # See clippy.toml
-disallowed_script_idents = "warn"           # See clippy.toml
-disallowed_types = "warn"                   # See clippy.toml
-doc_link_with_quotes = "warn"
-doc_markdown = "warn"
-empty_enum = "warn"
-enum_glob_use = "warn"
-equatable_if_let = "warn"
-exit = "warn"
-expl_impl_clone_on_copy = "warn"
-explicit_deref_methods = "warn"
-explicit_into_iter_loop = "warn"
-explicit_iter_loop = "warn"
-fallible_impl_from = "warn"
-filter_map_next = "warn"
-flat_map_option = "warn"
-float_cmp_const = "warn"
-fn_params_excessive_bools = "warn"
-fn_to_numeric_cast_any = "warn"
-from_iter_instead_of_collect = "warn"
-get_unwrap = "warn"
-if_let_mutex = "warn"
-implicit_clone = "warn"
-imprecise_flops = "warn"
-index_refutable_slice = "warn"
-inefficient_to_string = "warn"
-infinite_loop = "warn"
-into_iter_without_iter = "warn"
-invalid_upcast_comparisons = "warn"
-iter_not_returning_iterator = "warn"
-iter_on_empty_collections = "warn"
-iter_on_single_items = "warn"
-iter_over_hash_type = "warn"
-iter_without_into_iter = "warn"
-large_digit_groups = "warn"
-large_include_file = "warn"
-large_stack_arrays = "warn"
-large_stack_frames = "warn"
-large_types_passed_by_value = "warn"
-let_underscore_untyped = "warn"
-let_unit_value = "warn"
-linkedlist = "warn"
-lossy_float_literal = "warn"
-macro_use_imports = "warn"
-manual_assert = "warn"
-manual_clamp = "warn"
-manual_instant_elapsed = "warn"
-manual_let_else = "warn"
-manual_ok_or = "warn"
-manual_string_new = "warn"
-map_err_ignore = "warn"
-map_flatten = "warn"
-map_unwrap_or = "warn"
-match_on_vec_items = "warn"
-match_same_arms = "warn"
-match_wild_err_arm = "warn"
-match_wildcard_for_single_variants = "warn"
-mem_forget = "warn"
-mismatched_target_os = "warn"
-mismatching_type_param_order = "warn"
-missing_assert_message = "warn"
-missing_enforced_import_renames = "warn"
-missing_errors_doc = "warn"
-missing_safety_doc = "warn"
-mut_mut = "warn"
-mutex_integer = "warn"
-needless_borrow = "warn"
-needless_continue = "warn"
-needless_for_each = "warn"
-needless_pass_by_ref_mut = "warn"
-needless_pass_by_value = "warn"
-negative_feature_names = "warn"
-nonstandard_macro_braces = "warn"
-option_option = "warn"
-path_buf_push_overwrite = "warn"
-ptr_as_ptr = "warn"
-ptr_cast_constness = "warn"
-pub_without_shorthand = "warn"
-rc_mutex = "warn"
-readonly_write_lock = "warn"
-redundant_type_annotations = "warn"
-ref_option_ref = "warn"
-rest_pat_in_fully_bound_structs = "warn"
-same_functions_in_if_condition = "warn"
-semicolon_if_nothing_returned = "warn"
-should_panic_without_expect = "warn"
-significant_drop_tightening = "warn"
-single_match_else = "warn"
-str_to_string = "warn"
-string_add = "warn"
-string_add_assign = "warn"
-string_lit_as_bytes = "warn"
-string_lit_chars_any = "warn"
-string_to_string = "warn"
-suspicious_command_arg_space = "warn"
-suspicious_xor_used_as_pow = "warn"
-todo = "warn"
-too_many_lines = "warn"
-trailing_empty_array = "warn"
-trait_duplication_in_bounds = "warn"
-tuple_array_conversions = "warn"
-unchecked_duration_subtraction = "warn"
-undocumented_unsafe_blocks = "warn"
-unimplemented = "warn"
-uninhabited_references = "warn"
-uninlined_format_args = "warn"
-unnecessary_box_returns = "warn"
-unnecessary_safety_doc = "warn"
-unnecessary_struct_initialization = "warn"
-unnecessary_wraps = "warn"
-unnested_or_patterns = "warn"
-unused_peekable = "warn"
-unused_rounding = "warn"
-unused_self = "warn"
-unwrap_used = "warn"
-use_self = "warn"
-useless_transmute = "warn"
-verbose_file_reads = "warn"
-wildcard_dependencies = "warn"
-wildcard_imports = "warn"
-zero_sized_map_values = "warn"
-
-manual_range_contains = "allow" # this one is just worse imho
-ref_patterns = "allow"          # It's nice to avoid ref pattern, but there are some situations that are hard (impossible?) to express without.
diff --git a/README.md b/README.md
index e4555599f36a753de7da24c70cb21e4542a6de8d..36cdf72a33f1279dcd2400161c08ab96ad56bdcf 100644
--- a/README.md
+++ b/README.md
@@ -1,40 +1,3 @@
-# Rerun template repository
-Template for our private and public repos, containing CI, CoC, etc
+## Fork of the [InstantMesh space]() but with [Rerun](https://www.rerun.io) for visualization
 
-When creating a new Rerun repository, use this as a template, then modify it as it makes sense.
-
-This template should be the default for any repository of any kind, including:
-* Rust projects
-* C++ projects
-* Python projects
-* Other stuff
-
-This template includes
-* License files
-* Code of Conduct
-* Helpers for checking and linting Rust code
-  - `cargo-clippy`
-  - `cargo-deny`
-  - `rust-toolchain`
-  - …
-* CI for:
-  - Spell checking
-  - Link checking
-  - C++ checks
-  - Python checks
-  - Rust checks
-
-
-## How to use
-Start by clicking "Use this template" at https://github.com/rerun-io/rerun_template/ or follow [these instructions](https://docs.github.com/en/free-pro-team@latest/github/creating-cloning-and-archiving-repositories/creating-a-repository-from-a-template).
-
-Then follow these steps:
-* Run `scripts/template_update.py init --languages cpp,rust,python` to delete files you don't need (give the languages you need support for)
-* Search and replace all instances of `new_repo_name` with the name of the repository.
-* Search and replace all instances of `new_project_name` with the name of the project (crate/binary name).
-* Search for `TODO` and fill in all those places
-* Replace this `README.md` with something better
-* Commit!
-
-In the future you can always update this repository with the latest changes from the template by running:
-* `scripts/template_update.py update --languages cpp,rust,python`
+The resulting Huggingface space can be found [here.](https://huggingface.co/spaces/rerun/InstantMesh)
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..ace82ebe79d72d64878778ceb94e0e279acf2d41
--- /dev/null
+++ b/app.py
@@ -0,0 +1,308 @@
+from __future__ import annotations
+
+import os
+import shutil
+import threading
+from queue import SimpleQueue
+from typing import Any
+
+import gradio as gr
+import numpy as np
+import rembg
+import rerun as rr
+import rerun.blueprint as rrb
+import spaces
+import torch
+from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
+from einops import rearrange
+from gradio_rerun import Rerun
+from huggingface_hub import hf_hub_download
+from omegaconf import OmegaConf
+from PIL import Image
+from pytorch_lightning import seed_everything
+from torchvision.transforms import v2
+
+from src.models.lrm_mesh import InstantMesh
+from src.utils.camera_util import (
+    FOV_to_intrinsics,
+    get_circular_camera_poses,
+    get_zero123plus_input_cameras,
+)
+from src.utils.infer_util import remove_background, resize_foreground
+from src.utils.train_util import instantiate_from_config
+
+
+def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
+    """Get the rendering camera parameters."""
+    c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
+    if is_flexicubes:
+        cameras = torch.linalg.inv(c2ws)
+        cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
+    else:
+        extrinsics = c2ws.flatten(-2)
+        intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
+        cameras = torch.cat([extrinsics, intrinsics], dim=-1)
+        cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
+    return cameras
+
+
+###############################################################################
+# Configuration.
+###############################################################################
+
+
+def find_cuda():
+    # Check if CUDA_HOME or CUDA_PATH environment variables are set
+    cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
+
+    if cuda_home and os.path.exists(cuda_home):
+        return cuda_home
+
+    # Search for the nvcc executable in the system's PATH
+    nvcc_path = shutil.which("nvcc")
+
+    if nvcc_path:
+        # Remove the 'bin/nvcc' part to get the CUDA installation path
+        cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
+        return cuda_path
+
+    return None
+
+
+cuda_path = find_cuda()
+
+if cuda_path:
+    print(f"CUDA installation found at: {cuda_path}")
+else:
+    print("CUDA installation not found")
+
+config_path = "configs/instant-mesh-large.yaml"
+config = OmegaConf.load(config_path)
+config_name = os.path.basename(config_path).replace(".yaml", "")
+model_config = config.model_config
+infer_config = config.infer_config
+
+IS_FLEXICUBES = True if config_name.startswith("instant-mesh") else False
+
+device = torch.device("cuda")
+
+# load diffusion model
+print("Loading diffusion model ...")
+pipeline = DiffusionPipeline.from_pretrained(
+    "sudo-ai/zero123plus-v1.2",
+    custom_pipeline="zero123plus",
+    torch_dtype=torch.float16,
+)
+pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
+
+# load custom white-background UNet
+unet_ckpt_path = hf_hub_download(
+    repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model"
+)
+state_dict = torch.load(unet_ckpt_path, map_location="cpu")
+pipeline.unet.load_state_dict(state_dict, strict=True)
+
+pipeline = pipeline.to(device)
+print(f"type(pipeline)={type(pipeline)}")
+
+# load reconstruction model
+print("Loading reconstruction model ...")
+model_ckpt_path = hf_hub_download(
+    repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model"
+)
+model: InstantMesh = instantiate_from_config(model_config)
+state_dict = torch.load(model_ckpt_path, map_location="cpu")["state_dict"]
+state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith("lrm_generator.") and "source_camera" not in k}
+model.load_state_dict(state_dict, strict=True)
+
+model = model.to(device)
+
+print("Loading Finished!")
+
+
+def check_input_image(input_image):
+    if input_image is None:
+        raise gr.Error("No image uploaded!")
+
+
+def preprocess(input_image, do_remove_background):
+    rembg_session = rembg.new_session() if do_remove_background else None
+
+    if do_remove_background:
+        input_image = remove_background(input_image, rembg_session)
+        input_image = resize_foreground(input_image, 0.85)
+
+    return input_image
+
+
+def pipeline_callback(
+    log_queue: SimpleQueue, pipe: Any, step_index: int, timestep: float, callback_kwargs: dict[str, Any]
+) -> dict[str, Any]:
+    latents = callback_kwargs["latents"]
+    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]  # type: ignore[attr-defined]
+    image = pipe.image_processor.postprocess(image, output_type="np").squeeze()  # type: ignore[attr-defined]
+
+    log_queue.put(("mvs", rr.Image(image)))
+    log_queue.put(("latents", rr.Tensor(latents.squeeze())))
+
+    return callback_kwargs
+
+
+def generate_mvs(log_queue, input_image, sample_steps, sample_seed):
+    seed_everything(sample_seed)
+
+    return pipeline(
+        input_image,
+        num_inference_steps=sample_steps,
+        callback_on_step_end=lambda *args, **kwargs: pipeline_callback(log_queue, *args, **kwargs),
+    ).images[0]
+
+
+def make3d(log_queue, images: Image.Image):
+    global model
+    if IS_FLEXICUBES:
+        model.init_flexicubes_geometry(device, use_renderer=False)
+    model = model.eval()
+
+    images = np.asarray(images, dtype=np.float32) / 255.0
+    images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()  # (3, 960, 640)
+    images = rearrange(images, "c (n h) (m w) -> (n m) c h w", n=3, m=2)  # (6, 3, 320, 320)
+
+    input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
+
+    images = images.unsqueeze(0).to(device)
+    images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
+
+    with torch.no_grad():
+        # get triplane
+        planes = model.forward_planes(images, input_cameras)
+
+        # get mesh
+        mesh_out = model.extract_mesh(
+            planes,
+            use_texture_map=False,
+            **infer_config,
+        )
+
+        vertices, faces, vertex_colors = mesh_out
+
+        log_queue.put((
+            "mesh",
+            rr.Mesh3D(vertex_positions=vertices, vertex_colors=vertex_colors, triangle_indices=faces),
+        ))
+
+    return mesh_out
+
+
+def generate_blueprint() -> rrb.Blueprint:
+    return rrb.Blueprint(
+        rrb.Horizontal(
+            rrb.Spatial3DView(origin="mesh"),
+            rrb.Grid(
+                rrb.Spatial2DView(origin="z123image"),
+                rrb.Spatial2DView(origin="preprocessed_image"),
+                rrb.Spatial2DView(origin="mvs"),
+                rrb.TensorView(
+                    origin="latents",
+                ),
+            ),
+            column_shares=[1, 1],
+        ),
+        collapse_panels=True,
+    )
+
+
+def compute(log_queue, input_image, do_remove_background, sample_steps, sample_seed):
+    preprocessed_image = preprocess(input_image, do_remove_background)
+    log_queue.put(("preprocessed_image", rr.Image(preprocessed_image)))
+
+    z123_image = generate_mvs(log_queue, preprocessed_image, sample_steps, sample_seed)
+    log_queue.put(("z123image", rr.Image(z123_image)))
+
+    _mesh_out = make3d(log_queue, z123_image)
+
+    log_queue.put("done")
+
+
+@spaces.GPU
+@rr.thread_local_stream("InstantMesh")
+def log_to_rr(input_image, do_remove_background, sample_steps, sample_seed):
+    log_queue = SimpleQueue()
+
+    stream = rr.binary_stream()
+
+    blueprint = generate_blueprint()
+    rr.send_blueprint(blueprint)
+    yield stream.read()
+
+    handle = threading.Thread(
+        target=compute, args=[log_queue, input_image, do_remove_background, sample_steps, sample_seed]
+    )
+    handle.start()
+    while True:
+        msg = log_queue.get()
+        if msg == "done":
+            break
+        else:
+            entity_path, entity = msg
+            rr.log(entity_path, entity)
+            yield stream.read()
+    handle.join()
+
+
+_HEADER_ = """
+<h2><b>Duplicate of the <a href='https://huggingface.co/spaces/TencentARC/InstantMesh'>InstantMesh space</a> that uses <a href='https://rerun.io/'>Rerun</a> for visualization.</b></h2>
+<h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b></a></h2>
+
+**InstantMesh** is a feed-forward framework for efficient 3D mesh generation from a single image based on the LRM/Instant3D architecture.
+
+Technical report: <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a>.
+Source code: <a href='https://github.com/rerun-io/hf-example-instant-mesh'>Github</a>.
+"""
+
+with gr.Blocks() as demo:
+    gr.Markdown(_HEADER_)
+    with gr.Row(variant="panel"):
+        with gr.Column(scale=1):
+            with gr.Row():
+                input_image = gr.Image(
+                    label="Input Image",
+                    image_mode="RGBA",
+                    sources="upload",
+                    # width=256,
+                    # height=256,
+                    type="pil",
+                    elem_id="content_image",
+                )
+            with gr.Row():
+                with gr.Group():
+                    do_remove_background = gr.Checkbox(label="Remove Background", value=True)
+                    sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
+
+                    sample_steps = gr.Slider(label="Sample Steps", minimum=30, maximum=75, value=75, step=5)
+
+            with gr.Row():
+                submit = gr.Button("Generate", elem_id="generate", variant="primary")
+
+            with gr.Row(variant="panel"):
+                gr.Examples(
+                    examples=[os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))],
+                    inputs=[input_image],
+                    label="Examples",
+                    cache_examples=False,
+                    examples_per_page=16,
+                )
+
+        with gr.Column(scale=2):
+            viewer = Rerun(streaming=True, height=800)
+
+            with gr.Row():
+                gr.Markdown("""Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).""")
+
+    mv_images = gr.State()
+
+    submit.click(fn=check_input_image, inputs=[input_image]).success(
+        fn=log_to_rr, inputs=[input_image, do_remove_background, sample_steps, sample_seed], outputs=[viewer]
+    )
+
+demo.launch()
diff --git a/configs/instant-mesh-base.yaml b/configs/instant-mesh-base.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..ad4f4c0cd0d3c6f4d3038b657a41dab82c048dd1
--- /dev/null
+++ b/configs/instant-mesh-base.yaml
@@ -0,0 +1,22 @@
+model_config:
+  target: src.models.lrm_mesh.InstantMesh
+  params:
+    encoder_feat_dim: 768
+    encoder_freeze: false
+    encoder_model_name: facebook/dino-vitb16
+    transformer_dim: 1024
+    transformer_layers: 12
+    transformer_heads: 16
+    triplane_low_res: 32
+    triplane_high_res: 64
+    triplane_dim: 40
+    rendering_samples_per_ray: 96
+    grid_res: 128
+    grid_scale: 2.1
+
+
+infer_config:
+  unet_path: ckpts/diffusion_pytorch_model.bin
+  model_path: ckpts/instant_mesh_base.ckpt
+  texture_resolution: 1024
+  render_resolution: 512
\ No newline at end of file
diff --git a/configs/instant-mesh-large.yaml b/configs/instant-mesh-large.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..e296bc89f6d0d0649136ba2ce0e34490f76a5e41
--- /dev/null
+++ b/configs/instant-mesh-large.yaml
@@ -0,0 +1,22 @@
+model_config:
+  target: src.models.lrm_mesh.InstantMesh
+  params:
+    encoder_feat_dim: 768
+    encoder_freeze: false
+    encoder_model_name: facebook/dino-vitb16
+    transformer_dim: 1024
+    transformer_layers: 16
+    transformer_heads: 16
+    triplane_low_res: 32
+    triplane_high_res: 64
+    triplane_dim: 80
+    rendering_samples_per_ray: 128
+    grid_res: 128
+    grid_scale: 2.1
+
+
+infer_config:
+  unet_path: ckpts/diffusion_pytorch_model.bin
+  model_path: ckpts/instant_mesh_large.ckpt
+  texture_resolution: 1024
+  render_resolution: 512
\ No newline at end of file
diff --git a/configs/instant-nerf-base.yaml b/configs/instant-nerf-base.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..ded3d484751127d430891fc28eb2de664aecd5e1
--- /dev/null
+++ b/configs/instant-nerf-base.yaml
@@ -0,0 +1,21 @@
+model_config:
+  target: src.models.lrm.InstantNeRF
+  params:
+    encoder_feat_dim: 768
+    encoder_freeze: false
+    encoder_model_name: facebook/dino-vitb16
+    transformer_dim: 1024
+    transformer_layers: 12
+    transformer_heads: 16
+    triplane_low_res: 32
+    triplane_high_res: 64
+    triplane_dim: 40
+    rendering_samples_per_ray: 96
+
+
+infer_config:
+  unet_path: ckpts/diffusion_pytorch_model.bin
+  model_path: ckpts/instant_nerf_base.ckpt
+  mesh_threshold: 10.0
+  mesh_resolution: 256
+  render_resolution: 384
\ No newline at end of file
diff --git a/configs/instant-nerf-large.yaml b/configs/instant-nerf-large.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..57494b69d74ee78dca2e2cead2ef68ddfd0fd531
--- /dev/null
+++ b/configs/instant-nerf-large.yaml
@@ -0,0 +1,21 @@
+model_config:
+  target: src.models.lrm.InstantNeRF
+  params:
+    encoder_feat_dim: 768
+    encoder_freeze: false
+    encoder_model_name: facebook/dino-vitb16
+    transformer_dim: 1024
+    transformer_layers: 16
+    transformer_heads: 16
+    triplane_low_res: 32
+    triplane_high_res: 64
+    triplane_dim: 80
+    rendering_samples_per_ray: 128
+
+
+infer_config:
+  unet_path: ckpts/diffusion_pytorch_model.bin
+  model_path: ckpts/instant_nerf_large.ckpt
+  mesh_threshold: 10.0
+  mesh_resolution: 256
+  render_resolution: 384
\ No newline at end of file
diff --git a/examples/bird.jpg b/examples/bird.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ac70a36ebefb87fb283f3bb95d07fe71700702a3
Binary files /dev/null and b/examples/bird.jpg differ
diff --git a/examples/bubble_mart_blue.png b/examples/bubble_mart_blue.png
new file mode 100644
index 0000000000000000000000000000000000000000..af870322d4a8a2f237546fbea9560bb8e5f50364
Binary files /dev/null and b/examples/bubble_mart_blue.png differ
diff --git a/examples/cake.jpg b/examples/cake.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8dbebb6901e1230405be3451c0165e80458d5542
Binary files /dev/null and b/examples/cake.jpg differ
diff --git a/examples/cartoon_dinosaur.png b/examples/cartoon_dinosaur.png
new file mode 100644
index 0000000000000000000000000000000000000000..598964626b767eb6470a28a68537c091fc5de2f8
Binary files /dev/null and b/examples/cartoon_dinosaur.png differ
diff --git a/examples/chair_armed.png b/examples/chair_armed.png
new file mode 100644
index 0000000000000000000000000000000000000000..2ab67e95ed57fbc5ebcd7d934827fd7fb03ab3ff
Binary files /dev/null and b/examples/chair_armed.png differ
diff --git a/examples/chair_comfort.jpg b/examples/chair_comfort.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..918347fe51773d7ecaa7fb929274db8d7d5d3e19
Binary files /dev/null and b/examples/chair_comfort.jpg differ
diff --git a/examples/chair_wood.jpg b/examples/chair_wood.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bc60569896fb02a46185aabb85086890f0f400d7
Binary files /dev/null and b/examples/chair_wood.jpg differ
diff --git a/examples/chest.jpg b/examples/chest.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..26ae0b145887e43b850d298b94fe54828e909492
Binary files /dev/null and b/examples/chest.jpg differ
diff --git a/examples/cute_horse.jpg b/examples/cute_horse.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ec8807d313b983e3cc34ee89bbf3f312d6ce66eb
Binary files /dev/null and b/examples/cute_horse.jpg differ
diff --git a/examples/cute_tiger.jpg b/examples/cute_tiger.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..82e873258d9f3fd6d569205ab75deb8a26918356
Binary files /dev/null and b/examples/cute_tiger.jpg differ
diff --git a/examples/earphone.jpg b/examples/earphone.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..498e4196b0d68f8809d049e7178b80592a31a0a2
Binary files /dev/null and b/examples/earphone.jpg differ
diff --git a/examples/fox.jpg b/examples/fox.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1f2efc1c3a9c4ad8f36ad93082c124c91a6e9ef7
Binary files /dev/null and b/examples/fox.jpg differ
diff --git a/examples/fruit.jpg b/examples/fruit.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..07034ad3721de0e09c7509b22a7d3bc9679304d0
Binary files /dev/null and b/examples/fruit.jpg differ
diff --git a/examples/fruit_elephant.jpg b/examples/fruit_elephant.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ef8eaf3b88ae0a38272b34802fe40032055afa58
Binary files /dev/null and b/examples/fruit_elephant.jpg differ
diff --git a/examples/genshin_building.png b/examples/genshin_building.png
new file mode 100644
index 0000000000000000000000000000000000000000..00b6a949d01283e1ae30fac4bd6040e13f18a055
Binary files /dev/null and b/examples/genshin_building.png differ
diff --git a/examples/genshin_teapot.png b/examples/genshin_teapot.png
new file mode 100644
index 0000000000000000000000000000000000000000..1f13a6edfe67ced810b4513117279067f0360fae
Binary files /dev/null and b/examples/genshin_teapot.png differ
diff --git a/examples/hatsune_miku.png b/examples/hatsune_miku.png
new file mode 100644
index 0000000000000000000000000000000000000000..2fecf005fdd56a396c4894256fbb98fcc1c4dd8f
Binary files /dev/null and b/examples/hatsune_miku.png differ
diff --git a/examples/house2.jpg b/examples/house2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2eb8d63a6b91d5b16e729710c8b703aa5c11f9e5
Binary files /dev/null and b/examples/house2.jpg differ
diff --git a/examples/mushroom_teapot.jpg b/examples/mushroom_teapot.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a6c767354305f5467a4c0d5f199eee2a120f4501
Binary files /dev/null and b/examples/mushroom_teapot.jpg differ
diff --git a/examples/pikachu.png b/examples/pikachu.png
new file mode 100644
index 0000000000000000000000000000000000000000..e7579c16957a3e13b80d53cf0a41ddfdfd47b92d
Binary files /dev/null and b/examples/pikachu.png differ
diff --git a/examples/plant.jpg b/examples/plant.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3519c1639c3f837d9f1147cba1172e6aaab25a23
Binary files /dev/null and b/examples/plant.jpg differ
diff --git a/examples/robot.jpg b/examples/robot.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..929450fba69a20389f39d46cb51d27facc1bba6d
Binary files /dev/null and b/examples/robot.jpg differ
diff --git a/examples/sea_turtle.png b/examples/sea_turtle.png
new file mode 100644
index 0000000000000000000000000000000000000000..27c3e2a9c7d44cb33914422b410ef41cf6591433
Binary files /dev/null and b/examples/sea_turtle.png differ
diff --git a/examples/skating_shoe.jpg b/examples/skating_shoe.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5f21cb1d43e9d42d2836118963fc1d2874523748
Binary files /dev/null and b/examples/skating_shoe.jpg differ
diff --git a/examples/sorting_board.png b/examples/sorting_board.png
new file mode 100644
index 0000000000000000000000000000000000000000..a40fb8362afce0e323dd4517bba784cc652f5f6c
Binary files /dev/null and b/examples/sorting_board.png differ
diff --git a/examples/sword.png b/examples/sword.png
new file mode 100644
index 0000000000000000000000000000000000000000..3068cb9bdbbd9ed3c0a143fd5c741abbc58508e3
Binary files /dev/null and b/examples/sword.png differ
diff --git a/examples/toy_car.jpg b/examples/toy_car.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ffa72aa6c1510e200e5d640461b779d2e7bf4997
Binary files /dev/null and b/examples/toy_car.jpg differ
diff --git a/examples/watermelon.png b/examples/watermelon.png
new file mode 100644
index 0000000000000000000000000000000000000000..52b39917abcbd2f1eef9b7c8cf9aa602bddde1bf
Binary files /dev/null and b/examples/watermelon.png differ
diff --git a/examples/whitedog.png b/examples/whitedog.png
new file mode 100644
index 0000000000000000000000000000000000000000..16c598a8133643898408ea806b69d5b18c53be7d
Binary files /dev/null and b/examples/whitedog.png differ
diff --git a/examples/x_teapot.jpg b/examples/x_teapot.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4e1cb46c5541dcc4ea544864e2eeebd42dfcb18a
Binary files /dev/null and b/examples/x_teapot.jpg differ
diff --git a/examples/x_toyduck.jpg b/examples/x_toyduck.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5e60d43bd76d7511e44568c4f9bba2a11a1a4f04
Binary files /dev/null and b/examples/x_toyduck.jpg differ
diff --git a/main.py b/main.py
deleted file mode 100755
index 15d44c5122d71bba499003b10d4544135faa1a84..0000000000000000000000000000000000000000
--- a/main.py
+++ /dev/null
@@ -1,11 +0,0 @@
-#!/usr/bin/env python3
-
-from __future__ import annotations
-
-
-def main() -> None:
-    pass
-
-
-if __name__ == "__main__":
-    main()
diff --git a/requirements.txt b/requirements.txt
index 98ca0e71a07dd50afbbacce4d76d004f591371b4..490c71aad8c786b17e85a79fda8d1aec146d392d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1,27 @@
-rerun-sdk>=0.15.0,<0.16.0
+spaces
+torch==2.1.0
+torchvision==0.16.0
+torchaudio==2.1.0
+pytorch-lightning==2.1.2
+einops
+omegaconf
+deepspeed
+torchmetrics
+webdataset
+accelerate
+tensorboard
+PyMCubes
+trimesh
+rembg
+transformers
+diffusers==0.28.2
+bitsandbytes
+imageio[ffmpeg]
+xatlas
+plyfile
+xformers==0.0.22.post7
+git+https://github.com/NVlabs/nvdiffrast/
+huggingface-hub
+gradio_client >= 0.12
+rerun-sdk>=0.16.0,<0.17.0
+gradio_rerun
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/data/__init__.py b/src/data/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/data/objaverse.py b/src/data/objaverse.py
new file mode 100755
index 0000000000000000000000000000000000000000..b54c3267f76bcc79cd8b9b45c38cd00cc4f933e4
--- /dev/null
+++ b/src/data/objaverse.py
@@ -0,0 +1,322 @@
+from __future__ import annotations
+
+import json
+import math
+import os
+from pathlib import Path
+
+import cv2
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torch.nn.functional as F
+import webdataset as wds
+from PIL import Image
+from torch.utils.data import Dataset
+from torch.utils.data.distributed import DistributedSampler
+
+from src.utils.camera_util import (
+    FOV_to_intrinsics,
+    center_looking_at_camera_pose,
+    get_surrounding_views,
+)
+from src.utils.train_util import instantiate_from_config
+
+
+class DataModuleFromConfig(pl.LightningDataModule):
+    def __init__(
+        self,
+        batch_size=8,
+        num_workers=4,
+        train=None,
+        validation=None,
+        test=None,
+        **kwargs,
+    ):
+        super().__init__()
+
+        self.batch_size = batch_size
+        self.num_workers = num_workers
+
+        self.dataset_configs = dict()
+        if train is not None:
+            self.dataset_configs['train'] = train
+        if validation is not None:
+            self.dataset_configs['validation'] = validation
+        if test is not None:
+            self.dataset_configs['test'] = test
+
+    def setup(self, stage):
+
+        if stage in ['fit']:
+            self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
+        else:
+            raise NotImplementedError
+
+    def train_dataloader(self):
+
+        sampler = DistributedSampler(self.datasets['train'])
+        return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
+
+    def val_dataloader(self):
+
+        sampler = DistributedSampler(self.datasets['validation'])
+        return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
+
+    def test_dataloader(self):
+
+        return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
+
+
+class ObjaverseData(Dataset):
+    def __init__(self,
+        root_dir='objaverse/',
+        meta_fname='valid_paths.json',
+        input_image_dir='rendering_random_32views',
+        target_image_dir='rendering_random_32views',
+        input_view_num=6,
+        target_view_num=2,
+        total_view_n=32,
+        fov=50,
+        camera_rotation=True,
+        validation=False,
+    ):
+        self.root_dir = Path(root_dir)
+        self.input_image_dir = input_image_dir
+        self.target_image_dir = target_image_dir
+
+        self.input_view_num = input_view_num
+        self.target_view_num = target_view_num
+        self.total_view_n = total_view_n
+        self.fov = fov
+        self.camera_rotation = camera_rotation
+
+        with open(os.path.join(root_dir, meta_fname)) as f:
+            filtered_dict = json.load(f)
+        paths = filtered_dict['good_objs']
+        self.paths = paths
+
+        self.depth_scale = 4.0
+
+        len(self.paths)
+        print('============= length of dataset %d =============' % len(self.paths))
+
+    def __len__(self):
+        return len(self.paths)
+
+    def load_im(self, path, color):
+        """Replace background pixel with random color in rendering."""
+        pil_img = Image.open(path)
+
+        image = np.asarray(pil_img, dtype=np.float32) / 255.
+        alpha = image[:, :, 3:]
+        image = image[:, :, :3] * alpha + color * (1 - alpha)
+
+        image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
+        alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
+        return image, alpha
+
+    def __getitem__(self, index):
+        # load data
+        while True:
+            input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
+            target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
+
+            indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
+            input_indices = indices[:self.input_view_num]
+            target_indices = indices[self.input_view_num:]
+
+            '''background color, default: white'''
+            bg_white = [1., 1., 1.]
+            bg_black = [0., 0., 0.]
+
+            image_list = []
+            alpha_list = []
+            depth_list = []
+            normal_list = []
+            pose_list = []
+
+            try:
+                input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
+                for idx in input_indices:
+                    image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
+                    normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
+                    depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
+                    depth = torch.from_numpy(depth).unsqueeze(0)
+                    pose = input_cameras[idx]
+                    pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
+
+                    image_list.append(image)
+                    alpha_list.append(alpha)
+                    depth_list.append(depth)
+                    normal_list.append(normal)
+                    pose_list.append(pose)
+
+                target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
+                for idx in target_indices:
+                    image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
+                    normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
+                    depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
+                    depth = torch.from_numpy(depth).unsqueeze(0)
+                    pose = target_cameras[idx]
+                    pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
+
+                    image_list.append(image)
+                    alpha_list.append(alpha)
+                    depth_list.append(depth)
+                    normal_list.append(normal)
+                    pose_list.append(pose)
+
+            except Exception as e:
+                print(e)
+                index = np.random.randint(0, len(self.paths))
+                continue
+
+            break
+
+        images = torch.stack(image_list, dim=0).float()                 # (6+V, 3, H, W)
+        alphas = torch.stack(alpha_list, dim=0).float()                 # (6+V, 1, H, W)
+        depths = torch.stack(depth_list, dim=0).float()                 # (6+V, 1, H, W)
+        normals = torch.stack(normal_list, dim=0).float()               # (6+V, 3, H, W)
+        w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float()    # (6+V, 4, 4)
+        c2ws = torch.linalg.inv(w2cs).float()
+
+        normals = normals * 2.0 - 1.0
+        normals = F.normalize(normals, dim=1)
+        normals = (normals + 1.0) / 2.0
+        normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
+
+        # random rotation along z axis
+        if self.camera_rotation:
+            degree = np.random.uniform(0, math.pi * 2)
+            rot = torch.tensor([
+                [np.cos(degree), -np.sin(degree), 0, 0],
+                [np.sin(degree), np.cos(degree), 0, 0],
+                [0, 0, 1, 0],
+                [0, 0, 0, 1],
+            ]).unsqueeze(0).float()
+            c2ws = torch.matmul(rot, c2ws)
+
+            # rotate normals
+            N, _, H, W = normals.shape
+            normals = normals * 2.0 - 1.0
+            normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
+            normals = F.normalize(normals, dim=1)
+            normals = (normals + 1.0) / 2.0
+            normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
+
+        # random scaling
+        if np.random.rand() < 0.5:
+            scale = np.random.uniform(0.8, 1.0)
+            c2ws[:, :3, 3] *= scale
+            depths *= scale
+
+        # instrinsics of perspective cameras
+        K = FOV_to_intrinsics(self.fov)
+        Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
+
+        data = {
+            'input_images': images[:self.input_view_num],     # (6, 3, H, W)
+            'input_alphas': alphas[:self.input_view_num],           # (6, 1, H, W)
+            'input_depths': depths[:self.input_view_num],           # (6, 1, H, W)
+            'input_normals': normals[:self.input_view_num],         # (6, 3, H, W)
+            'input_c2ws': c2ws_input[:self.input_view_num],         # (6, 4, 4)
+            'input_Ks': Ks[:self.input_view_num],                   # (6, 3, 3)
+
+            # lrm generator input and supervision
+            'target_images': images[self.input_view_num:],          # (V, 3, H, W)
+            'target_alphas': alphas[self.input_view_num:],          # (V, 1, H, W)
+            'target_depths': depths[self.input_view_num:],          # (V, 1, H, W)
+            'target_normals': normals[self.input_view_num:],        # (V, 3, H, W)
+            'target_c2ws': c2ws[self.input_view_num:],              # (V, 4, 4)
+            'target_Ks': Ks[self.input_view_num:],                  # (V, 3, 3)
+
+            'depth_available': 1,
+        }
+        return data
+
+
+class ValidationData(Dataset):
+    def __init__(self,
+        root_dir='objaverse/',
+        input_view_num=6,
+        input_image_size=256,
+        fov=50,
+    ):
+        self.root_dir = Path(root_dir)
+        self.input_view_num = input_view_num
+        self.input_image_size = input_image_size
+        self.fov = fov
+
+        self.paths = sorted(os.listdir(self.root_dir))
+        print('============= length of dataset %d =============' % len(self.paths))
+
+        cam_distance = 2.5
+        azimuths = np.array([30, 90, 150, 210, 270, 330])
+        elevations = np.array([30, -20, 30, -20, 30, -20])
+        azimuths = np.deg2rad(azimuths)
+        elevations = np.deg2rad(elevations)
+
+        x = cam_distance * np.cos(elevations) * np.cos(azimuths)
+        y = cam_distance * np.cos(elevations) * np.sin(azimuths)
+        z = cam_distance * np.sin(elevations)
+
+        cam_locations = np.stack([x, y, z], axis=-1)
+        cam_locations = torch.from_numpy(cam_locations).float()
+        c2ws = center_looking_at_camera_pose(cam_locations)
+        self.c2ws = c2ws.float()
+        self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
+
+        render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
+        render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
+        self.render_c2ws = render_c2ws.float()
+        self.render_Ks = render_Ks.float()
+
+    def __len__(self):
+        return len(self.paths)
+
+    def load_im(self, path, color):
+        """Replace background pixel with random color in rendering."""
+        pil_img = Image.open(path)
+        pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
+
+        image = np.asarray(pil_img, dtype=np.float32) / 255.
+        if image.shape[-1] == 4:
+            alpha = image[:, :, 3:]
+            image = image[:, :, :3] * alpha + color * (1 - alpha)
+        else:
+            alpha = np.ones_like(image[:, :, :1])
+
+        image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
+        alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
+        return image, alpha
+
+    def __getitem__(self, index):
+        # load data
+        input_image_path = os.path.join(self.root_dir, self.paths[index])
+
+        '''background color, default: white'''
+        # color = np.random.uniform(0.48, 0.52)
+        bkg_color = [1.0, 1.0, 1.0]
+
+        image_list = []
+        alpha_list = []
+
+        for idx in range(self.input_view_num):
+            image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
+            image_list.append(image)
+            alpha_list.append(alpha)
+
+        images = torch.stack(image_list, dim=0).float()                     # (6+V, 3, H, W)
+        alphas = torch.stack(alpha_list, dim=0).float()                 # (6+V, 1, H, W)
+
+        data = {
+            'input_images': images,                 # (6, 3, H, W)
+            'input_alphas': alphas,             # (6, 1, H, W)
+            'input_c2ws': self.c2ws,            # (6, 4, 4)
+            'input_Ks': self.Ks,                # (6, 3, 3)
+
+            'render_c2ws': self.render_c2ws,
+            'render_Ks': self.render_Ks,
+        }
+        return data
diff --git a/src/lib.rs b/src/lib.rs
deleted file mode 100644
index ec3d2bdcddbd8d8a1ba68ded580faa6ebb220f1e..0000000000000000000000000000000000000000
--- a/src/lib.rs
+++ /dev/null
@@ -1 +0,0 @@
-//! Example of a Rust library.
diff --git a/src/main.cpp b/src/main.cpp
deleted file mode 100644
index f7aff572003b99f0d7671298611921dc8cb79c91..0000000000000000000000000000000000000000
--- a/src/main.cpp
+++ /dev/null
@@ -1,8 +0,0 @@
-#include <cstdio>
-
-#include <rerun.hpp>
-
-int main(int argc, const char* argv[]) {
-    printf("Hello, World!\n");
-    return 0;
-}
diff --git a/src/main.rs b/src/main.rs
deleted file mode 100644
index eced7e313d5dd9ddf79f97d15838134fb854c66b..0000000000000000000000000000000000000000
--- a/src/main.rs
+++ /dev/null
@@ -1,5 +0,0 @@
-//! Example of a Rust binary.
-
-fn main() {
-    println!("Hello, PROJ_NAME!");
-}
diff --git a/src/model.py b/src/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..7c01afecb1f7da1f956b1420b75a7d3588098f40
--- /dev/null
+++ b/src/model.py
@@ -0,0 +1,313 @@
+from __future__ import annotations
+
+import os
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+from torchvision.transforms import v2
+from torchvision.utils import make_grid, save_image
+
+from src.utils.train_util import instantiate_from_config
+
+
+class MVRecon(pl.LightningModule):
+    def __init__(
+        self,
+        lrm_generator_config,
+        lrm_path=None,
+        input_size=256,
+        render_size=192,
+    ):
+        super().__init__()
+
+        self.input_size = input_size
+        self.render_size = render_size
+
+        # init modules
+        self.lrm_generator = instantiate_from_config(lrm_generator_config)
+        if lrm_path is not None:
+            lrm_ckpt = torch.load(lrm_path)
+            self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
+
+        self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
+
+        self.validation_step_outputs = []
+
+    def on_fit_start(self):
+        if self.global_rank == 0:
+            os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
+            os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
+
+    def prepare_batch_data(self, batch):
+        lrm_generator_input = {}
+        render_gt = {}   # for supervision
+
+        # input images
+        images = batch['input_images']
+        images = v2.functional.resize(
+            images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
+
+        lrm_generator_input['images'] = images.to(self.device)
+
+        # input cameras and render cameras
+        input_c2ws = batch['input_c2ws'].flatten(-2)
+        input_Ks = batch['input_Ks'].flatten(-2)
+        target_c2ws = batch['target_c2ws'].flatten(-2)
+        target_Ks = batch['target_Ks'].flatten(-2)
+        render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
+        render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
+        render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
+
+        input_extrinsics = input_c2ws[:, :, :12]
+        input_intrinsics = torch.stack([
+            input_Ks[:, :, 0], input_Ks[:, :, 4],
+            input_Ks[:, :, 2], input_Ks[:, :, 5],
+        ], dim=-1)
+        cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
+
+        # add noise to input cameras
+        cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
+
+        lrm_generator_input['cameras'] = cameras.to(self.device)
+        lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
+
+        # target images
+        target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
+        target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
+        target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
+
+        # random crop
+        render_size = np.random.randint(self.render_size, 513)
+        target_images = v2.functional.resize(
+            target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
+        target_depths = v2.functional.resize(
+            target_depths, render_size, interpolation=0, antialias=True)
+        target_alphas = v2.functional.resize(
+            target_alphas, render_size, interpolation=0, antialias=True)
+
+        crop_params = v2.RandomCrop.get_params(
+            target_images, output_size=(self.render_size, self.render_size))
+        target_images = v2.functional.crop(target_images, *crop_params)
+        target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
+        target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
+
+        lrm_generator_input['render_size'] = render_size
+        lrm_generator_input['crop_params'] = crop_params
+
+        render_gt['target_images'] = target_images.to(self.device)
+        render_gt['target_depths'] = target_depths.to(self.device)
+        render_gt['target_alphas'] = target_alphas.to(self.device)
+
+        return lrm_generator_input, render_gt
+
+    def prepare_validation_batch_data(self, batch):
+        lrm_generator_input = {}
+
+        # input images
+        images = batch['input_images']
+        images = v2.functional.resize(
+            images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
+
+        lrm_generator_input['images'] = images.to(self.device)
+
+        input_c2ws = batch['input_c2ws'].flatten(-2)
+        input_Ks = batch['input_Ks'].flatten(-2)
+
+        input_extrinsics = input_c2ws[:, :, :12]
+        input_intrinsics = torch.stack([
+            input_Ks[:, :, 0], input_Ks[:, :, 4],
+            input_Ks[:, :, 2], input_Ks[:, :, 5],
+        ], dim=-1)
+        cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
+
+        lrm_generator_input['cameras'] = cameras.to(self.device)
+
+        render_c2ws = batch['render_c2ws'].flatten(-2)
+        render_Ks = batch['render_Ks'].flatten(-2)
+        render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
+
+        lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
+        lrm_generator_input['render_size'] = 384
+        lrm_generator_input['crop_params'] = None
+
+        return lrm_generator_input
+
+    def forward_lrm_generator(
+        self,
+        images,
+        cameras,
+        render_cameras,
+        render_size=192,
+        crop_params=None,
+        chunk_size=1,
+    ):
+        planes = torch.utils.checkpoint.checkpoint(
+            self.lrm_generator.forward_planes,
+            images,
+            cameras,
+            use_reentrant=False,
+        )
+        frames = []
+        for i in range(0, render_cameras.shape[1], chunk_size):
+            frames.append(
+                torch.utils.checkpoint.checkpoint(
+                    self.lrm_generator.synthesizer,
+                    planes,
+                    cameras=render_cameras[:, i:i+chunk_size],
+                    render_size=render_size,
+                    crop_params=crop_params,
+                    use_reentrant=False
+                )
+            )
+        frames = {
+            k: torch.cat([r[k] for r in frames], dim=1)
+            for k in frames[0].keys()
+        }
+        return frames
+
+    def forward(self, lrm_generator_input):
+        images = lrm_generator_input['images']
+        cameras = lrm_generator_input['cameras']
+        render_cameras = lrm_generator_input['render_cameras']
+        render_size = lrm_generator_input['render_size']
+        crop_params = lrm_generator_input['crop_params']
+
+        out = self.forward_lrm_generator(
+            images,
+            cameras,
+            render_cameras,
+            render_size=render_size,
+            crop_params=crop_params,
+            chunk_size=1,
+        )
+        render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
+        render_depths = out['images_depth']
+        render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
+
+        out = {
+            'render_images': render_images,
+            'render_depths': render_depths,
+            'render_alphas': render_alphas,
+        }
+        return out
+
+    def training_step(self, batch, batch_idx):
+        lrm_generator_input, render_gt = self.prepare_batch_data(batch)
+
+        render_out = self.forward(lrm_generator_input)
+
+        loss, loss_dict = self.compute_loss(render_out, render_gt)
+
+        self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+
+        if self.global_step % 1000 == 0 and self.global_rank == 0:
+            B, N, C, H, W = render_gt['target_images'].shape
+            N_in = lrm_generator_input['images'].shape[1]
+
+            input_images = v2.functional.resize(
+                lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
+            input_images = torch.cat(
+                [input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
+
+            input_images = rearrange(
+                input_images, 'b n c h w -> b c h (n w)')
+            target_images = rearrange(
+                render_gt['target_images'], 'b n c h w -> b c h (n w)')
+            render_images = rearrange(
+                render_out['render_images'], 'b n c h w -> b c h (n w)')
+            target_alphas = rearrange(
+                repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
+            render_alphas = rearrange(
+                repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
+            target_depths = rearrange(
+                repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
+            render_depths = rearrange(
+                repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
+            MAX_DEPTH = torch.max(target_depths)
+            target_depths = target_depths / MAX_DEPTH * target_alphas
+            render_depths = render_depths / MAX_DEPTH
+
+            grid = torch.cat([
+                input_images,
+                target_images, render_images,
+                target_alphas, render_alphas,
+                target_depths, render_depths,
+            ], dim=-2)
+            grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
+
+            save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
+
+        return loss
+
+    def compute_loss(self, render_out, render_gt):
+        # NOTE: the rgb value range of OpenLRM is [0, 1]
+        render_images = render_out['render_images']
+        target_images = render_gt['target_images'].to(render_images)
+        render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
+        target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
+
+        loss_mse = F.mse_loss(render_images, target_images)
+        loss_lpips = 2.0 * self.lpips(render_images, target_images)
+
+        render_alphas = render_out['render_alphas']
+        target_alphas = render_gt['target_alphas']
+        loss_mask = F.mse_loss(render_alphas, target_alphas)
+
+        loss = loss_mse + loss_lpips + loss_mask
+
+        prefix = 'train'
+        loss_dict = {}
+        loss_dict.update({f'{prefix}/loss_mse': loss_mse})
+        loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
+        loss_dict.update({f'{prefix}/loss_mask': loss_mask})
+        loss_dict.update({f'{prefix}/loss': loss})
+
+        return loss, loss_dict
+
+    @torch.no_grad()
+    def validation_step(self, batch, batch_idx):
+        lrm_generator_input = self.prepare_validation_batch_data(batch)
+
+        render_out = self.forward(lrm_generator_input)
+        render_images = render_out['render_images']
+        render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
+
+        self.validation_step_outputs.append(render_images)
+
+    def on_validation_epoch_end(self):
+        images = torch.cat(self.validation_step_outputs, dim=-1)
+
+        all_images = self.all_gather(images)
+        all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
+
+        if self.global_rank == 0:
+            image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
+
+            grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
+            save_image(grid, image_path)
+            print(f"Saved image to {image_path}")
+
+        self.validation_step_outputs.clear()
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+
+        params = []
+
+        lrm_params_fast, lrm_params_slow = [], []
+        for n, p in self.lrm_generator.named_parameters():
+            if 'adaLN_modulation' in n or 'camera_embedder' in n:
+                lrm_params_fast.append(p)
+            else:
+                lrm_params_slow.append(p)
+        params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
+        params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
+
+        optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
+        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
+
+        return {'optimizer': optimizer, 'lr_scheduler': scheduler}
diff --git a/src/model_mesh.py b/src/model_mesh.py
new file mode 100755
index 0000000000000000000000000000000000000000..27737873c96e7c3bbe84bf109473af2f5dd26cb9
--- /dev/null
+++ b/src/model_mesh.py
@@ -0,0 +1,327 @@
+from __future__ import annotations
+
+import os
+
+import pytorch_lightning as pl
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+from torchvision.transforms import v2
+from torchvision.utils import make_grid, save_image
+
+from src.utils.train_util import instantiate_from_config
+
+
+# Regulrarization loss for FlexiCubes
+def sdf_reg_loss_batch(sdf, all_edges):
+    sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
+    mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
+    sdf_f1x6x2 = sdf_f1x6x2[mask]
+    sdf_diff = F.binary_cross_entropy_with_logits(
+        sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
+               F.binary_cross_entropy_with_logits(
+                   sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
+    return sdf_diff
+
+
+class MVRecon(pl.LightningModule):
+    def __init__(
+        self,
+        lrm_generator_config,
+        input_size=256,
+        render_size=512,
+        init_ckpt=None,
+    ):
+        super().__init__()
+
+        self.input_size = input_size
+        self.render_size = render_size
+
+        # init modules
+        self.lrm_generator = instantiate_from_config(lrm_generator_config)
+
+        self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
+
+        # Load weights from pretrained MVRecon model, and use the mlp
+        # weights to initialize the weights of sdf and rgb mlps.
+        if init_ckpt is not None:
+            sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
+            sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
+            sd_fc = {}
+            for k, v in sd.items():
+                if k.startswith('lrm_generator.synthesizer.decoder.net.'):
+                    if k.startswith('lrm_generator.synthesizer.decoder.net.6.'):    # last layer
+                        # Here we assume the density filed's isosurface threshold is t,
+                        # we reverse the sign of density filed to initialize SDF field.
+                        # -(w*x + b - t) = (-w)*x + (t - b)
+                        if 'weight' in k:
+                            sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
+                        else:
+                            sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
+                        sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
+                    else:
+                        sd_fc[k.replace('net.', 'net_sdf.')] = v
+                        sd_fc[k.replace('net.', 'net_rgb.')] = v
+                else:
+                    sd_fc[k] = v
+            sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
+            # missing `net_deformation` and `net_weight` parameters
+            self.lrm_generator.load_state_dict(sd_fc, strict=False)
+            print(f'Loaded weights from {init_ckpt}')
+
+        self.validation_step_outputs = []
+
+    def on_fit_start(self):
+        device = torch.device(f'cuda:{self.global_rank}')
+        self.lrm_generator.init_flexicubes_geometry(device)
+        if self.global_rank == 0:
+            os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
+            os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
+
+    def prepare_batch_data(self, batch):
+        lrm_generator_input = {}
+        render_gt = {}
+
+        # input images
+        images = batch['input_images']
+        images = v2.functional.resize(
+            images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
+
+        lrm_generator_input['images'] = images.to(self.device)
+
+        # input cameras and render cameras
+        input_c2ws = batch['input_c2ws']
+        input_Ks = batch['input_Ks']
+        target_c2ws = batch['target_c2ws']
+
+        render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
+        render_w2cs = torch.linalg.inv(render_c2ws)
+
+        input_extrinsics = input_c2ws.flatten(-2)
+        input_extrinsics = input_extrinsics[:, :, :12]
+        input_intrinsics = input_Ks.flatten(-2)
+        input_intrinsics = torch.stack([
+            input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
+            input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
+        ], dim=-1)
+        cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
+
+        # add noise to input_cameras
+        cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
+
+        lrm_generator_input['cameras'] = cameras.to(self.device)
+        lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
+
+        # target images
+        target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
+        target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
+        target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
+        target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
+
+        render_size = self.render_size
+        target_images = v2.functional.resize(
+            target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
+        target_depths = v2.functional.resize(
+            target_depths, render_size, interpolation=0, antialias=True)
+        target_alphas = v2.functional.resize(
+            target_alphas, render_size, interpolation=0, antialias=True)
+        target_normals = v2.functional.resize(
+            target_normals, render_size, interpolation=3, antialias=True)
+
+        lrm_generator_input['render_size'] = render_size
+
+        render_gt['target_images'] = target_images.to(self.device)
+        render_gt['target_depths'] = target_depths.to(self.device)
+        render_gt['target_alphas'] = target_alphas.to(self.device)
+        render_gt['target_normals'] = target_normals.to(self.device)
+
+        return lrm_generator_input, render_gt
+
+    def prepare_validation_batch_data(self, batch):
+        lrm_generator_input = {}
+
+        # input images
+        images = batch['input_images']
+        images = v2.functional.resize(
+            images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
+
+        lrm_generator_input['images'] = images.to(self.device)
+
+        # input cameras
+        input_c2ws = batch['input_c2ws'].flatten(-2)
+        input_Ks = batch['input_Ks'].flatten(-2)
+
+        input_extrinsics = input_c2ws[:, :, :12]
+        input_intrinsics = torch.stack([
+            input_Ks[:, :, 0], input_Ks[:, :, 4],
+            input_Ks[:, :, 2], input_Ks[:, :, 5],
+        ], dim=-1)
+        cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
+
+        lrm_generator_input['cameras'] = cameras.to(self.device)
+
+        # render cameras
+        render_c2ws = batch['render_c2ws']
+        render_w2cs = torch.linalg.inv(render_c2ws)
+
+        lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
+        lrm_generator_input['render_size'] = 384
+
+        return lrm_generator_input
+
+    def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
+        planes = torch.utils.checkpoint.checkpoint(
+            self.lrm_generator.forward_planes,
+            images,
+            cameras,
+            use_reentrant=False,
+        )
+        out = self.lrm_generator.forward_geometry(
+            planes,
+            render_cameras,
+            render_size,
+        )
+        return out
+
+    def forward(self, lrm_generator_input):
+        images = lrm_generator_input['images']
+        cameras = lrm_generator_input['cameras']
+        render_cameras = lrm_generator_input['render_cameras']
+        render_size = lrm_generator_input['render_size']
+
+        out = self.forward_lrm_generator(
+            images, cameras, render_cameras, render_size=render_size)
+
+        return out
+
+    def training_step(self, batch, batch_idx):
+        lrm_generator_input, render_gt = self.prepare_batch_data(batch)
+
+        render_out = self.forward(lrm_generator_input)
+
+        loss, loss_dict = self.compute_loss(render_out, render_gt)
+
+        self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+
+        if self.global_step % 1000 == 0 and self.global_rank == 0:
+            _B, _N, _C, _H, _W = render_gt['target_images'].shape
+            lrm_generator_input['images'].shape[1]
+
+            target_images = rearrange(
+                render_gt['target_images'], 'b n c h w -> b c h (n w)')
+            render_images = rearrange(
+                render_out['img'], 'b n c h w -> b c h (n w)')
+            target_alphas = rearrange(
+                repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
+            render_alphas = rearrange(
+                repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
+            target_depths = rearrange(
+                repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
+            render_depths = rearrange(
+                repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
+            target_normals = rearrange(
+                render_gt['target_normals'], 'b n c h w -> b c h (n w)')
+            render_normals = rearrange(
+                render_out['normal'], 'b n c h w -> b c h (n w)')
+            MAX_DEPTH = torch.max(target_depths)
+            target_depths = target_depths / MAX_DEPTH * target_alphas
+            render_depths = render_depths / MAX_DEPTH
+
+            grid = torch.cat([
+                target_images, render_images,
+                target_alphas, render_alphas,
+                target_depths, render_depths,
+                target_normals, render_normals,
+            ], dim=-2)
+            grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
+
+            image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
+            save_image(grid, image_path)
+            print(f"Saved image to {image_path}")
+
+        return loss
+
+    def compute_loss(self, render_out, render_gt):
+        # NOTE: the rgb value range of OpenLRM is [0, 1]
+        render_images = render_out['img']
+        target_images = render_gt['target_images'].to(render_images)
+        render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
+        target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
+        loss_mse = F.mse_loss(render_images, target_images)
+        loss_lpips = 2.0 * self.lpips(render_images, target_images)
+
+        render_alphas = render_out['mask']
+        target_alphas = render_gt['target_alphas']
+        loss_mask = F.mse_loss(render_alphas, target_alphas)
+
+        render_depths = render_out['depth']
+        target_depths = render_gt['target_depths']
+        loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
+
+        render_normals = render_out['normal'] * 2.0 - 1.0
+        target_normals = render_gt['target_normals'] * 2.0 - 1.0
+        similarity = (render_normals * target_normals).sum(dim=-3).abs()
+        normal_mask = target_alphas.squeeze(-3)
+        loss_normal = 1 - similarity[normal_mask>0].mean()
+        loss_normal = 0.2 * loss_normal
+
+        # flexicubes regularization loss
+        sdf = render_out['sdf']
+        sdf_reg_loss = render_out['sdf_reg_loss']
+        sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
+        _, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
+        flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
+        flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
+
+        loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
+
+        loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
+
+        prefix = 'train'
+        loss_dict = {}
+        loss_dict.update({f'{prefix}/loss_mse': loss_mse})
+        loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
+        loss_dict.update({f'{prefix}/loss_mask': loss_mask})
+        loss_dict.update({f'{prefix}/loss_normal': loss_normal})
+        loss_dict.update({f'{prefix}/loss_depth': loss_depth})
+        loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
+        loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
+        loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
+        loss_dict.update({f'{prefix}/loss': loss})
+
+        return loss, loss_dict
+
+    @torch.no_grad()
+    def validation_step(self, batch, batch_idx):
+        lrm_generator_input = self.prepare_validation_batch_data(batch)
+
+        render_out = self.forward(lrm_generator_input)
+        render_images = render_out['img']
+        render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
+
+        self.validation_step_outputs.append(render_images)
+
+    def on_validation_epoch_end(self):
+        images = torch.cat(self.validation_step_outputs, dim=-1)
+
+        all_images = self.all_gather(images)
+        all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
+
+        if self.global_rank == 0:
+            image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
+
+            grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
+            save_image(grid, image_path)
+            print(f"Saved image to {image_path}")
+
+        self.validation_step_outputs.clear()
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+
+        optimizer = torch.optim.AdamW(
+            self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
+        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
+
+        return {'optimizer': optimizer, 'lr_scheduler': scheduler}
diff --git a/src/models/__init__.py b/src/models/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/decoder/__init__.py b/src/models/decoder/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/decoder/transformer.py b/src/models/decoder/transformer.py
new file mode 100755
index 0000000000000000000000000000000000000000..2c383b493737281bd49129ae9f5f319a338dc811
--- /dev/null
+++ b/src/models/decoder/transformer.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2023, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+
+class BasicTransformerBlock(nn.Module):
+    """Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks."""
+
+    # use attention from torch.nn.MultiHeadAttention
+    # Block contains a cross-attention layer, a self-attention layer, and a MLP
+    def __init__(
+        self,
+        inner_dim: int,
+        cond_dim: int,
+        num_heads: int,
+        eps: float,
+        attn_drop: float = 0.,
+        attn_bias: bool = False,
+        mlp_ratio: float = 4.,
+        mlp_drop: float = 0.,
+    ):
+        super().__init__()
+
+        self.norm1 = nn.LayerNorm(inner_dim)
+        self.cross_attn = nn.MultiheadAttention(
+            embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
+            dropout=attn_drop, bias=attn_bias, batch_first=True)
+        self.norm2 = nn.LayerNorm(inner_dim)
+        self.self_attn = nn.MultiheadAttention(
+            embed_dim=inner_dim, num_heads=num_heads,
+            dropout=attn_drop, bias=attn_bias, batch_first=True)
+        self.norm3 = nn.LayerNorm(inner_dim)
+        self.mlp = nn.Sequential(
+            nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
+            nn.GELU(),
+            nn.Dropout(mlp_drop),
+            nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
+            nn.Dropout(mlp_drop),
+        )
+
+    def forward(self, x, cond):
+        # x: [N, L, D]
+        # cond: [N, L_cond, D_cond]
+        x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
+        before_sa = self.norm2(x)
+        x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
+        x = x + self.mlp(self.norm3(x))
+        return x
+
+
+class TriplaneTransformer(nn.Module):
+    """
+    Transformer with condition that generates a triplane representation.
+
+    Reference:
+    Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
+    """
+
+    def __init__(
+        self,
+        inner_dim: int,
+        image_feat_dim: int,
+        triplane_low_res: int,
+        triplane_high_res: int,
+        triplane_dim: int,
+        num_layers: int,
+        num_heads: int,
+        eps: float = 1e-6,
+    ):
+        super().__init__()
+
+        # attributes
+        self.triplane_low_res = triplane_low_res
+        self.triplane_high_res = triplane_high_res
+        self.triplane_dim = triplane_dim
+
+        # modules
+        # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
+        self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
+        self.layers = nn.ModuleList([
+            BasicTransformerBlock(
+                inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
+            for _ in range(num_layers)
+        ])
+        self.norm = nn.LayerNorm(inner_dim, eps=eps)
+        self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
+
+    def forward(self, image_feats):
+        # image_feats: [N, L_cond, D_cond]
+
+        N = image_feats.shape[0]
+        H = W = self.triplane_low_res
+        3 * H * W
+
+        x = self.pos_embed.repeat(N, 1, 1)  # [N, L, D]
+        for layer in self.layers:
+            x = layer(x, image_feats)
+        x = self.norm(x)
+
+        # separate each plane and apply deconv
+        x = x.view(N, 3, H, W, -1)
+        x = torch.einsum('nihwd->indhw', x)  # [3, N, D, H, W]
+        x = x.contiguous().view(3*N, -1, H, W)  # [3*N, D, H, W]
+        x = self.deconv(x)  # [3*N, D', H', W']
+        x = x.view(3, N, *x.shape[-3:])  # [3, N, D', H', W']
+        x = torch.einsum('indhw->nidhw', x)  # [N, 3, D', H', W']
+        x = x.contiguous()
+
+        return x
diff --git a/src/models/encoder/__init__.py b/src/models/encoder/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/models/encoder/dino.py b/src/models/encoder/dino.py
new file mode 100755
index 0000000000000000000000000000000000000000..3bc1b195a74e792375ff0ee95433087ae0271e5a
--- /dev/null
+++ b/src/models/encoder/dino.py
@@ -0,0 +1,546 @@
+# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ViT model."""
+from __future__ import annotations
+
+import collections.abc
+import math
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from transformers import PreTrainedModel, ViTConfig
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+    BaseModelOutput,
+    BaseModelOutputWithPooling,
+)
+from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+
+
+class ViTEmbeddings(nn.Module):
+    """Construct the CLS token, position and patch embeddings. Optionally, also the mask token."""
+
+    def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
+        super().__init__()
+
+        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
+        self.patch_embeddings = ViTPatchEmbeddings(config)
+        num_patches = self.patch_embeddings.num_patches
+        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+        self.config = config
+
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+        resolution images.
+
+        Source:
+        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+        """
+
+        num_patches = embeddings.shape[1] - 1
+        num_positions = self.position_embeddings.shape[1] - 1
+        if num_patches == num_positions and height == width:
+            return self.position_embeddings
+        class_pos_embed = self.position_embeddings[:, 0]
+        patch_pos_embed = self.position_embeddings[:, 1:]
+        dim = embeddings.shape[-1]
+        h0 = height // self.config.patch_size
+        w0 = width // self.config.patch_size
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        h0, w0 = h0 + 0.1, w0 + 0.1
+        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed,
+            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+            mode="bicubic",
+            align_corners=False,
+        )
+        assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+    def forward(
+        self,
+        pixel_values: torch.Tensor,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        interpolate_pos_encoding: bool = False,
+    ) -> torch.Tensor:
+        batch_size, _num_channels, height, width = pixel_values.shape
+        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+        if bool_masked_pos is not None:
+            seq_length = embeddings.shape[1]
+            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+            # replace the masked visual tokens by mask_tokens
+            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+        # add the [CLS] token to the embedded patch tokens
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+        embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+        # add positional encoding to each token
+        if interpolate_pos_encoding:
+            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+        else:
+            embeddings = embeddings + self.position_embeddings
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings
+
+
+class ViTPatchEmbeddings(nn.Module):
+    """
+    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+    Transformer.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        image_size, patch_size = config.image_size, config.patch_size
+        num_channels, hidden_size = config.num_channels, config.hidden_size
+
+        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.num_patches = num_patches
+
+        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+        _batch_size, num_channels, height, width = pixel_values.shape
+        if num_channels != self.num_channels:
+            raise ValueError(
+                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+                f" Expected {self.num_channels} but got {num_channels}."
+            )
+        if not interpolate_pos_encoding:
+            if height != self.image_size[0] or width != self.image_size[1]:
+                raise ValueError(
+                    f"Input image size ({height}*{width}) doesn't match model"
+                    f" ({self.image_size[0]}*{self.image_size[1]})."
+                )
+        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+        return embeddings
+
+
+class ViTSelfAttention(nn.Module):
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+            raise ValueError(
+                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+                f"heads {config.num_attention_heads}."
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+        self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+        self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+    ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
+        mixed_query_layer = self.query(hidden_states)
+
+        key_layer = self.transpose_for_scores(self.key(hidden_states))
+        value_layer = self.transpose_for_scores(self.value(hidden_states))
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        context_layer = torch.matmul(attention_probs, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        return outputs
+
+
+class ViTSelfOutput(nn.Module):
+    """
+    The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
+    layernorm applied before each block.
+    """
+
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        return hidden_states
+
+
+class ViTAttention(nn.Module):
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.attention = ViTSelfAttention(config)
+        self.output = ViTSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads: set[int]) -> None:
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.attention.query = prune_linear_layer(self.attention.query, index)
+        self.attention.key = prune_linear_layer(self.attention.key, index)
+        self.attention.value = prune_linear_layer(self.attention.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+        self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
+        self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+        attention_output = self.output(self_outputs[0], hidden_states)
+
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+class ViTIntermediate(nn.Module):
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+
+        return hidden_states
+
+
+class ViTOutput(nn.Module):
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        hidden_states = hidden_states + input_tensor
+
+        return hidden_states
+
+
+def modulate(x, shift, scale):
+    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+class ViTLayer(nn.Module):
+    """This corresponds to the Block class in the timm implementation."""
+
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = ViTAttention(config)
+        self.intermediate = ViTIntermediate(config)
+        self.output = ViTOutput(config)
+        self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+        self.adaLN_modulation = nn.Sequential(
+            nn.SiLU(),
+            nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
+        )
+        nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
+        nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        adaln_input: torch.Tensor = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
+        shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
+
+        self_attention_outputs = self.attention(
+            modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa),  # in ViT, layernorm is applied before self-attention
+            head_mask,
+            output_attentions=output_attentions,
+        )
+        attention_output = self_attention_outputs[0]
+        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
+
+        # first residual connection
+        hidden_states = attention_output + hidden_states
+
+        # in ViT, layernorm is also applied after self-attention
+        layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
+        layer_output = self.intermediate(layer_output)
+
+        # second residual connection is done here
+        layer_output = self.output(layer_output, hidden_states)
+
+        outputs = (layer_output,) + outputs
+
+        return outputs
+
+
+class ViTEncoder(nn.Module):
+    def __init__(self, config: ViTConfig) -> None:
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        adaln_input: torch.Tensor = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+        output_hidden_states: bool = False,
+        return_dict: bool = True,
+    ) -> Union[tuple, BaseModelOutput]:
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        for i, layer_module in enumerate(self.layer):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    layer_module.__call__,
+                    hidden_states,
+                    adaln_input,
+                    layer_head_mask,
+                    output_attentions,
+                )
+            else:
+                layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+
+class ViTPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = ViTConfig
+    base_model_prefix = "vit"
+    main_input_name = "pixel_values"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
+
+    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+        """Initialize the weights."""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+            # `trunc_normal_cpu` not implemented in `half` issues
+            module.weight.data = nn.init.trunc_normal_(
+                module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+            ).to(module.weight.dtype)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, ViTEmbeddings):
+            module.position_embeddings.data = nn.init.trunc_normal_(
+                module.position_embeddings.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.position_embeddings.dtype)
+
+            module.cls_token.data = nn.init.trunc_normal_(
+                module.cls_token.data.to(torch.float32),
+                mean=0.0,
+                std=self.config.initializer_range,
+            ).to(module.cls_token.dtype)
+
+
+class ViTModel(ViTPreTrainedModel):
+    def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
+        self.encoder = ViTEncoder(config)
+
+        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.pooler = ViTPooler(config) if add_pooling_layer else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self) -> ViTPatchEmbeddings:
+        return self.embeddings.patch_embeddings
+
+    def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel.
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        adaln_input: Optional[torch.Tensor] = None,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[tuple, BaseModelOutputWithPooling]:
+        r"""
+        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
+        expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
+        if pixel_values.dtype != expected_dtype:
+            pixel_values = pixel_values.to(expected_dtype)
+
+        embedding_output = self.embeddings(
+            pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+        )
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            adaln_input=adaln_input,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        sequence_output = encoder_outputs[0]
+        sequence_output = self.layernorm(sequence_output)
+        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+        if not return_dict:
+            head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+            return head_outputs + encoder_outputs[1:]
+
+        return BaseModelOutputWithPooling(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+
+class ViTPooler(nn.Module):
+    def __init__(self, config: ViTConfig):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
diff --git a/src/models/encoder/dino_wrapper.py b/src/models/encoder/dino_wrapper.py
new file mode 100755
index 0000000000000000000000000000000000000000..7ebd13b304a15bb43d28003bc378a2394bb8e167
--- /dev/null
+++ b/src/models/encoder/dino_wrapper.py
@@ -0,0 +1,80 @@
+# Copyright (c) 2023, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import torch.nn as nn
+from einops import rearrange
+from transformers import ViTImageProcessor
+
+from .dino import ViTModel
+
+
+class DinoWrapper(nn.Module):
+    """Dino v1 wrapper using huggingface transformer implementation."""
+
+    def __init__(self, model_name: str, freeze: bool = True):
+        super().__init__()
+        self.model, self.processor = self._build_dino(model_name)
+        self.camera_embedder = nn.Sequential(
+            nn.Linear(16, self.model.config.hidden_size, bias=True),
+            nn.SiLU(),
+            nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
+        )
+        if freeze:
+            self._freeze()
+
+    def forward(self, image, camera):
+        # image: [B, N, C, H, W]
+        # camera: [B, N, D]
+        # RGB image with [0,1] scale and properly sized
+        if image.ndim == 5:
+            image = rearrange(image, 'b n c h w -> (b n) c h w')
+        dtype = image.dtype
+        inputs = self.processor(
+            images=image.float(),
+            return_tensors="pt",
+            do_rescale=False,
+            do_resize=False,
+        ).to(self.model.device).to(dtype)
+        # embed camera
+        camera.shape[1]
+        camera_embeddings = self.camera_embedder(camera)
+        camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
+        embeddings = camera_embeddings
+        # This resampling of positional embedding uses bicubic interpolation
+        outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
+        last_hidden_states = outputs.last_hidden_state
+        return last_hidden_states
+
+    def _freeze(self):
+        print("======== Freezing DinoWrapper ========")
+        self.model.eval()
+        for name, param in self.model.named_parameters():
+            param.requires_grad = False
+
+    @staticmethod
+    def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
+        import requests
+        try:
+            model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
+            processor = ViTImageProcessor.from_pretrained(model_name)
+            return model, processor
+        except requests.exceptions.ProxyError as err:
+            if proxy_error_retries > 0:
+                print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
+                import time
+                time.sleep(proxy_error_cooldown)
+                return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
+            else:
+                raise err
diff --git a/src/models/geometry/__init__.py b/src/models/geometry/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..89e9a6c2fffe82a55693885dae78c1a630924389
--- /dev/null
+++ b/src/models/geometry/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
diff --git a/src/models/geometry/camera/__init__.py b/src/models/geometry/camera/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..68bf983d37cd56c6435cfebf6625590b6052a10c
--- /dev/null
+++ b/src/models/geometry/camera/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+import torch
+from torch import nn
+
+
+class Camera(nn.Module):
+    def __init__(self):
+        super().__init__()
+        pass
diff --git a/src/models/geometry/camera/perspective_camera.py b/src/models/geometry/camera/perspective_camera.py
new file mode 100755
index 0000000000000000000000000000000000000000..7a5b78815941894b9924d5a602e141c5f5730e4f
--- /dev/null
+++ b/src/models/geometry/camera/perspective_camera.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+from . import Camera
+
+
+def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
+    if near_plane is None:
+        near_plane = n
+    return np.array(
+        [[n / x, 0, 0, 0],
+         [0, n / -x, 0, 0],
+         [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
+         [0, 0, -1, 0]]).astype(np.float32)
+
+
+class PerspectiveCamera(Camera):
+    def __init__(self, fovy=49.0, device='cuda'):
+        super().__init__()
+        self.device = device
+        focal = np.tan(fovy / 180.0 * np.pi * 0.5)
+        self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
+
+    def project(self, points_bxnx4):
+        out = torch.matmul(
+            points_bxnx4,
+            torch.transpose(self.proj_mtx, 1, 2))
+        return out
diff --git a/src/models/geometry/render/__init__.py b/src/models/geometry/render/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..7ee9c3843d694f43452db63b58e9fad64f906dbd
--- /dev/null
+++ b/src/models/geometry/render/__init__.py
@@ -0,0 +1,11 @@
+from __future__ import annotations
+
+import torch
+
+
+class Renderer:
+    def __init__(self):
+        pass
+
+    def forward(self):
+        pass
diff --git a/src/models/geometry/render/neural_render.py b/src/models/geometry/render/neural_render.py
new file mode 100755
index 0000000000000000000000000000000000000000..9d8da2246a59a20173753fdba8433e93606f17be
--- /dev/null
+++ b/src/models/geometry/render/neural_render.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+import nvdiffrast.torch as dr
+import torch
+import torch.nn.functional as F
+
+from . import Renderer
+
+_FG_LUT = None
+
+
+def interpolate(attr, rast, attr_idx, rast_db=None):
+    return dr.interpolate(
+        attr.contiguous(), rast, attr_idx, rast_db=rast_db,
+        diff_attrs=None if rast_db is None else 'all')
+
+
+def xfm_points(points, matrix, use_python=True):
+    """
+    Transform points.
+
+    Args:
+    ----
+        points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
+        matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
+        use_python: Use PyTorch's torch.matmul (for validation)
+
+    Returns:
+    -------
+        Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
+
+    """
+    out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
+    if torch.is_anomaly_enabled():
+        assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
+    return out
+
+
+def dot(x, y):
+    return torch.sum(x * y, -1, keepdim=True)
+
+
+def compute_vertex_normal(v_pos, t_pos_idx):
+    i0 = t_pos_idx[:, 0]
+    i1 = t_pos_idx[:, 1]
+    i2 = t_pos_idx[:, 2]
+
+    v0 = v_pos[i0, :]
+    v1 = v_pos[i1, :]
+    v2 = v_pos[i2, :]
+
+    face_normals = torch.cross(v1 - v0, v2 - v0)
+
+    # Splat face normals to vertices
+    v_nrm = torch.zeros_like(v_pos)
+    v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
+    v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
+    v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
+
+    # Normalize, replace zero (degenerated) normals with some default value
+    v_nrm = torch.where(
+        dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
+    )
+    v_nrm = F.normalize(v_nrm, dim=1)
+    assert torch.all(torch.isfinite(v_nrm))
+
+    return v_nrm
+
+
+class NeuralRender(Renderer):
+    def __init__(self, device='cuda', camera_model=None):
+        super().__init__()
+        self.device = device
+        self.ctx = dr.RasterizeCudaContext(device=device)
+        self.projection_mtx = None
+        self.camera = camera_model
+
+    def render_mesh(
+            self,
+            mesh_v_pos_bxnx3,
+            mesh_t_pos_idx_fx3,
+            camera_mv_bx4x4,
+            mesh_v_feat_bxnxd,
+            resolution=256,
+            spp=1,
+            device='cuda',
+            hierarchical_mask=False
+    ):
+        assert not hierarchical_mask
+
+        mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
+        v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in)  # Rotate it to camera coordinates
+        v_pos_clip = self.camera.project(v_pos)  # Projection in the camera
+
+        v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long())  # vertex normals in world coordinates
+
+        # Render the image,
+        # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
+        num_layers = 1
+        mask_pyramid = None
+        assert mesh_t_pos_idx_fx3.shape[0] > 0  # Make sure we have shapes
+        mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1)  # Concatenate the pos
+
+        with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
+            for _ in range(num_layers):
+                rast, _db = peeler.rasterize_next_layer()
+                gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
+
+        hard_mask = torch.clamp(rast[..., -1:], 0, 1)
+        antialias_mask = dr.antialias(
+            hard_mask.clone().contiguous(), rast, v_pos_clip,
+            mesh_t_pos_idx_fx3)
+
+        depth = gb_feat[..., -2:-1]
+        ori_mesh_feature = gb_feat[..., :-4]
+
+        normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
+        normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
+        normal = F.normalize(normal, dim=-1)
+        normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float())      # black background
+
+        return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal
diff --git a/src/models/geometry/rep_3d/__init__.py b/src/models/geometry/rep_3d/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..1ba8ed833982cc5a8a8a72ca45b83fe09ce4de5e
--- /dev/null
+++ b/src/models/geometry/rep_3d/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+
+class Geometry:
+    def __init__(self):
+        pass
+
+    def forward(self):
+        pass
diff --git a/src/models/geometry/rep_3d/dmtet.py b/src/models/geometry/rep_3d/dmtet.py
new file mode 100755
index 0000000000000000000000000000000000000000..c1ba3686fdd9c20d6ad6fbd7dfbc2a873fae67bc
--- /dev/null
+++ b/src/models/geometry/rep_3d/dmtet.py
@@ -0,0 +1,504 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+from . import Geometry
+from .dmtet_utils import get_center_boundary_index
+
+
+###############################################################################
+# DMTet utility functions
+###############################################################################
+def create_mt_variable(device):
+    triangle_table = torch.tensor(
+        [
+            [-1, -1, -1, -1, -1, -1],
+            [1, 0, 2, -1, -1, -1],
+            [4, 0, 3, -1, -1, -1],
+            [1, 4, 2, 1, 3, 4],
+            [3, 1, 5, -1, -1, -1],
+            [2, 3, 0, 2, 5, 3],
+            [1, 4, 0, 1, 5, 4],
+            [4, 2, 5, -1, -1, -1],
+            [4, 5, 2, -1, -1, -1],
+            [4, 1, 0, 4, 5, 1],
+            [3, 2, 0, 3, 5, 2],
+            [1, 3, 5, -1, -1, -1],
+            [4, 1, 2, 4, 3, 1],
+            [3, 0, 4, -1, -1, -1],
+            [2, 0, 1, -1, -1, -1],
+            [-1, -1, -1, -1, -1, -1]
+        ], dtype=torch.long, device=device)
+
+    num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
+    base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
+    v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
+    return triangle_table, num_triangles_table, base_tet_edges, v_id
+
+
+def sort_edges(edges_ex2):
+    with torch.no_grad():
+        order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
+        order = order.unsqueeze(dim=1)
+        a = torch.gather(input=edges_ex2, index=order, dim=1)
+        b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
+    return torch.stack([a, b], -1)
+
+
+###############################################################################
+# marching tetrahedrons (differentiable)
+###############################################################################
+
+def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id):
+    with torch.no_grad():
+        occ_n = sdf_n > 0
+        occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
+        occ_sum = torch.sum(occ_fx4, -1)
+        valid_tets = (occ_sum > 0) & (occ_sum < 4)
+        occ_sum = occ_sum[valid_tets]
+
+        # find all vertices
+        all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
+        all_edges = sort_edges(all_edges)
+        unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
+
+        unique_edges = unique_edges.long()
+        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
+        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
+        mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
+        idx_map = mapping[idx_map]  # map edges to verts
+
+        interp_v = unique_edges[mask_edges]  # .long()
+    edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
+    edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
+    edges_to_interp_sdf[:, -1] *= -1
+
+    denominator = edges_to_interp_sdf.sum(1, keepdim=True)
+
+    edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
+    verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
+
+    idx_map = idx_map.reshape(-1, 6)
+
+    tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
+    num_triangles = num_triangles_table[tetindex]
+
+    # Generate triangle indices
+    faces = torch.cat(
+        (
+            torch.gather(
+                input=idx_map[num_triangles == 1], dim=1,
+                index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
+            torch.gather(
+                input=idx_map[num_triangles == 2], dim=1,
+                index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
+        ), dim=0)
+    return verts, faces
+
+
+def create_tetmesh_variables(device='cuda'):
+    tet_table = torch.tensor(
+        [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
+         [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1],
+         [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1],
+         [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8],
+         [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1],
+         [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9],
+         [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9],
+         [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9],
+         [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1],
+         [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9],
+         [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9],
+         [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9],
+         [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8],
+         [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8],
+         [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6],
+         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device)
+    num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device)
+    return tet_table, num_tets_table
+
+
+def marching_tets_tetmesh(
+        pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
+        return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
+    with torch.no_grad():
+        occ_n = sdf_n > 0
+        occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
+        occ_sum = torch.sum(occ_fx4, -1)
+        valid_tets = (occ_sum > 0) & (occ_sum < 4)
+        occ_sum = occ_sum[valid_tets]
+
+        # find all vertices
+        all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2)
+        all_edges = sort_edges(all_edges)
+        unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
+
+        unique_edges = unique_edges.long()
+        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
+        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1
+        mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device)
+        idx_map = mapping[idx_map]  # map edges to verts
+
+        interp_v = unique_edges[mask_edges]  # .long()
+    edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
+    edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
+    edges_to_interp_sdf[:, -1] *= -1
+
+    denominator = edges_to_interp_sdf.sum(1, keepdim=True)
+
+    edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
+    verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
+
+    idx_map = idx_map.reshape(-1, 6)
+
+    tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
+    num_triangles = num_triangles_table[tetindex]
+
+    # Generate triangle indices
+    faces = torch.cat(
+        (
+            torch.gather(
+                input=idx_map[num_triangles == 1], dim=1,
+                index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
+            torch.gather(
+                input=idx_map[num_triangles == 2], dim=1,
+                index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
+        ), dim=0)
+    if not return_tet_mesh:
+        return verts, faces
+    occupied_verts = ori_v[occ_n]
+    mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1
+    mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda")
+    tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4))
+
+    idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1)  # t x 10
+    tet_verts = torch.cat([verts, occupied_verts], 0)
+    num_tets = num_tets_table[tetindex]
+
+    tets = torch.cat(
+        (
+            torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape(
+                -1,
+                4),
+            torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape(
+                -1,
+                4),
+        ), dim=0)
+    # add fully occupied tets
+    fully_occupied = occ_fx4.sum(-1) == 4
+    tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0]
+    tets = torch.cat([tets, tet_fully_occupied])
+
+    return verts, faces, tet_verts, tets
+
+
+###############################################################################
+# Compact tet grid
+###############################################################################
+
+def compact_tets(pos_nx3, sdf_n, tet_fx4):
+    with torch.no_grad():
+        # Find surface tets
+        occ_n = sdf_n > 0
+        occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
+        occ_sum = torch.sum(occ_fx4, -1)
+        valid_tets = (occ_sum > 0) & (occ_sum < 4)  # one value per tet, these are the surface tets
+
+        valid_vtx = tet_fx4[valid_tets].reshape(-1)
+        unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True)
+        new_pos = pos_nx3[unique_vtx]
+        new_sdf = sdf_n[unique_vtx]
+        new_tets = idx_map.reshape(-1, 4)
+        return new_pos, new_sdf, new_tets
+
+
+###############################################################################
+# Subdivide volume
+###############################################################################
+
+def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
+    device = tet_pos_bxnx3.device
+    # get new verts
+    tet_fx4 = tet_bxfx4[0]
+    edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3]
+    all_edges = tet_fx4[:, edges].reshape(-1, 2)
+    all_edges = sort_edges(all_edges)
+    unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
+    idx_map = idx_map + tet_pos_bxnx3.shape[1]
+    all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1)
+    mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape(
+        all_values.shape[0], -1, 2,
+        all_values.shape[-1]).mean(2)
+    new_v = torch.cat([all_values, mid_points_pos], 1)
+    new_v, new_sdf = new_v[..., :3], new_v[..., 3]
+
+    # get new tets
+
+    idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3]
+    idx_ab = idx_map[0::6]
+    idx_ac = idx_map[1::6]
+    idx_ad = idx_map[2::6]
+    idx_bc = idx_map[3::6]
+    idx_bd = idx_map[4::6]
+    idx_cd = idx_map[5::6]
+
+    tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1)
+    tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1)
+    tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1)
+    tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1)
+    tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1)
+    tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1)
+    tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1)
+    tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1)
+
+    tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0)
+    tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1)
+    tet = tet_np.long().to(device)
+
+    return new_v, tet, new_sdf
+
+
+###############################################################################
+# Adjacency
+###############################################################################
+def tet_to_tet_adj_sparse(tet_tx4):
+    # include self connection!!!!!!!!!!!!!!!!!!!
+    with torch.no_grad():
+        t = tet_tx4.shape[0]
+        device = tet_tx4.device
+        idx_array = torch.LongTensor(
+            [0, 1, 2,
+             1, 0, 3,
+             2, 3, 0,
+             3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1)  # (t, 4, 3)
+
+        # get all faces
+        all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape(
+            -1,
+            3)  # (tx4, 3)
+        all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1)
+        # sort and group
+        all_faces_sorted, _ = torch.sort(all_faces, dim=1)
+
+        all_faces_unique, inverse_indices, counts = torch.unique(
+            all_faces_sorted, dim=0, return_counts=True,
+            return_inverse=True)
+        all_faces_unique[counts == 2]
+        counts = counts[inverse_indices]  # tx4
+        valid = (counts == 2)
+
+        group = inverse_indices[valid]
+        # print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape)
+        _, indices = torch.sort(group)
+        all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices]
+        tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1)
+
+        tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])])
+        adj_self = torch.arange(t, device=tet_tx4.device)
+        adj_self = torch.stack([adj_self, adj_self], -1)
+        tet_adj_idx = torch.cat([tet_adj_idx, adj_self])
+
+        tet_adj_idx = torch.unique(tet_adj_idx, dim=0)
+        values = torch.ones(
+            tet_adj_idx.shape[0], device=tet_tx4.device).float()
+        adj_sparse = torch.sparse.FloatTensor(
+            tet_adj_idx.t(), values, torch.Size([t, t]))
+
+        # normalization
+        neighbor_num = 1.0 / torch.sparse.sum(
+            adj_sparse, dim=1).to_dense()
+        values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0])
+        adj_sparse = torch.sparse.FloatTensor(
+            tet_adj_idx.t(), values, torch.Size([t, t]))
+    return adj_sparse
+
+
+###############################################################################
+# Compact grid
+###############################################################################
+
+def get_tet_bxfx4x3(bxnxz, bxfx4):
+    n_batch, z = bxnxz.shape[0], bxnxz.shape[2]
+    gather_input = bxnxz.unsqueeze(2).expand(
+        n_batch, bxnxz.shape[1], 4, z)
+    gather_index = bxfx4.unsqueeze(-1).expand(
+        n_batch, bxfx4.shape[1], 4, z).long()
+    tet_bxfx4xz = torch.gather(
+        input=gather_input, dim=1, index=gather_index)
+
+    return tet_bxfx4xz
+
+
+def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf):
+    with torch.no_grad():
+        assert tet_pos_bxnx3.shape[0] == 1
+
+        occ = grid_sdf[0] > 0
+        occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1)
+        mask = (occ_sum > 0) & (occ_sum < 4)
+
+        # build connectivity graph
+        adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0])
+        mask = mask.float().unsqueeze(-1)
+
+        # Include a one ring of neighbors
+        for i in range(1):
+            mask = torch.sparse.mm(adj_matrix, mask)
+        mask = mask.squeeze(-1) > 0
+
+        mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long)
+        new_tet_bxfx4 = tet_bxfx4[:, mask].long()
+        selected_verts_idx = torch.unique(new_tet_bxfx4)
+        new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx]
+        mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device)
+        new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape)
+        new_grid_sdf = grid_sdf[:, selected_verts_idx]
+        return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf
+
+
+###############################################################################
+# Regularizer
+###############################################################################
+
+def sdf_reg_loss(sdf, all_edges):
+    sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2)
+    mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
+    sdf_f1x6x2 = sdf_f1x6x2[mask]
+    sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(
+        sdf_f1x6x2[..., 0],
+        (sdf_f1x6x2[..., 1] > 0).float()) + \
+               torch.nn.functional.binary_cross_entropy_with_logits(
+                   sdf_f1x6x2[..., 1],
+                   (sdf_f1x6x2[..., 0] > 0).float())
+    return sdf_diff
+
+
+def sdf_reg_loss_batch(sdf, all_edges):
+    sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
+    mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
+    sdf_f1x6x2 = sdf_f1x6x2[mask]
+    sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
+               torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
+    return sdf_diff
+
+
+###############################################################################
+#  Geometry interface
+###############################################################################
+class DMTetGeometry(Geometry):
+    def __init__(
+            self, grid_res=64, scale=2.0, device='cuda', renderer=None,
+            render_type='neural_render', args=None):
+        super().__init__()
+        self.grid_res = grid_res
+        self.device = device
+        self.args = args
+        tets = np.load('data/tets/%d_compress.npz' % (grid_res))
+        self.verts = torch.from_numpy(tets['vertices']).float().to(self.device)
+        # Make sure the tet is zero-centered and length is equal to 1
+        length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0]
+        length = length.max()
+        mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0
+        self.verts = (self.verts - mid.unsqueeze(dim=0)) / length
+        if isinstance(scale, list):
+            self.verts[:, 0] = self.verts[:, 0] * scale[0]
+            self.verts[:, 1] = self.verts[:, 1] * scale[1]
+            self.verts[:, 2] = self.verts[:, 2] * scale[1]
+        else:
+            self.verts = self.verts * scale
+        self.indices = torch.from_numpy(tets['tets']).long().to(self.device)
+        self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device)
+        self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device)
+        # Parameters for regularization computation
+        edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device)
+        all_edges = self.indices[:, edges].reshape(-1, 2)
+        all_edges_sorted = torch.sort(all_edges, dim=1)[0]
+        self.all_edges = torch.unique(all_edges_sorted, dim=0)
+
+        # Parameters used for fix boundary sdf
+        self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts)
+        self.renderer = renderer
+        self.render_type = render_type
+
+    def getAABB(self):
+        return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
+
+    def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
+        if indices is None:
+            indices = self.indices
+        verts, faces = marching_tets(
+            v_deformed_nx3, sdf_n, indices, self.triangle_table,
+            self.num_triangles_table, self.base_tet_edges, self.v_id)
+        faces = torch.cat(
+            [faces[:, 0:1],
+             faces[:, 2:3],
+             faces[:, 1:2], ], dim=-1)
+        return verts, faces
+
+    def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None):
+        if indices is None:
+            indices = self.indices
+        verts, faces, tet_verts, tets = marching_tets_tetmesh(
+            v_deformed_nx3, sdf_n, indices, self.triangle_table,
+            self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True,
+            num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3)
+        faces = torch.cat(
+            [faces[:, 0:1],
+             faces[:, 2:3],
+             faces[:, 1:2], ], dim=-1)
+        return verts, faces, tet_verts, tets
+
+    def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
+        return_value = dict()
+        if self.render_type == 'neural_render':
+            tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
+                mesh_v_nx3.unsqueeze(dim=0),
+                mesh_f_fx3.int(),
+                camera_mv_bx4x4,
+                mesh_v_nx3.unsqueeze(dim=0),
+                resolution=resolution,
+                device=self.device,
+                hierarchical_mask=hierarchical_mask
+            )
+
+            return_value['tex_pos'] = tex_pos
+            return_value['mask'] = mask
+            return_value['hard_mask'] = hard_mask
+            return_value['rast'] = rast
+            return_value['v_pos_clip'] = v_pos_clip
+            return_value['mask_pyramid'] = mask_pyramid
+            return_value['depth'] = depth
+        else:
+            raise NotImplementedError
+
+        return return_value
+
+    def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
+        # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
+        v_list = []
+        f_list = []
+        n_batch = v_deformed_bxnx3.shape[0]
+        all_render_output = []
+        for i_batch in range(n_batch):
+            verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
+            v_list.append(verts_nx3)
+            f_list.append(faces_fx3)
+            render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
+            all_render_output.append(render_output)
+
+        # Concatenate all render output
+        return_keys = all_render_output[0].keys()
+        return_value = dict()
+        for k in return_keys:
+            value = [v[k] for v in all_render_output]
+            return_value[k] = value
+            # We can do concatenation outside of the render
+        return return_value
diff --git a/src/models/geometry/rep_3d/dmtet_utils.py b/src/models/geometry/rep_3d/dmtet_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..8e8b45a086fa77f11f34ded0eaf634ac02aba270
--- /dev/null
+++ b/src/models/geometry/rep_3d/dmtet_utils.py
@@ -0,0 +1,21 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+import torch
+
+
+def get_center_boundary_index(verts):
+    length_ = torch.sum(verts ** 2, dim=-1)
+    center_idx = torch.argmin(length_)
+    boundary_neg = verts == verts.max()
+    boundary_pos = verts == verts.min()
+    boundary = torch.bitwise_or(boundary_pos, boundary_neg)
+    boundary = torch.sum(boundary.float(), dim=-1)
+    boundary_idx = torch.nonzero(boundary)
+    return center_idx, boundary_idx.squeeze(dim=-1)
diff --git a/src/models/geometry/rep_3d/extract_texture_map.py b/src/models/geometry/rep_3d/extract_texture_map.py
new file mode 100755
index 0000000000000000000000000000000000000000..bc11d20642130a1f0d00ef463fb8d748e8826fa1
--- /dev/null
+++ b/src/models/geometry/rep_3d/extract_texture_map.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+import numpy as np
+import nvdiffrast.torch as dr
+import torch
+import xatlas
+
+
+# ==============================================================================================
+def interpolate(attr, rast, attr_idx, rast_db=None):
+    return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
+
+
+def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
+    _vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
+
+    # Convert to tensors
+    indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
+
+    uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
+    mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
+    # mesh_v_tex. ture
+    uv_clip = uvs[None, ...] * 2.0 - 1.0
+
+    # pad to four component coordinate
+    uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
+
+    # rasterize
+    rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
+
+    # Interpolate world space position
+    gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
+    mask = rast[..., 3:4] > 0
+    return uvs, mesh_tex_idx, gb_pos, mask
diff --git a/src/models/geometry/rep_3d/flexicubes.py b/src/models/geometry/rep_3d/flexicubes.py
new file mode 100755
index 0000000000000000000000000000000000000000..9ffc8243fdc7d5ca1b32aa9c1c213ef074921cbf
--- /dev/null
+++ b/src/models/geometry/rep_3d/flexicubes.py
@@ -0,0 +1,580 @@
+# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+import torch
+
+from .tables import *
+
+__all__ = [
+    'FlexiCubes'
+]
+
+
+class FlexiCubes:
+    """
+    This class implements the FlexiCubes method for extracting meshes from scalar fields.
+    It maintains a series of lookup tables and indices to support the mesh extraction process.
+    FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances
+    the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting
+    the surface representation through gradient-based optimization.
+
+    During instantiation, the class loads DMC tables from a file and transforms them into
+    PyTorch tensors on the specified device.
+
+    Attributes
+    ----------
+        device (str): Specifies the computational device (default is "cuda").
+        dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges
+            associated with each dual vertex in 256 Marching Cubes (MC) configurations.
+        num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of
+            the 256 MC configurations.
+        check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19
+            of the DMC configurations.
+        tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
+        quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles
+            along one diagonal.
+        quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into
+            two triangles along the other diagonal.
+        quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles
+            during training by connecting all edges to their midpoints.
+        cube_corners (torch.Tensor): Defines the positions of a standard unit cube's
+            eight corners in 3D space, ordered starting from the origin (0,0,0),
+            moving along the x-axis, then y-axis, and finally z-axis.
+            Used as a blueprint for generating a voxel grid.
+        cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used
+            to retrieve the case id.
+        cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs.
+            Used to retrieve edge vertices in DMC.
+        edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with
+            their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the
+            first edge is oriented along the x-axis.
+        dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges
+            across four adjacent cubes to the shared faces of these cubes. For instance,
+            dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along
+            the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively.
+            This tensor is only utilized during isosurface tetrahedralization.
+        adj_pairs (torch.Tensor):
+            A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
+        qef_reg_scale (float):
+            The scaling factor applied to the regularization loss to prevent issues with singularity
+            when solving the QEF. This parameter is only used when a 'grad_func' is specified.
+        weight_scale (float):
+            The scale of weights in FlexiCubes. Should be between 0 and 1.
+
+    """
+
+    def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
+
+        self.device = device
+        self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
+        self.num_vd_table = torch.tensor(num_vd_table,
+                                         dtype=torch.long, device=device, requires_grad=False)
+        self.check_table = torch.tensor(
+            check_table,
+            dtype=torch.long, device=device, requires_grad=False)
+
+        self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
+        self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
+        self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
+        self.quad_split_train = torch.tensor(
+            [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
+
+        self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
+                                         1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
+        self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
+        self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
+                                       2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
+
+        self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
+                                           dtype=torch.long, device=device)
+        self.dir_faces_table = torch.tensor([
+            [[5, 4], [3, 2], [4, 5], [2, 3]],
+            [[5, 4], [1, 0], [4, 5], [0, 1]],
+            [[3, 2], [1, 0], [2, 3], [0, 1]]
+        ], dtype=torch.long, device=device)
+        self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
+        self.qef_reg_scale = qef_reg_scale
+        self.weight_scale = weight_scale
+
+    def construct_voxel_grid(self, res):
+        """
+        Generates a voxel grid based on the specified resolution.
+
+        Args:
+        ----
+            res (int or list[int]): The resolution of the voxel grid. If an integer
+                is provided, it is used for all three dimensions. If a list or tuple
+                of 3 integers is provided, they define the resolution for the x,
+                y, and z dimensions respectively.
+
+        Returns:
+        -------
+            (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the
+                cube corners (index into vertices) of the constructed voxel grid.
+                The vertices are centered at the origin, with the length of each
+                dimension in the grid being one.
+
+        """
+        base_cube_f = torch.arange(8).to(self.device)
+        if isinstance(res, int):
+            res = (res, res, res)
+        voxel_grid_template = torch.ones(res, device=self.device)
+
+        res = torch.tensor([res], dtype=torch.float, device=self.device)
+        coords = torch.nonzero(voxel_grid_template).float() / res  # N, 3
+        verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
+        cubes = (base_cube_f.unsqueeze(0) +
+                 torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
+
+        verts_rounded = torch.round(verts * 10**5) / (10**5)
+        verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
+        cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
+
+        return verts_unique - 0.5, cubes
+
+    def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
+                 gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
+        r"""
+        Main function for mesh extraction from scalar field using FlexiCubes. This function converts
+        discrete signed distance fields, encoded on voxel grids and additional per-cube parameters,
+        to triangle or tetrahedral meshes using a differentiable operation as described in
+        `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances
+        mesh quality and geometric fidelity by adjusting the surface representation based on gradient
+        optimization. The output surface is differentiable with respect to the input vertex positions,
+        scalar field values, and weight parameters.
+
+        If you intend to extract a surface mesh from a fixed Signed Distance Field without the
+        optimization of parameters, it is suggested to provide the "grad_func" which should
+        return the surface gradient at any given 3D position. When grad_func is provided, the process
+        to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as
+        described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy.
+        Please note, this approach is non-differentiable.
+
+        For more details and example usage in optimization, refer to the
+        `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
+
+        Args:
+        ----
+            x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
+            s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values
+                denote that the corresponding vertex resides inside the isosurface. This affects
+                the directions of the extracted triangle faces and volume to be tetrahedralized.
+            cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
+            res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it
+                is used for all three dimensions. If a list or tuple of 3 integers is provided, they
+                specify the resolution for the x, y, and z dimensions respectively.
+            beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual
+                vertices positioning. Defaults to uniform value for all edges.
+            alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual
+                vertices positioning. Defaults to uniform value for all vertices.
+            gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of
+                quadrilaterals into triangles. Defaults to uniform value for all cubes.
+            training (bool, optional): If set to True, applies differentiable quad splitting for
+                training. Defaults to False.
+            output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise,
+                outputs a triangular mesh. Defaults to False.
+            grad_func (callable, optional): A function to compute the surface gradient at specified
+                3D positions (input: Nx3 positions). The function should return gradients as an Nx3
+                tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
+
+        Returns:
+        -------
+            (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
+                - Vertices for the extracted triangular/tetrahedral mesh.
+                - Faces for the extracted triangular/tetrahedral mesh.
+                - Regularizer L_dev, computed per dual vertex.
+
+        .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
+            https://research.nvidia.com/labs/toronto-ai/flexicubes/
+        .. _Manifold Dual Contouring:
+            https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
+
+        """
+
+        surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
+        if surf_cubes.sum() == 0:
+            return torch.zeros(
+                (0, 3),
+                device=self.device), torch.zeros(
+                (0, 4),
+                dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
+                (0, 3),
+                dtype=torch.long, device=self.device), torch.zeros(
+                (0),
+                device=self.device)
+        beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
+
+        case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
+
+        surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
+
+        vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
+            x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
+        vertices, faces, s_edges, edge_indices = self._triangulate(
+            s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
+        if not output_tetmesh:
+            return vertices, faces, L_dev
+        else:
+            vertices, tets = self._tetrahedralize(
+                x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
+                surf_cubes, training)
+            return vertices, tets, L_dev
+
+    def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
+        """Regularizer L_dev as in Equation 8."""
+        dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
+        mean_l2 = torch.zeros_like(vd[:, 0])
+        mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
+        mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
+        return mad
+
+    def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
+        """Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones."""
+        n_cubes = surf_cubes.shape[0]
+
+        if beta_fx12 is not None:
+            beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
+        else:
+            beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
+
+        if alpha_fx8 is not None:
+            alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
+        else:
+            alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
+
+        if gamma_f is not None:
+            gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
+        else:
+            gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
+
+        return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
+
+    @torch.no_grad()
+    def _get_case_id(self, occ_fx8, surf_cubes, res):
+        """
+        Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
+        ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
+        supplementary material. It should be noted that this function assumes a regular grid.
+        """
+        case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
+
+        problem_config = self.check_table.to(self.device)[case_ids]
+        to_check = problem_config[..., 0] == 1
+        problem_config = problem_config[to_check]
+        if not isinstance(res, (list, tuple)):
+            res = [res, res, res]
+
+        # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
+        # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
+        # This allows efficient checking on adjacent cubes.
+        problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
+        vol_idx = torch.nonzero(problem_config_full[..., 0] == 0)  # N, 3
+        vol_idx_problem = vol_idx[surf_cubes][to_check]
+        problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
+        vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
+
+        within_range = (
+            vol_idx_problem_adj[..., 0] >= 0) & (
+            vol_idx_problem_adj[..., 0] < res[0]) & (
+            vol_idx_problem_adj[..., 1] >= 0) & (
+            vol_idx_problem_adj[..., 1] < res[1]) & (
+            vol_idx_problem_adj[..., 2] >= 0) & (
+            vol_idx_problem_adj[..., 2] < res[2])
+
+        vol_idx_problem = vol_idx_problem[within_range]
+        vol_idx_problem_adj = vol_idx_problem_adj[within_range]
+        problem_config = problem_config[within_range]
+        problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
+                                                 vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
+        # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
+        to_invert = (problem_config_adj[..., 0] == 1)
+        idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
+        case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
+        return case_ids
+
+    @torch.no_grad()
+    def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
+        """
+        Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
+        can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
+        and marks the cube edges with this index.
+        """
+        occ_n = s_n < 0
+        all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
+        unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
+
+        unique_edges = unique_edges.long()
+        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
+
+        surf_edges_mask = mask_edges[_idx_map]
+        counts = counts[_idx_map]
+
+        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
+        mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
+        # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
+        # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
+        idx_map = mapping[_idx_map]
+        surf_edges = unique_edges[mask_edges]
+        return surf_edges, idx_map, counts, surf_edges_mask
+
+    @torch.no_grad()
+    def _identify_surf_cubes(self, s_n, cube_fx8):
+        """
+        Identifies grid cubes that intersect with the underlying surface by checking if the signs at
+        all corners are not identical.
+        """
+        occ_n = s_n < 0
+        occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
+        _occ_sum = torch.sum(occ_fx8, -1)
+        surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
+        return surf_cubes, occ_fx8
+
+    def _linear_interp(self, edges_weight, edges_x):
+        """Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'."""
+        edge_dim = edges_weight.dim() - 2
+        assert edges_weight.shape[edge_dim] == 2
+        edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
+                                 torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
+        denominator = edges_weight.sum(edge_dim)
+        ue = (edges_x * edges_weight).sum(edge_dim) / denominator
+        return ue
+
+    def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
+        p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
+        norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
+        c_bx3 = c_bx3.reshape(-1, 3)
+        A = norm_bxnx3
+        B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
+
+        A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
+        B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
+        A = torch.cat([A, A_reg], 1)
+        B = torch.cat([B, B_reg], 1)
+        dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
+        return dual_verts
+
+    def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
+        """Computes the location of dual vertices as described in Section 4.2."""
+        alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
+        surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
+        surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
+        zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
+
+        idx_map = idx_map.reshape(-1, 12)
+        num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
+        edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
+
+        total_num_vd = 0
+        vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
+        if grad_func is not None:
+            normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
+            vd = []
+        for num in torch.unique(num_vd):
+            cur_cubes = (num_vd == num)  # consider cubes with the same numbers of vd emitted (for batching)
+            curr_num_vd = cur_cubes.sum() * num
+            curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
+            curr_edge_group_to_vd = torch.arange(
+                curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
+            total_num_vd += curr_num_vd
+            curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
+                cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
+
+            curr_mask = (curr_edge_group != -1)
+            edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
+            edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
+            edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
+            vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
+            vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
+
+            if grad_func is not None:
+                with torch.no_grad():
+                    cube_e_verts_idx = idx_map[cur_cubes]
+                    curr_edge_group[~curr_mask] = 0
+
+                    verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
+                    verts_group_idx[verts_group_idx == -1] = 0
+                    verts_group_pos = torch.index_select(
+                        input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
+                    v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
+                    curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
+                    verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
+
+                    normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
+                        -1, num.item(), 7,
+                        3)
+                    curr_mask = curr_mask.squeeze(2)
+                    vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
+                                                 verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
+        edge_group = torch.cat(edge_group)
+        edge_group_to_vd = torch.cat(edge_group_to_vd)
+        edge_group_to_cube = torch.cat(edge_group_to_cube)
+        vd_num_edges = torch.cat(vd_num_edges)
+        vd_gamma = torch.cat(vd_gamma)
+
+        if grad_func is not None:
+            vd = torch.cat(vd)
+            L_dev = torch.zeros([1], device=self.device)
+        else:
+            vd = torch.zeros((total_num_vd, 3), device=self.device)
+            beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
+
+            idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
+
+            x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
+            s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
+
+            zero_crossing_group = torch.index_select(
+                input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
+
+            alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
+                                             index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
+            ue_group = self._linear_interp(s_group * alpha_group, x_group)
+
+            beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
+                                      index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
+            beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
+            vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
+            L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
+
+        v_idx = torch.arange(vd.shape[0], device=self.device)  # + total_num_vd
+
+        vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
+                                                      12 + edge_group, src=v_idx[edge_group_to_vd])
+
+        return vd, L_dev, vd_gamma, vd_idx_map
+
+    def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
+        """
+        Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
+        triangles based on the gamma parameter, as described in Section 4.3.
+        """
+        with torch.no_grad():
+            group_mask = (edge_counts == 4) & surf_edges_mask  # surface edges shared by 4 cubes.
+            group = idx_map.reshape(-1)[group_mask]
+            vd_idx = vd_idx_map[group_mask]
+            edge_indices, indices = torch.sort(group, stable=True)
+            quad_vd_idx = vd_idx[indices].reshape(-1, 4)
+
+            # Ensure all face directions point towards the positive SDF to maintain consistent winding.
+            s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
+            flip_mask = s_edges[:, 0] > 0
+            quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
+                                     quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
+        if grad_func is not None:
+            # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
+            with torch.no_grad():
+                vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
+                quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
+                gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
+                gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
+        else:
+            quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
+            gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
+                0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
+            gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
+                1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
+        if not training:
+            mask = (gamma_02 > gamma_13).squeeze(1)
+            faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
+            faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
+            faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
+            faces = faces.reshape(-1, 3)
+        else:
+            vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
+            vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
+                     torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
+            vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
+                     torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
+            weight_sum = (gamma_02 + gamma_13) + 1e-8
+            vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
+                         weight_sum.unsqueeze(-1)).squeeze(1)
+            vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
+            vd = torch.cat([vd, vd_center])
+            faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
+            faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
+        return vd, faces, s_edges, edge_indices
+
+    def _tetrahedralize(
+            self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
+            surf_cubes, training):
+        """Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5."""
+        occ_n = s_n < 0
+        occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
+        occ_sum = torch.sum(occ_fx8, -1)
+
+        inside_verts = x_nx3[occ_n]
+        mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
+        mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
+        """
+        For each grid edge connecting two grid vertices with different
+        signs, we first form a four-sided pyramid by connecting one
+        of the grid vertices with four mesh vertices that correspond
+        to the grid edge and then subdivide the pyramid into two tetrahedra
+        """
+        inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
+            s_edges < 0]]
+        if not training:
+            inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
+        else:
+            inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
+
+        tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
+        """
+        For each grid edge connecting two grid vertices with the
+        same sign, the tetrahedron is formed by the two grid vertices
+        and two vertices in consecutive adjacent cells
+        """
+        inside_cubes = (occ_sum == 8)
+        inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
+        inside_cubes_center_idx = torch.arange(
+            inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
+
+        surface_n_inside_cubes = surf_cubes | inside_cubes
+        edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
+                                            dtype=torch.long, device=x_nx3.device) * -1
+        surf_cubes = surf_cubes[surface_n_inside_cubes]
+        inside_cubes = inside_cubes[surface_n_inside_cubes]
+        edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
+        edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
+
+        all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
+        unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
+        unique_edges = unique_edges.long()
+        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
+        mask = mask_edges[_idx_map]
+        counts = counts[_idx_map]
+        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
+        mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
+        idx_map = mapping[_idx_map]
+
+        group_mask = (counts == 4) & mask
+        group = idx_map.reshape(-1)[group_mask]
+        edge_indices, indices = torch.sort(group)
+        cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
+                                device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
+        edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
+            0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
+        # Identify the face shared by the adjacent cells.
+        cube_idx_4 = cube_idx[indices].reshape(-1, 4)
+        edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
+        shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
+        cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
+        # Identify an edge of the face with different signs and
+        # select the mesh vertex corresponding to the identified edge.
+        case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
+        case_ids_expand[surf_cubes] = case_ids
+        cases = case_ids_expand[cube_idx_4x2]
+        quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
+        mask = (quad_edge == -1).sum(-1) == 0
+        inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
+        tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
+
+        tets = torch.cat([tets_surface, tets_inside])
+        vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
+        return vertices, tets
diff --git a/src/models/geometry/rep_3d/flexicubes_geometry.py b/src/models/geometry/rep_3d/flexicubes_geometry.py
new file mode 100755
index 0000000000000000000000000000000000000000..6b4801479c5f4dd598a8fdafd79ac569ad021f51
--- /dev/null
+++ b/src/models/geometry/rep_3d/flexicubes_geometry.py
@@ -0,0 +1,119 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+import torch
+
+from . import Geometry
+from .flexicubes import FlexiCubes  # replace later
+
+
+def get_center_boundary_index(grid_res, device):
+    v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
+    v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
+    center_indices = torch.nonzero(v.reshape(-1))
+
+    v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
+    v[:2, ...] = True
+    v[-2:, ...] = True
+    v[:, :2, ...] = True
+    v[:, -2:, ...] = True
+    v[:, :, :2] = True
+    v[:, :, -2:] = True
+    boundary_indices = torch.nonzero(v.reshape(-1))
+    return center_indices, boundary_indices
+
+###############################################################################
+#  Geometry interface
+###############################################################################
+class FlexiCubesGeometry(Geometry):
+    def __init__(
+            self, grid_res=64, scale=2.0, device='cuda', renderer=None,
+            render_type='neural_render', args=None):
+        super().__init__()
+        self.grid_res = grid_res
+        self.device = device
+        self.args = args
+        self.fc = FlexiCubes(device, weight_scale=0.5)
+        self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
+        if isinstance(scale, list):
+            self.verts[:, 0] = self.verts[:, 0] * scale[0]
+            self.verts[:, 1] = self.verts[:, 1] * scale[1]
+            self.verts[:, 2] = self.verts[:, 2] * scale[1]
+        else:
+            self.verts = self.verts * scale
+
+        all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
+        self.all_edges = torch.unique(all_edges, dim=0)
+
+        # Parameters used for fix boundary sdf
+        self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
+        self.renderer = renderer
+        self.render_type = render_type
+
+    def getAABB(self):
+        return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
+
+    def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
+        if indices is None:
+            indices = self.indices
+
+        verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
+                                            beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
+                                            gamma_f=weight_n[:, 20], training=is_training
+                                            )
+        return verts, faces, v_reg_loss
+
+
+    def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
+        return_value = dict()
+        if self.render_type == 'neural_render':
+            tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh(
+                mesh_v_nx3.unsqueeze(dim=0),
+                mesh_f_fx3.int(),
+                camera_mv_bx4x4,
+                mesh_v_nx3.unsqueeze(dim=0),
+                resolution=resolution,
+                device=self.device,
+                hierarchical_mask=hierarchical_mask
+            )
+
+            return_value['tex_pos'] = tex_pos
+            return_value['mask'] = mask
+            return_value['hard_mask'] = hard_mask
+            return_value['rast'] = rast
+            return_value['v_pos_clip'] = v_pos_clip
+            return_value['mask_pyramid'] = mask_pyramid
+            return_value['depth'] = depth
+            return_value['normal'] = normal
+        else:
+            raise NotImplementedError
+
+        return return_value
+
+    def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
+        # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
+        v_list = []
+        f_list = []
+        n_batch = v_deformed_bxnx3.shape[0]
+        all_render_output = []
+        for i_batch in range(n_batch):
+            verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
+            v_list.append(verts_nx3)
+            f_list.append(faces_fx3)
+            render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
+            all_render_output.append(render_output)
+
+        # Concatenate all render output
+        return_keys = all_render_output[0].keys()
+        return_value = dict()
+        for k in return_keys:
+            value = [v[k] for v in all_render_output]
+            return_value[k] = value
+            # We can do concatenation outside of the render
+        return return_value
diff --git a/src/models/geometry/rep_3d/tables.py b/src/models/geometry/rep_3d/tables.py
new file mode 100755
index 0000000000000000000000000000000000000000..332116efc75c82552c266a7e351b41bcf21d7eb5
--- /dev/null
+++ b/src/models/geometry/rep_3d/tables.py
@@ -0,0 +1,793 @@
+# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+dmc_table = [
+[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
+]
+num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
+2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
+1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
+1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
+2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
+3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
+2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
+1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
+1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
+1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
+check_table = [
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 194],
+[1, -1, 0, 0, 193],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 164],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 161],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 152],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 145],
+[1, 0, 0, 1, 144],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 137],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 133],
+[1, 0, 1, 0, 132],
+[1, 1, 0, 0, 131],
+[1, 1, 0, 0, 130],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 100],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 98],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 96],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 88],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 82],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 74],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 72],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 70],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 67],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 65],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 56],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 52],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 44],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 40],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 38],
+[1, 0, -1, 0, 37],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 33],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 28],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 26],
+[1, 0, 0, -1, 25],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 20],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 18],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 9],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 6],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0]
+]
+tet_table = [
+[-1, -1, -1, -1, -1, -1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, -1],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[0, 4, 0, 4, 4, -1],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[5, 5, 5, 5, 5, 5],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, -1, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, -1, 2, 4, 4, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 4, 4, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 4, 2, 4, 4, 2],
+[0, 4, 0, 4, 4, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 5, 2, 5, 5, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, -1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[4, 1, 1, 4, 4, 1],
+[0, 1, 1, 0, 0, 1],
+[4, 0, 0, 4, 4, 0],
+[2, 2, 2, 2, 2, 2],
+[-1, 1, 1, 4, 4, 1],
+[0, 1, 1, 4, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[5, 1, 1, 5, 5, 1],
+[0, 1, 1, 0, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[8, 8, 8, 8, 8, 8],
+[1, 1, 1, 4, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 4, 4, 1],
+[0, 4, 0, 4, 4, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 5, 5, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[6, 6, 6, 6, 6, 6],
+[6, -1, 0, 6, 0, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 1, 1, 6, 1, 6],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[6, 4, -1, 6, 4, 6],
+[6, 4, 0, 6, 4, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 1, 1, 6, 1, 6],
+[5, 5, 5, 5, 5, 5],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 4, 2, 2, 4, 2],
+[0, 4, 0, 4, 4, 0],
+[2, 0, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[6, 1, 1, 6, -1, 6],
+[6, 1, 1, 6, 0, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 2, 2, 6, 2, 6],
+[4, 1, 1, 4, 4, 1],
+[0, 1, 1, 0, 0, 1],
+[4, 0, 0, 4, 4, 4],
+[2, 2, 2, 2, 2, 2],
+[6, 1, 1, 6, 4, 6],
+[6, 1, 1, 6, 4, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 2, 2, 6, 2, 6],
+[5, 1, 1, 5, 5, 1],
+[0, 1, 1, 0, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[6, 6, 6, 6, 6, 6],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 4, 1],
+[0, 4, 0, 4, 4, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 5, 0, 5, 0, 5],
+[5, 5, 5, 5, 5, 5],
+[5, 5, 5, 5, 5, 5],
+[0, 5, 0, 5, 0, 5],
+[-1, 5, 0, 5, 0, 5],
+[1, 5, 1, 5, 1, 5],
+[4, 5, -1, 5, 4, 5],
+[0, 5, 0, 5, 0, 5],
+[4, 5, 0, 5, 4, 5],
+[1, 5, 1, 5, 1, 5],
+[4, 4, 4, 4, 4, 4],
+[0, 4, 0, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[6, 6, 6, 6, 6, 6],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 5, 2, 5, -1, 5],
+[0, 5, 0, 5, 0, 5],
+[2, 5, 2, 5, 0, 5],
+[1, 5, 1, 5, 1, 5],
+[2, 5, 2, 5, 4, 5],
+[0, 5, 0, 5, 0, 5],
+[2, 5, 2, 5, 4, 5],
+[1, 5, 1, 5, 1, 5],
+[2, 4, 2, 4, 4, 2],
+[0, 4, 0, 4, 4, 4],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 6, 2, 6, 6, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, 1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[4, 1, 1, 1, 4, 1],
+[0, 1, 1, 1, 0, 1],
+[4, 0, 0, 4, 4, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[1, 1, 1, 1, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[6, 0, 0, 6, 0, 6],
+[0, 0, 0, 0, 0, 0],
+[6, 6, 6, 6, 6, 6],
+[5, 5, 5, 5, 5, 5],
+[5, 5, 0, 5, 0, 5],
+[5, 5, 0, 5, 0, 5],
+[5, 5, 1, 5, 1, 5],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 0, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[4, 4, 0, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[8, 8, 8, 8, 8, 8],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 1, 1, 4, 4, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 4, 2, 4, 4, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[12, 12, 12, 12, 12, 12]
+]
diff --git a/src/models/lrm.py b/src/models/lrm.py
new file mode 100755
index 0000000000000000000000000000000000000000..4a5e6fbea5a082885047e618a6eddd6a19fe8d78
--- /dev/null
+++ b/src/models/lrm.py
@@ -0,0 +1,196 @@
+# Copyright (c) 2023, Zexin He
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import mcubes
+import numpy as np
+import nvdiffrast.torch as dr
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from ..utils.mesh_util import xatlas_uvmap
+from .decoder.transformer import TriplaneTransformer
+from .encoder.dino_wrapper import DinoWrapper
+from .renderer.synthesizer import TriplaneSynthesizer
+
+
+class InstantNeRF(nn.Module):
+    """Full model of the large reconstruction model."""
+
+    def __init__(
+        self,
+        encoder_freeze: bool = False,
+        encoder_model_name: str = 'facebook/dino-vitb16',
+        encoder_feat_dim: int = 768,
+        transformer_dim: int = 1024,
+        transformer_layers: int = 16,
+        transformer_heads: int = 16,
+        triplane_low_res: int = 32,
+        triplane_high_res: int = 64,
+        triplane_dim: int = 80,
+        rendering_samples_per_ray: int = 128,
+    ):
+        super().__init__()
+
+        # modules
+        self.encoder = DinoWrapper(
+            model_name=encoder_model_name,
+            freeze=encoder_freeze,
+        )
+
+        self.transformer = TriplaneTransformer(
+            inner_dim=transformer_dim,
+            num_layers=transformer_layers,
+            num_heads=transformer_heads,
+            image_feat_dim=encoder_feat_dim,
+            triplane_low_res=triplane_low_res,
+            triplane_high_res=triplane_high_res,
+            triplane_dim=triplane_dim,
+        )
+
+        self.synthesizer = TriplaneSynthesizer(
+            triplane_dim=triplane_dim,
+            samples_per_ray=rendering_samples_per_ray,
+        )
+
+    def forward_planes(self, images, cameras):
+        # images: [B, V, C_img, H_img, W_img]
+        # cameras: [B, V, 16]
+        B = images.shape[0]
+
+        # encode images
+        image_feats = self.encoder(images, cameras)
+        image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
+
+        # transformer generating planes
+        planes = self.transformer(image_feats)
+
+        return planes
+
+    def forward(self, images, cameras, render_cameras, render_size: int):
+        # images: [B, V, C_img, H_img, W_img]
+        # cameras: [B, V, 16]
+        # render_cameras: [B, M, D_cam_render]
+        # render_size: int
+        _B, _M = render_cameras.shape[:2]
+
+        planes = self.forward_planes(images, cameras)
+
+        # render target views
+        render_results = self.synthesizer(planes, render_cameras, render_size)
+
+        return {
+            'planes': planes,
+            **render_results,
+        }
+
+    def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
+        """
+        Predict Texture given triplanes
+        :param planes: the triplane feature map
+        :param tex_pos: Position we want to query the texture field
+        :param hard_mask: 2D silhoueete of the rendered image.
+        """
+        tex_pos = torch.cat(tex_pos, dim=0)
+        if hard_mask is not None:
+            tex_pos = tex_pos * hard_mask.float()
+        batch_size = tex_pos.shape[0]
+        tex_pos = tex_pos.reshape(batch_size, -1, 3)
+        ###################
+        # We use mask to get the texture location (to save the memory)
+        if hard_mask is not None:
+            n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
+            sample_tex_pose_list = []
+            max_point = n_point_list.max()
+            expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
+            for i in range(tex_pos.shape[0]):
+                tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
+                if tex_pos_one_shape.shape[1] < max_point:
+                    tex_pos_one_shape = torch.cat(
+                        [tex_pos_one_shape, torch.zeros(
+                            1, max_point - tex_pos_one_shape.shape[1], 3,
+                            device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
+                sample_tex_pose_list.append(tex_pos_one_shape)
+            tex_pos = torch.cat(sample_tex_pose_list, dim=0)
+
+        tex_feat = self.synthesizer.forward_points(planes, tex_pos)['rgb']
+
+        if hard_mask is not None:
+            final_tex_feat = torch.zeros(
+                planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
+            expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
+            for i in range(planes.shape[0]):
+                final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
+            tex_feat = final_tex_feat
+
+        return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
+
+    def extract_mesh(
+        self,
+        planes: torch.Tensor,
+        mesh_resolution: int = 256,
+        mesh_threshold: int = 10.0,
+        use_texture_map: bool = False,
+        texture_resolution: int = 1024,
+        **kwargs,
+    ):
+        """
+        Extract a 3D mesh from triplane nerf. Only support batch_size 1.
+        :param planes: triplane features
+        :param mesh_resolution: marching cubes resolution
+        :param mesh_threshold: iso-surface threshold
+        :param use_texture_map: use texture map or vertex color
+        :param texture_resolution: the resolution of texture map.
+        """
+        assert planes.shape[0] == 1
+        device = planes.device
+
+        grid_out = self.synthesizer.forward_grid(
+            planes=planes,
+            grid_size=mesh_resolution,
+        )
+
+        vertices, faces = mcubes.marching_cubes(
+            grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(),
+            mesh_threshold,
+        )
+        vertices = vertices / (mesh_resolution - 1) * 2 - 1
+
+        if not use_texture_map:
+            # query vertex colors
+            vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0)
+            vertices_colors = self.synthesizer.forward_points(
+                planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy()
+            vertices_colors = (vertices_colors * 255).astype(np.uint8)
+
+            return vertices, faces, vertices_colors
+
+        # use x-atlas to get uv mapping for the mesh
+        vertices = torch.tensor(vertices, dtype=torch.float32, device=device)
+        faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device)
+
+        ctx = dr.RasterizeCudaContext(device=device)
+        uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
+            ctx, vertices, faces, resolution=texture_resolution)
+        tex_hard_mask = tex_hard_mask.float()
+
+        # query the texture field to get the RGB color for texture map
+        tex_feat = self.get_texture_prediction(
+            planes, [gb_pos], tex_hard_mask)
+        background_feature = torch.zeros_like(tex_feat)
+        img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask)
+        texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
+
+        return vertices, faces, uvs, mesh_tex_idx, texture_map
diff --git a/src/models/lrm_mesh.py b/src/models/lrm_mesh.py
new file mode 100755
index 0000000000000000000000000000000000000000..0eeef318875fc1441bf0562a0dd997bffffee980
--- /dev/null
+++ b/src/models/lrm_mesh.py
@@ -0,0 +1,385 @@
+# Copyright (c) 2023, Tencent Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import numpy as np
+import nvdiffrast.torch as dr
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from ..utils.mesh_util import xatlas_uvmap
+from .decoder.transformer import TriplaneTransformer
+from .encoder.dino_wrapper import DinoWrapper
+from .geometry.camera.perspective_camera import PerspectiveCamera
+from .geometry.render.neural_render import NeuralRender
+from .geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry
+from .renderer.synthesizer_mesh import TriplaneSynthesizer
+
+
+class InstantMesh(nn.Module):
+    """Full model of the large reconstruction model."""
+
+    def __init__(
+        self,
+        encoder_freeze: bool = False,
+        encoder_model_name: str = 'facebook/dino-vitb16',
+        encoder_feat_dim: int = 768,
+        transformer_dim: int = 1024,
+        transformer_layers: int = 16,
+        transformer_heads: int = 16,
+        triplane_low_res: int = 32,
+        triplane_high_res: int = 64,
+        triplane_dim: int = 80,
+        rendering_samples_per_ray: int = 128,
+        grid_res: int = 128,
+        grid_scale: float = 2.0,
+    ):
+        super().__init__()
+
+        # attributes
+        self.grid_res = grid_res
+        self.grid_scale = grid_scale
+        self.deformation_multiplier = 4.0
+
+        # modules
+        self.encoder = DinoWrapper(
+            model_name=encoder_model_name,
+            freeze=encoder_freeze,
+        )
+
+        self.transformer = TriplaneTransformer(
+            inner_dim=transformer_dim,
+            num_layers=transformer_layers,
+            num_heads=transformer_heads,
+            image_feat_dim=encoder_feat_dim,
+            triplane_low_res=triplane_low_res,
+            triplane_high_res=triplane_high_res,
+            triplane_dim=triplane_dim,
+        )
+
+        self.synthesizer = TriplaneSynthesizer(
+            triplane_dim=triplane_dim,
+            samples_per_ray=rendering_samples_per_ray,
+        )
+
+    def init_flexicubes_geometry(self, device, fovy=50.0, use_renderer=True):
+        camera = PerspectiveCamera(fovy=fovy, device=device)
+        if use_renderer:
+            renderer = NeuralRender(device, camera_model=camera)
+        else:
+            renderer = None
+        self.geometry = FlexiCubesGeometry(
+            grid_res=self.grid_res,
+            scale=self.grid_scale,
+            renderer=renderer,
+            render_type='neural_render',
+            device=device,
+        )
+
+    def forward_planes(self, images, cameras):
+        # images: [B, V, C_img, H_img, W_img]
+        # cameras: [B, V, 16]
+        B = images.shape[0]
+
+        # encode images
+        image_feats = self.encoder(images, cameras)
+        image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B)
+
+        # decode triplanes
+        planes = self.transformer(image_feats)
+
+        return planes
+
+    def get_sdf_deformation_prediction(self, planes):
+        """
+        Predict SDF and deformation for tetrahedron vertices
+        :param planes: triplane feature map for the geometry.
+        """
+        init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1)
+
+        # Step 1: predict the SDF and deformation
+        sdf, deformation, weight = torch.utils.checkpoint.checkpoint(
+            self.synthesizer.get_geometry_prediction,
+            planes,
+            init_position,
+            self.geometry.indices,
+            use_reentrant=False,
+        )
+
+        # Step 2: Normalize the deformation to avoid the flipped triangles.
+        deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation)
+        sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32)
+
+        ####
+        # Step 3: Fix some sdf if we observe empty shape (full positive or full negative)
+        sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1))
+        sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1)
+        pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1)
+        neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1)
+        zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
+        if torch.sum(zero_surface).item() > 0:
+            update_sdf = torch.zeros_like(sdf[0:1])
+            max_sdf = sdf.max()
+            min_sdf = sdf.min()
+            update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf)  # greater than zero
+            update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf)  # smaller than zero
+            new_sdf = torch.zeros_like(sdf)
+            for i_batch in range(zero_surface.shape[0]):
+                if zero_surface[i_batch]:
+                    new_sdf[i_batch:i_batch + 1] += update_sdf
+            update_mask = (new_sdf == 0).float()
+            # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative)
+            sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1)
+            sdf_reg_loss = sdf_reg_loss * zero_surface.float()
+            sdf = sdf * update_mask + new_sdf * (1 - update_mask)
+
+        # Step 4: Here we remove the gradient for the bad sdf (full positive or full negative)
+        final_sdf = []
+        final_def = []
+        for i_batch in range(zero_surface.shape[0]):
+            if zero_surface[i_batch]:
+                final_sdf.append(sdf[i_batch: i_batch + 1].detach())
+                final_def.append(deformation[i_batch: i_batch + 1].detach())
+            else:
+                final_sdf.append(sdf[i_batch: i_batch + 1])
+                final_def.append(deformation[i_batch: i_batch + 1])
+        sdf = torch.cat(final_sdf, dim=0)
+        deformation = torch.cat(final_def, dim=0)
+        return sdf, deformation, sdf_reg_loss, weight
+
+    def get_geometry_prediction(self, planes=None):
+        """
+        Function to generate mesh with give triplanes
+        :param planes: triplane features.
+        """
+        # Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid.
+        sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes)
+        v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation
+        tets = self.geometry.indices
+        n_batch = planes.shape[0]
+        v_list = []
+        f_list = []
+        flexicubes_surface_reg_list = []
+
+        # Step 2: Using marching tet to obtain the mesh
+        for i_batch in range(n_batch):
+            verts, faces, flexicubes_surface_reg = self.geometry.get_mesh(
+                v_deformed[i_batch],
+                sdf[i_batch].squeeze(dim=-1),
+                with_uv=False,
+                indices=tets,
+                weight_n=weight[i_batch].squeeze(dim=-1),
+                is_training=self.training,
+            )
+            flexicubes_surface_reg_list.append(flexicubes_surface_reg)
+            v_list.append(verts)
+            f_list.append(faces)
+
+        flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean()
+        flexicubes_weight_reg = (weight ** 2).mean()
+
+        return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
+
+    def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
+        """
+        Predict Texture given triplanes
+        :param planes: the triplane feature map
+        :param tex_pos: Position we want to query the texture field
+        :param hard_mask: 2D silhoueete of the rendered image.
+        """
+        tex_pos = torch.cat(tex_pos, dim=0)
+        if hard_mask is not None:
+            tex_pos = tex_pos * hard_mask.float()
+        batch_size = tex_pos.shape[0]
+        tex_pos = tex_pos.reshape(batch_size, -1, 3)
+        ###################
+        # We use mask to get the texture location (to save the memory)
+        if hard_mask is not None:
+            n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
+            sample_tex_pose_list = []
+            max_point = n_point_list.max()
+            expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
+            for i in range(tex_pos.shape[0]):
+                tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
+                if tex_pos_one_shape.shape[1] < max_point:
+                    tex_pos_one_shape = torch.cat(
+                        [tex_pos_one_shape, torch.zeros(
+                            1, max_point - tex_pos_one_shape.shape[1], 3,
+                            device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
+                sample_tex_pose_list.append(tex_pos_one_shape)
+            tex_pos = torch.cat(sample_tex_pose_list, dim=0)
+
+        tex_feat = torch.utils.checkpoint.checkpoint(
+            self.synthesizer.get_texture_prediction,
+            planes,
+            tex_pos,
+            use_reentrant=False,
+        )
+
+        if hard_mask is not None:
+            final_tex_feat = torch.zeros(
+                planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
+            expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
+            for i in range(planes.shape[0]):
+                final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
+            tex_feat = final_tex_feat
+
+        return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
+
+    def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256):
+        """
+        Function to render a generated mesh with nvdiffrast
+        :param mesh_v: List of vertices for the mesh
+        :param mesh_f: List of faces for the mesh
+        :param cam_mv:  4x4 rotation matrix
+        :return:
+        """
+        return_value_list = []
+        for i_mesh in range(len(mesh_v)):
+            return_value = self.geometry.render_mesh(
+                mesh_v[i_mesh],
+                mesh_f[i_mesh].int(),
+                cam_mv[i_mesh],
+                resolution=render_size,
+                hierarchical_mask=False
+            )
+            return_value_list.append(return_value)
+
+        return_keys = return_value_list[0].keys()
+        return_value = dict()
+        for k in return_keys:
+            value = [v[k] for v in return_value_list]
+            return_value[k] = value
+
+        mask = torch.cat(return_value['mask'], dim=0)
+        hard_mask = torch.cat(return_value['hard_mask'], dim=0)
+        tex_pos = return_value['tex_pos']
+        depth = torch.cat(return_value['depth'], dim=0)
+        normal = torch.cat(return_value['normal'], dim=0)
+        return mask, hard_mask, tex_pos, depth, normal
+
+    def forward_geometry(self, planes, render_cameras, render_size=256):
+        """
+        Main function of our Generator. It first generate 3D mesh, then render it into 2D image
+        with given `render_cameras`.
+        :param planes: triplane features
+        :param render_cameras: cameras to render generated 3D shape.
+        """
+        B, NV = render_cameras.shape[:2]
+
+        # Generate 3D mesh first
+        mesh_v, mesh_f, sdf, _deformation, _v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
+
+        # Render the mesh into 2D image (get 3d position of each image plane)
+        cam_mv = render_cameras
+        run_n_view = cam_mv.shape[1]
+        antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size)
+
+        tex_hard_mask = hard_mask
+        tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos]
+        tex_hard_mask = torch.cat(
+            [torch.cat(
+                [tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1]
+                 for i_view in range(run_n_view)], dim=2)
+                for i in range(planes.shape[0])], dim=0)
+
+        # Querying the texture field to predict the texture feature for each pixel on the image
+        tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask)
+        background_feature = torch.ones_like(tex_feat)      # white background
+
+        # Merge them together
+        img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask)
+
+        # We should split it back to the original image shape
+        img_feat = torch.cat(
+            [torch.cat(
+                [img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)]
+                 for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0)
+
+        img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV))
+        antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV))
+        depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV))        # transform negative depth to positive
+        normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV))
+
+        out = {
+            'img': img,
+            'mask': antilias_mask,
+            'depth': depth,
+            'normal': normal,
+            'sdf': sdf,
+            'mesh_v': mesh_v,
+            'mesh_f': mesh_f,
+            'sdf_reg_loss': sdf_reg_loss,
+        }
+        return out
+
+    def forward(self, images, cameras, render_cameras, render_size: int):
+        # images: [B, V, C_img, H_img, W_img]
+        # cameras: [B, V, 16]
+        # render_cameras: [B, M, D_cam_render]
+        # render_size: int
+        _B, _M = render_cameras.shape[:2]
+
+        planes = self.forward_planes(images, cameras)
+        out = self.forward_geometry(planes, render_cameras, render_size=render_size)
+
+        return {
+            'planes': planes,
+            **out
+        }
+
+    def extract_mesh(
+        self,
+        planes: torch.Tensor,
+        use_texture_map: bool = False,
+        texture_resolution: int = 1024,
+        **kwargs,
+    ):
+        """
+        Extract a 3D mesh from FlexiCubes. Only support batch_size 1.
+        :param planes: triplane features
+        :param use_texture_map: use texture map or vertex color
+        :param texture_resolution: the resolution of texure map.
+        """
+        assert planes.shape[0] == 1
+        device = planes.device
+
+        # predict geometry first
+        mesh_v, mesh_f, _sdf, _deformation, _v_deformed, _sdf_reg_loss = self.get_geometry_prediction(planes)
+        vertices, faces = mesh_v[0], mesh_f[0]
+
+        if not use_texture_map:
+            # query vertex colors
+            vertices_tensor = vertices.unsqueeze(0)
+            vertices_colors = self.synthesizer.get_texture_prediction(
+                planes, vertices_tensor).clamp(0, 1).squeeze(0).cpu().numpy()
+            vertices_colors = (vertices_colors * 255).astype(np.uint8)
+
+            return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors
+
+        # use x-atlas to get uv mapping for the mesh
+        dr.RasterizeCudaContext(device=device)
+        uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
+            self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution)
+        tex_hard_mask = tex_hard_mask.float()
+
+        # query the texture field to get the RGB color for texture map
+        tex_feat = self.get_texture_prediction(
+            planes, [gb_pos], tex_hard_mask)
+        background_feature = torch.zeros_like(tex_feat)
+        img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask)
+        texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
+
+        return vertices, faces, uvs, mesh_tex_idx, texture_map
diff --git a/src/models/renderer/__init__.py b/src/models/renderer/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34
--- /dev/null
+++ b/src/models/renderer/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
diff --git a/src/models/renderer/synthesizer.py b/src/models/renderer/synthesizer.py
new file mode 100755
index 0000000000000000000000000000000000000000..99fee917b881be868d954946ba881c78893d5faf
--- /dev/null
+++ b/src/models/renderer/synthesizer.py
@@ -0,0 +1,205 @@
+# ORIGINAL LICENSE
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# Modified by Jiale Xu
+# The modifications are subject to the same license as the original.
+from __future__ import annotations
+
+import itertools
+
+import torch
+import torch.nn as nn
+
+from .utils.ray_sampler import RaySampler
+from .utils.renderer import ImportanceRenderer
+
+
+class OSGDecoder(nn.Module):
+    """
+    Triplane decoder that gives RGB and sigma values from sampled features.
+    Using ReLU here instead of Softplus in the original implementation.
+
+    Reference:
+    EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
+    """
+
+    def __init__(self, n_features: int,
+                 hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
+        super().__init__()
+        self.net = nn.Sequential(
+            nn.Linear(3 * n_features, hidden_dim),
+            activation(),
+            *itertools.chain(*[[
+                nn.Linear(hidden_dim, hidden_dim),
+                activation(),
+            ] for _ in range(num_layers - 2)]),
+            nn.Linear(hidden_dim, 1 + 3),
+        )
+        # init all bias to zero
+        for m in self.modules():
+            if isinstance(m, nn.Linear):
+                nn.init.zeros_(m.bias)
+
+    def forward(self, sampled_features, ray_directions):
+        # Aggregate features by mean
+        # sampled_features = sampled_features.mean(1)
+        # Aggregate features by concatenation
+        _N, n_planes, _M, _C = sampled_features.shape
+        sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
+        x = sampled_features
+
+        N, M, C = x.shape
+        x = x.contiguous().view(N*M, C)
+
+        x = self.net(x)
+        x = x.view(N, M, -1)
+        rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001  # Uses sigmoid clamping from MipNeRF
+        sigma = x[..., 0:1]
+
+        return {'rgb': rgb, 'sigma': sigma}
+
+
+class TriplaneSynthesizer(nn.Module):
+    """
+    Synthesizer that renders a triplane volume with planes and a camera.
+
+    Reference:
+    EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
+    """
+
+    DEFAULT_RENDERING_KWARGS = {
+        'ray_start': 'auto',
+        'ray_end': 'auto',
+        'box_warp': 2.,
+        'white_back': True,
+        'disparity_space_sampling': False,
+        'clamp_mode': 'softplus',
+        'sampler_bbox_min': -1.,
+        'sampler_bbox_max': 1.,
+    }
+
+    def __init__(self, triplane_dim: int, samples_per_ray: int):
+        super().__init__()
+
+        # attributes
+        self.triplane_dim = triplane_dim
+        self.rendering_kwargs = {
+            **self.DEFAULT_RENDERING_KWARGS,
+            'depth_resolution': samples_per_ray // 2,
+            'depth_resolution_importance': samples_per_ray // 2,
+        }
+
+        # renderings
+        self.renderer = ImportanceRenderer()
+        self.ray_sampler = RaySampler()
+
+        # modules
+        self.decoder = OSGDecoder(n_features=triplane_dim)
+
+    def forward(self, planes, cameras, render_size=128, crop_params=None):
+        # planes: (N, 3, D', H', W')
+        # cameras: (N, M, D_cam)
+        # render_size: int
+        assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras"
+        N, M = cameras.shape[:2]
+
+        cam2world_matrix = cameras[..., :16].view(N, M, 4, 4)
+        intrinsics = cameras[..., 16:25].view(N, M, 3, 3)
+
+        # Create a batch of rays for volume rendering
+        ray_origins, ray_directions = self.ray_sampler(
+            cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4),
+            intrinsics=intrinsics.reshape(-1, 3, 3),
+            render_size=render_size,
+        )
+        assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins"
+        assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
+
+        # Crop rays if crop_params is available
+        if crop_params is not None:
+            ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3)
+            ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3)
+            i, j, h, w = crop_params
+            ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3)
+            ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3)
+
+        # Perform volume rendering
+        rgb_samples, depth_samples, weights_samples = self.renderer(
+            planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs,
+        )
+
+        # Reshape into 'raw' neural-rendered image
+        if crop_params is not None:
+            Himg, Wimg = crop_params[2:]
+        else:
+            Himg = Wimg = render_size
+        rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous()
+        depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
+        weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg)
+
+        out = {
+            'images_rgb': rgb_images,
+            'images_depth': depth_images,
+            'images_weight': weight_images,
+        }
+        return out
+
+    def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
+        # planes: (N, 3, D', H', W')
+        # grid_size: int
+        # aabb: (N, 2, 3)
+        if aabb is None:
+            aabb = torch.tensor([
+                [self.rendering_kwargs['sampler_bbox_min']] * 3,
+                [self.rendering_kwargs['sampler_bbox_max']] * 3,
+            ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
+        assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
+        N = planes.shape[0]
+
+        # create grid points for triplane query
+        grid_points = []
+        for i in range(N):
+            grid_points.append(torch.stack(torch.meshgrid(
+                torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
+                torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
+                torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
+                indexing='ij',
+            ), dim=-1).reshape(-1, 3))
+        cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
+
+        features = self.forward_points(planes, cube_grid)
+
+        # reshape into grid
+        features = {
+            k: v.reshape(N, grid_size, grid_size, grid_size, -1)
+            for k, v in features.items()
+        }
+        return features
+
+    def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
+        # planes: (N, 3, D', H', W')
+        # points: (N, P, 3)
+        _N, _P = points.shape[:2]
+
+        # query triplane in chunks
+        outs = []
+        for i in range(0, points.shape[1], chunk_size):
+            chunk_points = points[:, i:i+chunk_size]
+
+            # query triplane
+            chunk_out = self.renderer.run_model_activated(
+                planes=planes,
+                decoder=self.decoder,
+                sample_coordinates=chunk_points,
+                sample_directions=torch.zeros_like(chunk_points),
+                options=self.rendering_kwargs,
+            )
+            outs.append(chunk_out)
+
+        # concatenate the outputs
+        point_features = {
+            k: torch.cat([out[k] for out in outs], dim=1)
+            for k in outs[0].keys()
+        }
+        return point_features
diff --git a/src/models/renderer/synthesizer_mesh.py b/src/models/renderer/synthesizer_mesh.py
new file mode 100755
index 0000000000000000000000000000000000000000..a98430572450e78d3356e661d32778dfcce1cd63
--- /dev/null
+++ b/src/models/renderer/synthesizer_mesh.py
@@ -0,0 +1,144 @@
+# ORIGINAL LICENSE
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# Modified by Jiale Xu
+# The modifications are subject to the same license as the original.
+from __future__ import annotations
+
+import itertools
+
+import torch
+import torch.nn as nn
+
+from .utils.renderer import generate_planes, sample_from_planes
+
+
+class OSGDecoder(nn.Module):
+    """
+    Triplane decoder that gives RGB and sigma values from sampled features.
+    Using ReLU here instead of Softplus in the original implementation.
+
+    Reference:
+    EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
+    """
+
+    def __init__(self, n_features: int,
+                 hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU):
+        super().__init__()
+
+        self.net_sdf = nn.Sequential(
+            nn.Linear(3 * n_features, hidden_dim),
+            activation(),
+            *itertools.chain(*[[
+                nn.Linear(hidden_dim, hidden_dim),
+                activation(),
+            ] for _ in range(num_layers - 2)]),
+            nn.Linear(hidden_dim, 1),
+        )
+        self.net_rgb = nn.Sequential(
+            nn.Linear(3 * n_features, hidden_dim),
+            activation(),
+            *itertools.chain(*[[
+                nn.Linear(hidden_dim, hidden_dim),
+                activation(),
+            ] for _ in range(num_layers - 2)]),
+            nn.Linear(hidden_dim, 3),
+        )
+        self.net_deformation = nn.Sequential(
+            nn.Linear(3 * n_features, hidden_dim),
+            activation(),
+            *itertools.chain(*[[
+                nn.Linear(hidden_dim, hidden_dim),
+                activation(),
+            ] for _ in range(num_layers - 2)]),
+            nn.Linear(hidden_dim, 3),
+        )
+        self.net_weight = nn.Sequential(
+            nn.Linear(8 * 3 * n_features, hidden_dim),
+            activation(),
+            *itertools.chain(*[[
+                nn.Linear(hidden_dim, hidden_dim),
+                activation(),
+            ] for _ in range(num_layers - 2)]),
+            nn.Linear(hidden_dim, 21),
+        )
+
+        # init all bias to zero
+        for m in self.modules():
+            if isinstance(m, nn.Linear):
+                nn.init.zeros_(m.bias)
+
+    def get_geometry_prediction(self, sampled_features, flexicubes_indices):
+        _N, n_planes, _M, _C = sampled_features.shape
+        sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
+
+        sdf = self.net_sdf(sampled_features)
+        deformation = self.net_deformation(sampled_features)
+
+        grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1)
+        grid_features = grid_features.reshape(
+            sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1])
+        weight = self.net_weight(grid_features) * 0.1
+
+        return sdf, deformation, weight
+
+    def get_texture_prediction(self, sampled_features):
+        _N, n_planes, _M, _C = sampled_features.shape
+        sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
+
+        rgb = self.net_rgb(sampled_features)
+        rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001  # Uses sigmoid clamping from MipNeRF
+
+        return rgb
+
+
+class TriplaneSynthesizer(nn.Module):
+    """
+    Synthesizer that renders a triplane volume with planes and a camera.
+
+    Reference:
+    EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
+    """
+
+    DEFAULT_RENDERING_KWARGS = {
+        'ray_start': 'auto',
+        'ray_end': 'auto',
+        'box_warp': 2.,
+        'white_back': True,
+        'disparity_space_sampling': False,
+        'clamp_mode': 'softplus',
+        'sampler_bbox_min': -1.,
+        'sampler_bbox_max': 1.,
+    }
+
+    def __init__(self, triplane_dim: int, samples_per_ray: int):
+        super().__init__()
+
+        # attributes
+        self.triplane_dim = triplane_dim
+        self.rendering_kwargs = {
+            **self.DEFAULT_RENDERING_KWARGS,
+            'depth_resolution': samples_per_ray // 2,
+            'depth_resolution_importance': samples_per_ray // 2,
+        }
+
+        # modules
+        self.plane_axes = generate_planes()
+        self.decoder = OSGDecoder(n_features=triplane_dim)
+
+    def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices):
+        plane_axes = self.plane_axes.to(planes.device)
+        sampled_features = sample_from_planes(
+            plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])
+
+        sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices)
+        return sdf, deformation, weight
+
+    def get_texture_prediction(self, planes, sample_coordinates):
+        plane_axes = self.plane_axes.to(planes.device)
+        sampled_features = sample_from_planes(
+            plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp'])
+
+        rgb = self.decoder.get_texture_prediction(sampled_features)
+        return rgb
diff --git a/src/models/renderer/utils/__init__.py b/src/models/renderer/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34
--- /dev/null
+++ b/src/models/renderer/utils/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
diff --git a/src/models/renderer/utils/math_utils.py b/src/models/renderer/utils/math_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..7770a493cdf64ba10071e1a82a6b322f8b3e3b0a
--- /dev/null
+++ b/src/models/renderer/utils/math_utils.py
@@ -0,0 +1,114 @@
+# MIT License
+
+# Copyright (c) 2022 Petr Kellnhofer
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+from __future__ import annotations
+
+import torch
+
+
+def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
+    """Left-multiplies MxM @ NxM. Returns NxM."""
+    res = torch.matmul(vectors4, matrix.T)
+    return res
+
+
+def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
+    """Normalize vector lengths."""
+    return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
+
+def torch_dot(x: torch.Tensor, y: torch.Tensor):
+    """Dot product of two tensors."""
+    return (x * y).sum(-1)
+
+
+def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
+    """
+    Author: Petr Kellnhofer
+    Intersects rays with the [-1, 1] NDC volume.
+    Returns min and max distance of entry.
+    Returns -1 for no intersection.
+    https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection.
+    """
+    o_shape = rays_o.shape
+    rays_o = rays_o.detach().reshape(-1, 3)
+    rays_d = rays_d.detach().reshape(-1, 3)
+
+
+    bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
+    bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
+    bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
+    is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
+
+    # Precompute inverse for stability.
+    invdir = 1 / rays_d
+    sign = (invdir < 0).long()
+
+    # Intersect with YZ plane.
+    tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+    tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+
+    # Intersect with XZ plane.
+    tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+    tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+
+    # Resolve parallel rays.
+    is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
+
+    # Use the shortest intersection.
+    tmin = torch.max(tmin, tymin)
+    tmax = torch.min(tmax, tymax)
+
+    # Intersect with XY plane.
+    tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+    tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+
+    # Resolve parallel rays.
+    is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
+
+    # Use the shortest intersection.
+    tmin = torch.max(tmin, tzmin)
+    tmax = torch.min(tmax, tzmax)
+
+    # Mark invalid.
+    tmin[torch.logical_not(is_valid)] = -1
+    tmax[torch.logical_not(is_valid)] = -2
+
+    return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
+
+
+def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
+    """
+    Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
+    Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
+    """
+    # create a tensor of 'num' steps from 0 to 1
+    steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
+
+    # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
+    # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
+    #   "cannot statically infer the expected size of a list in this contex", hence the code below
+    for i in range(start.ndim):
+        steps = steps.unsqueeze(-1)
+
+    # the output starts at 'start' and increments until 'stop' in each dimension
+    out = start[None] + steps * (stop - start)[None]
+
+    return out
diff --git a/src/models/renderer/utils/ray_marcher.py b/src/models/renderer/utils/ray_marcher.py
new file mode 100755
index 0000000000000000000000000000000000000000..853ccc7495ed592147293c9aee3616ff5d27328d
--- /dev/null
+++ b/src/models/renderer/utils/ray_marcher.py
@@ -0,0 +1,72 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+#
+# Modified by Jiale Xu
+# The modifications are subject to the same license as the original.
+
+
+"""
+The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
+Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!).
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+
+class MipRayMarcher2(nn.Module):
+    def __init__(self, activation_factory):
+        super().__init__()
+        self.activation_factory = activation_factory
+
+    def run_forward(self, colors, densities, depths, rendering_options, normals=None):
+        dtype = colors.dtype
+        deltas = depths[:, :, 1:] - depths[:, :, :-1]
+        colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
+        densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
+        depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
+
+        # using factory mode for better usability
+        densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype)
+
+        density_delta = densities_mid * deltas
+
+        alpha = 1 - torch.exp(-density_delta).to(dtype)
+
+        alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
+        weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
+        weights = weights.to(dtype)
+
+        composite_rgb = torch.sum(weights * colors_mid, -2)
+        weight_total = weights.sum(2)
+        # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
+        composite_depth = torch.sum(weights * depths_mid, -2)
+
+        # clip the composite to min/max range of depths
+        composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype)
+        composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
+
+        if rendering_options.get('white_back', False):
+            composite_rgb = composite_rgb + 1 - weight_total
+
+        # rendered value scale is 0-1, comment out original mipnerf scaling
+        # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
+
+        return composite_rgb, composite_depth, weights
+
+
+    def forward(self, colors, densities, depths, rendering_options, normals=None):
+        if normals is not None:
+            composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals)
+            return composite_rgb, composite_depth, composite_normals, weights
+
+        composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
+        return composite_rgb, composite_depth, weights
diff --git a/src/models/renderer/utils/ray_sampler.py b/src/models/renderer/utils/ray_sampler.py
new file mode 100755
index 0000000000000000000000000000000000000000..014e099ef94001dfca5a9253742af6bb6ad45507
--- /dev/null
+++ b/src/models/renderer/utils/ray_sampler.py
@@ -0,0 +1,143 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+#
+# Modified by Jiale Xu
+# The modifications are subject to the same license as the original.
+
+
+"""
+The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
+Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
+"""
+from __future__ import annotations
+
+import torch
+
+
+class RaySampler(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
+
+
+    def forward(self, cam2world_matrix, intrinsics, render_size):
+        """
+        Create batches of rays and return origins and directions.
+
+        cam2world_matrix: (N, 4, 4)
+        intrinsics: (N, 3, 3)
+        render_size: int
+
+        ray_origins: (N, M, 3)
+        ray_dirs: (N, M, 2)
+        """
+
+        dtype = cam2world_matrix.dtype
+        device = cam2world_matrix.device
+        N, M = cam2world_matrix.shape[0], render_size**2
+        cam_locs_world = cam2world_matrix[:, :3, 3]
+        fx = intrinsics[:, 0, 0]
+        fy = intrinsics[:, 1, 1]
+        cx = intrinsics[:, 0, 2]
+        cy = intrinsics[:, 1, 2]
+        sk = intrinsics[:, 0, 1]
+
+        uv = torch.stack(torch.meshgrid(
+            torch.arange(render_size, dtype=dtype, device=device),
+            torch.arange(render_size, dtype=dtype, device=device),
+            indexing='ij',
+        ))
+        uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
+        uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
+
+        x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
+        y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
+        z_cam = torch.ones((N, M), dtype=dtype, device=device)
+
+        x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
+        y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
+
+        cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype)
+
+        _opencv2blender = torch.tensor([
+            [1, 0, 0, 0],
+            [0, -1, 0, 0],
+            [0, 0, -1, 0],
+            [0, 0, 0, 1],
+        ], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1)
+
+        cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
+
+        world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
+
+        ray_dirs = world_rel_points - cam_locs_world[:, None, :]
+        ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype)
+
+        ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
+
+        return ray_origins, ray_dirs
+
+
+class OrthoRaySampler(torch.nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
+
+
+    def forward(self, cam2world_matrix, ortho_scale, render_size):
+        """
+        Create batches of rays and return origins and directions.
+
+        cam2world_matrix: (N, 4, 4)
+        ortho_scale: float
+        render_size: int
+
+        ray_origins: (N, M, 3)
+        ray_dirs: (N, M, 3)
+        """
+
+        N, M = cam2world_matrix.shape[0], render_size**2
+
+        uv = torch.stack(torch.meshgrid(
+            torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
+            torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
+            indexing='ij',
+        ))
+        uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
+        uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
+
+        x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
+        y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
+        z_cam = torch.zeros((N, M), device=cam2world_matrix.device)
+
+        x_lift = (x_cam - 0.5) * ortho_scale
+        y_lift = (y_cam - 0.5) * ortho_scale
+
+        cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
+
+        _opencv2blender = torch.tensor([
+            [1, 0, 0, 0],
+            [0, -1, 0, 0],
+            [0, 0, -1, 0],
+            [0, 0, 0, 1],
+        ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
+
+        cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
+
+        ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
+
+        ray_dirs_cam = torch.stack([
+            torch.zeros((N, M), device=cam2world_matrix.device),
+            torch.zeros((N, M), device=cam2world_matrix.device),
+            torch.ones((N, M), device=cam2world_matrix.device),
+        ], dim=-1)
+        ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1)
+
+        return ray_origins, ray_dirs
diff --git a/src/models/renderer/utils/renderer.py b/src/models/renderer/utils/renderer.py
new file mode 100755
index 0000000000000000000000000000000000000000..bc7074888875fa1fb074bf376de056bd194758ea
--- /dev/null
+++ b/src/models/renderer/utils/renderer.py
@@ -0,0 +1,321 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+#
+# Modified by Jiale Xu
+# The modifications are subject to the same license as the original.
+
+
+"""
+The renderer is a module that takes in rays, decides where to sample along each
+ray, and computes pixel colors using the volume rendering equation.
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from . import math_utils
+from .ray_marcher import MipRayMarcher2
+
+
+def generate_planes():
+    """
+    Defines planes by the three vectors that form the "axes" of the
+    plane. Should work with arbitrary number of planes and planes of
+    arbitrary orientation.
+
+    Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
+    """
+    return torch.tensor([[[1, 0, 0],
+                            [0, 1, 0],
+                            [0, 0, 1]],
+                            [[1, 0, 0],
+                            [0, 0, 1],
+                            [0, 1, 0]],
+                            [[0, 0, 1],
+                            [0, 1, 0],
+                            [1, 0, 0]]], dtype=torch.float32)
+
+def project_onto_planes(planes, coordinates):
+    """
+    Does a projection of a 3D point onto a batch of 2D planes,
+    returning 2D plane coordinates.
+
+    Takes plane axes of shape n_planes, 3, 3
+    # Takes coordinates of shape N, M, 3
+    # returns projections of shape N*n_planes, M, 2
+    """
+    N, M, _C = coordinates.shape
+    n_planes, _, _ = planes.shape
+    coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
+    inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
+    projections = torch.bmm(coordinates, inv_planes)
+    return projections[..., :2]
+
+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
+    assert padding_mode == 'zeros'
+    N, n_planes, C, H, W = plane_features.shape
+    _, M, _ = coordinates.shape
+    plane_features = plane_features.view(N*n_planes, C, H, W)
+    dtype = plane_features.dtype
+
+    coordinates = (2/box_warp) * coordinates # add specific box bounds
+
+    projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
+    output_features = torch.nn.functional.grid_sample(
+        plane_features,
+        projected_coordinates.to(dtype),
+        mode=mode,
+        padding_mode=padding_mode,
+        align_corners=False,
+    ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
+    return output_features
+
+def sample_from_3dgrid(grid, coordinates):
+    """
+    Expects coordinates in shape (batch_size, num_points_per_batch, 3)
+    Expects grid in shape (1, channels, H, W, D)
+    (Also works if grid has batch size)
+    Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels).
+    """
+    batch_size, _n_coords, n_dims = coordinates.shape
+    sampled_features = torch.nn.functional.grid_sample(
+        grid.expand(batch_size, -1, -1, -1, -1),
+        coordinates.reshape(batch_size, 1, 1, -1, n_dims),
+        mode='bilinear',
+        padding_mode='zeros',
+        align_corners=False,
+    )
+    N, C, H, W, D = sampled_features.shape
+    sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
+    return sampled_features
+
+class ImportanceRenderer(torch.nn.Module):
+    """
+    Modified original version to filter out-of-box samples as TensoRF does.
+
+    Reference:
+    TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.activation_factory = self._build_activation_factory()
+        self.ray_marcher = MipRayMarcher2(self.activation_factory)
+        self.plane_axes = generate_planes()
+
+    def _build_activation_factory(self):
+        def activation_factory(options: dict):
+            if options['clamp_mode'] == 'softplus':
+                return lambda x: F.softplus(x - 1)  # activation bias of -1 makes things initialize better
+            else:
+                assert False, "Renderer only supports `clamp_mode`=`softplus`!"
+        return activation_factory
+
+    def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
+                        planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
+        """
+        Additional filtering is applied to filter out-of-box samples.
+        Modifications made by Zexin He.
+        """
+
+        # context related variables
+        batch_size, num_rays, samples_per_ray, _ = depths.shape
+        device = depths.device
+
+        # define sample points with depths
+        sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
+        sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
+
+        # filter out-of-box samples
+        mask_inbox = \
+            (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
+                (sample_coordinates <= rendering_options['sampler_bbox_max'])
+        mask_inbox = mask_inbox.all(-1)
+
+        # forward model according to all samples
+        _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
+
+        # set out-of-box samples to zeros(rgb) & -inf(sigma)
+        SAFE_GUARD = 3
+        DATA_TYPE = _out['sigma'].dtype
+        colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
+        densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
+        colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox]
+
+        # reshape back
+        colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
+        densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1])
+
+        return colors_pass, densities_pass
+
+    def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
+        # self.plane_axes = self.plane_axes.to(ray_origins.device)
+
+        if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
+            ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
+            is_ray_valid = ray_end > ray_start
+            if torch.any(is_ray_valid).item():
+                ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
+                ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
+            depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
+        else:
+            # Create stratified depth samples
+            depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
+
+        # Coarse Pass
+        colors_coarse, densities_coarse = self._forward_pass(
+            depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
+            planes=planes, decoder=decoder, rendering_options=rendering_options)
+
+        # Fine Pass
+        N_importance = rendering_options['depth_resolution_importance']
+        if N_importance > 0:
+            _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
+
+            depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
+
+            colors_fine, densities_fine = self._forward_pass(
+                depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
+                planes=planes, decoder=decoder, rendering_options=rendering_options)
+
+            all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
+                                                                depths_fine, colors_fine, densities_fine)
+
+            rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
+        else:
+            rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
+
+        return rgb_final, depth_final, weights.sum(2)
+
+    def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
+        plane_axes = self.plane_axes.to(planes.device)
+        sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
+
+        out = decoder(sampled_features, sample_directions)
+        if options.get('density_noise', 0) > 0:
+            out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
+        return out
+
+    def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
+        out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
+        out['sigma'] = self.activation_factory(options)(out['sigma'])
+        return out
+
+    def sort_samples(self, all_depths, all_colors, all_densities):
+        _, indices = torch.sort(all_depths, dim=-2)
+        all_depths = torch.gather(all_depths, -2, indices)
+        all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+        all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+        return all_depths, all_colors, all_densities
+
+    def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None):
+        all_depths = torch.cat([depths1, depths2], dim = -2)
+        all_colors = torch.cat([colors1, colors2], dim = -2)
+        all_densities = torch.cat([densities1, densities2], dim = -2)
+
+        if normals1 is not None and normals2 is not None:
+            all_normals = torch.cat([normals1, normals2], dim = -2)
+        else:
+            all_normals = None
+
+        _, indices = torch.sort(all_depths, dim=-2)
+        all_depths = torch.gather(all_depths, -2, indices)
+        all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+        all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+
+        if all_normals is not None:
+            all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1]))
+            return all_depths, all_colors, all_normals, all_densities
+
+        return all_depths, all_colors, all_densities
+
+    def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
+        """Return depths of approximately uniformly spaced samples along rays."""
+        N, M, _ = ray_origins.shape
+        if disparity_space_sampling:
+            depths_coarse = torch.linspace(0,
+                                    1,
+                                    depth_resolution,
+                                    device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+            depth_delta = 1/(depth_resolution - 1)
+            depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+            depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
+        else:
+            if type(ray_start) == torch.Tensor:
+                depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
+                depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
+                depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
+            else:
+                depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+                depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
+                depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+
+        return depths_coarse
+
+    def sample_importance(self, z_vals, weights, N_importance):
+        """Return depths of importance sampled points along rays. See NeRF importance sampling for more."""
+        with torch.no_grad():
+            batch_size, num_rays, samples_per_ray, _ = z_vals.shape
+
+            z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
+            weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
+
+            # smooth weights
+            weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1), 2, 1, padding=1)
+            weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
+            weights = weights + 0.01
+
+            z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
+            importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
+                                             N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
+        return importance_z_vals
+
+    def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
+        """
+        Sample @N_importance samples from @bins with distribution defined by @weights.
+        Inputs:
+            bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
+            weights: (N_rays, N_samples_)
+            N_importance: the number of samples to draw from the distribution
+            det: deterministic or not
+            eps: a small number to prevent division by zero
+        Outputs:
+            samples: the sampled samples.
+        """
+        N_rays, N_samples_ = weights.shape
+        weights = weights + eps # prevent division by zero (don't do inplace op!)
+        pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
+        cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
+        cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1)  # (N_rays, N_samples_+1)
+                                                                   # padded to 0~1 inclusive
+
+        if det:
+            u = torch.linspace(0, 1, N_importance, device=bins.device)
+            u = u.expand(N_rays, N_importance)
+        else:
+            u = torch.rand(N_rays, N_importance, device=bins.device)
+        u = u.contiguous()
+
+        inds = torch.searchsorted(cdf, u, right=True)
+        below = torch.clamp_min(inds-1, 0)
+        above = torch.clamp_max(inds, N_samples_)
+
+        inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
+        cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
+        bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
+
+        denom = cdf_g[...,1]-cdf_g[...,0]
+        denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
+                             # anyway, therefore any value for it is fine (set to 1 here)
+
+        samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
+        return samples
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/utils/camera_util.py b/src/utils/camera_util.py
new file mode 100755
index 0000000000000000000000000000000000000000..5984a1b59dc78ef93de50cd2cc3296e6d9a8d558
--- /dev/null
+++ b/src/utils/camera_util.py
@@ -0,0 +1,111 @@
+from __future__ import annotations
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def pad_camera_extrinsics_4x4(extrinsics):
+    if extrinsics.shape[-2] == 4:
+        return extrinsics
+    padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
+    if extrinsics.ndim == 3:
+        padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
+    extrinsics = torch.cat([extrinsics, padding], dim=-2)
+    return extrinsics
+
+
+def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
+    """
+    Create OpenGL camera extrinsics from camera locations and look-at position.
+
+    camera_position: (M, 3) or (3,)
+    look_at: (3)
+    up_world: (3)
+    return: (M, 3, 4) or (3, 4)
+    """
+    # by default, looking at the origin and world up is z-axis
+    if look_at is None:
+        look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
+    if up_world is None:
+        up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
+    if camera_position.ndim == 2:
+        look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
+        up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
+
+    # OpenGL camera: z-backward, x-right, y-up
+    z_axis = camera_position - look_at
+    z_axis = F.normalize(z_axis, dim=-1).float()
+    x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
+    x_axis = F.normalize(x_axis, dim=-1).float()
+    y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
+    y_axis = F.normalize(y_axis, dim=-1).float()
+
+    extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
+    extrinsics = pad_camera_extrinsics_4x4(extrinsics)
+    return extrinsics
+
+
+def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
+    azimuths = np.deg2rad(azimuths)
+    elevations = np.deg2rad(elevations)
+
+    xs = radius * np.cos(elevations) * np.cos(azimuths)
+    ys = radius * np.cos(elevations) * np.sin(azimuths)
+    zs = radius * np.sin(elevations)
+
+    cam_locations = np.stack([xs, ys, zs], axis=-1)
+    cam_locations = torch.from_numpy(cam_locations).float()
+
+    c2ws = center_looking_at_camera_pose(cam_locations)
+    return c2ws
+
+
+def get_circular_camera_poses(M=120, radius=2.5, elevation=30.0):
+    # M: number of circular views
+    # radius: camera dist to center
+    # elevation: elevation degrees of the camera
+    # return: (M, 4, 4)
+    assert M > 0 and radius > 0
+
+    elevation = np.deg2rad(elevation)
+
+    camera_positions = []
+    for i in range(M):
+        azimuth = 2 * np.pi * i / M
+        x = radius * np.cos(elevation) * np.cos(azimuth)
+        y = radius * np.cos(elevation) * np.sin(azimuth)
+        z = radius * np.sin(elevation)
+        camera_positions.append([x, y, z])
+    camera_positions = np.array(camera_positions)
+    camera_positions = torch.from_numpy(camera_positions).float()
+    extrinsics = center_looking_at_camera_pose(camera_positions)
+    return extrinsics
+
+
+def FOV_to_intrinsics(fov, device='cpu'):
+    """
+    Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
+    Note the intrinsics are returned as normalized by image size, rather than in pixel units.
+    Assumes principal point is at image center.
+    """
+    focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
+    intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
+    return intrinsics
+
+
+def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
+    """Get the input camera parameters."""
+    azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
+    elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
+
+    c2ws = spherical_camera_pose(azimuths, elevations, radius)
+    c2ws = c2ws.float().flatten(-2)
+
+    Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)
+
+    extrinsics = c2ws[:, :12]
+    intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
+    cameras = torch.cat([extrinsics, intrinsics], dim=-1)
+
+    return cameras.unsqueeze(0).repeat(batch_size, 1, 1)
diff --git a/src/utils/infer_util.py b/src/utils/infer_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1bc4b75243ba21d74112ae6ba72e3f2ddfa6bd6
--- /dev/null
+++ b/src/utils/infer_util.py
@@ -0,0 +1,86 @@
+from __future__ import annotations
+
+import os
+from typing import Any
+
+import imageio
+import numpy as np
+import PIL.Image
+import rembg
+import torch
+
+
+def remove_background(image: PIL.Image.Image,
+    rembg_session: Any = None,
+    force: bool = False,
+    **rembg_kwargs,
+) -> PIL.Image.Image:
+    do_remove = True
+    if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
+        do_remove = False
+    do_remove = do_remove or force
+    if do_remove:
+        image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
+    return image
+
+
+def resize_foreground(
+    image: PIL.Image.Image,
+    ratio: float,
+) -> PIL.Image.Image:
+    image = np.array(image)
+    assert image.shape[-1] == 4
+    alpha = np.where(image[..., 3] > 0)
+    y1, y2, x1, x2 = (
+        alpha[0].min(),
+        alpha[0].max(),
+        alpha[1].min(),
+        alpha[1].max(),
+    )
+    # crop the foreground
+    fg = image[y1:y2, x1:x2]
+    # pad to square
+    size = max(fg.shape[0], fg.shape[1])
+    ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
+    ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
+    new_image = np.pad(
+        fg,
+        ((ph0, ph1), (pw0, pw1), (0, 0)),
+        mode="constant",
+        constant_values=((0, 0), (0, 0), (0, 0)),
+    )
+
+    # compute padding according to the ratio
+    new_size = int(new_image.shape[0] / ratio)
+    # pad to size, double side
+    ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
+    ph1, pw1 = new_size - size - ph0, new_size - size - pw0
+    new_image = np.pad(
+        new_image,
+        ((ph0, ph1), (pw0, pw1), (0, 0)),
+        mode="constant",
+        constant_values=((0, 0), (0, 0), (0, 0)),
+    )
+    new_image = PIL.Image.fromarray(new_image)
+    return new_image
+
+
+def images_to_video(
+    images: torch.Tensor,
+    output_path: str,
+    fps: int = 30,
+) -> None:
+    # images: (N, C, H, W)
+    video_dir = os.path.dirname(output_path)
+    os.path.basename(output_path)
+    os.makedirs(video_dir, exist_ok=True)
+
+    frames = []
+    for i in range(len(images)):
+        frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
+        assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
+            f"Frame shape mismatch: {frame.shape} vs {images.shape}"
+        assert frame.min() >= 0 and frame.max() <= 255, \
+            f"Frame value out of range: {frame.min()} ~ {frame.max()}"
+        frames.append(frame)
+    imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)
diff --git a/src/utils/mesh_util.py b/src/utils/mesh_util.py
new file mode 100755
index 0000000000000000000000000000000000000000..76314850e819c134d7b44880c244ef272ab40c3e
--- /dev/null
+++ b/src/utils/mesh_util.py
@@ -0,0 +1,182 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+from __future__ import annotations
+
+import cv2
+import numpy as np
+import nvdiffrast.torch as dr
+import torch
+import trimesh
+import xatlas
+from PIL import Image
+
+
+def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath):
+
+    pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
+    facenp_fx3 = facenp_fx3[:, [2, 1, 0]]
+
+    mesh = trimesh.Trimesh(
+        vertices=pointnp_px3,
+        faces=facenp_fx3,
+        vertex_colors=colornp_px3,
+    )
+    mesh.export(fpath, 'obj')
+
+
+def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath):
+
+    pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]])
+
+    mesh = trimesh.Trimesh(
+        vertices=pointnp_px3,
+        faces=facenp_fx3,
+        vertex_colors=colornp_px3,
+    )
+    mesh.export(fpath, 'glb')
+
+
+def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname):
+    import os
+    fol, na = os.path.split(fname)
+    na, _ = os.path.splitext(na)
+
+    matname = f'{fol}/{na}.mtl'
+    fid = open(matname, 'w')
+    fid.write('newmtl material_0\n')
+    fid.write('Kd 1 1 1\n')
+    fid.write('Ka 0 0 0\n')
+    fid.write('Ks 0.4 0.4 0.4\n')
+    fid.write('Ns 10\n')
+    fid.write('illum 2\n')
+    fid.write('map_Kd %s.png\n' % na)
+    fid.close()
+    ####
+
+    fid = open(fname, 'w')
+    fid.write('mtllib %s.mtl\n' % na)
+
+    for pidx, p in enumerate(pointnp_px3):
+        pp = p
+        fid.write(f'v {pp[0]:f} {pp[1]:f} {pp[2]:f}\n')
+
+    for pidx, p in enumerate(tcoords_px2):
+        pp = p
+        fid.write(f'vt {pp[0]:f} {pp[1]:f}\n')
+
+    fid.write('usemtl material_0\n')
+    for i, f in enumerate(facenp_fx3):
+        f1 = f + 1
+        f2 = facetex_fx3[i] + 1
+        fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
+    fid.close()
+
+    # save texture map
+    lo, hi = 0, 1
+    img = np.asarray(texmap_hxwx3, dtype=np.float32)
+    img = (img - lo) * (255 / (hi - lo))
+    img = img.clip(0, 255)
+    mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True)
+    mask = (mask <= 3.0).astype(np.float32)
+    kernel = np.ones((3, 3), 'uint8')
+    dilate_img = cv2.dilate(img, kernel, iterations=1)
+    img = img * (1 - mask) + dilate_img * mask
+    img = img.clip(0, 255).astype(np.uint8)
+    Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png')
+
+
+def loadobj(meshfile):
+    v = []
+    f = []
+    meshfp = open(meshfile)
+    for line in meshfp.readlines():
+        data = line.strip().split(' ')
+        data = [da for da in data if len(da) > 0]
+        if len(data) != 4:
+            continue
+        if data[0] == 'v':
+            v.append([float(d) for d in data[1:]])
+        if data[0] == 'f':
+            data = [da.split('/')[0] for da in data]
+            f.append([int(d) for d in data[1:]])
+    meshfp.close()
+
+    # torch need int64
+    facenp_fx3 = np.array(f, dtype=np.int64) - 1
+    pointnp_px3 = np.array(v, dtype=np.float32)
+    return pointnp_px3, facenp_fx3
+
+
+def loadobjtex(meshfile):
+    v = []
+    vt = []
+    f = []
+    ft = []
+    meshfp = open(meshfile)
+    for line in meshfp.readlines():
+        data = line.strip().split(' ')
+        data = [da for da in data if len(da) > 0]
+        if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)):
+            continue
+        if data[0] == 'v':
+            assert len(data) == 4
+
+            v.append([float(d) for d in data[1:]])
+        if data[0] == 'vt':
+            if len(data) == 3 or len(data) == 4:
+                vt.append([float(d) for d in data[1:3]])
+        if data[0] == 'f':
+            data = [da.split('/') for da in data]
+            if len(data) == 4:
+                f.append([int(d[0]) for d in data[1:]])
+                ft.append([int(d[1]) for d in data[1:]])
+            elif len(data) == 5:
+                idx1 = [1, 2, 3]
+                data1 = [data[i] for i in idx1]
+                f.append([int(d[0]) for d in data1])
+                ft.append([int(d[1]) for d in data1])
+                idx2 = [1, 3, 4]
+                data2 = [data[i] for i in idx2]
+                f.append([int(d[0]) for d in data2])
+                ft.append([int(d[1]) for d in data2])
+    meshfp.close()
+
+    # torch need int64
+    facenp_fx3 = np.array(f, dtype=np.int64) - 1
+    ftnp_fx3 = np.array(ft, dtype=np.int64) - 1
+    pointnp_px3 = np.array(v, dtype=np.float32)
+    uvs = np.array(vt, dtype=np.float32)
+    return pointnp_px3, facenp_fx3, uvs, ftnp_fx3
+
+
+# ==============================================================================================
+def interpolate(attr, rast, attr_idx, rast_db=None):
+    return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
+
+
+def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution):
+    _vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy())
+
+    # Convert to tensors
+    indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
+
+    uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
+    mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
+    # mesh_v_tex. ture
+    uv_clip = uvs[None, ...] * 2.0 - 1.0
+
+    # pad to four component coordinate
+    uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
+
+    # rasterize
+    rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution))
+
+    # Interpolate world space position
+    gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
+    mask = rast[..., 3:4] > 0
+    return uvs, mesh_tex_idx, gb_pos, mask
diff --git a/src/utils/train_util.py b/src/utils/train_util.py
new file mode 100755
index 0000000000000000000000000000000000000000..3762cf0ab025227c313ce8d3d85417397c9346aa
--- /dev/null
+++ b/src/utils/train_util.py
@@ -0,0 +1,28 @@
+from __future__ import annotations
+
+import importlib
+
+
+def count_params(model, verbose=False):
+    total_params = sum(p.numel() for p in model.parameters())
+    if verbose:
+        print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+    return total_params
+
+
+def instantiate_from_config(config):
+    if "target" not in config:
+        if config == '__is_first_stage__':
+            return None
+        elif config == "__is_unconditional__":
+            return None
+        raise KeyError("Expected key `target` to instantiate.")
+    return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+    module, cls = string.rsplit(".", 1)
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
diff --git a/zero123plus/pipeline.py b/zero123plus/pipeline.py
new file mode 100755
index 0000000000000000000000000000000000000000..3ace2662cb1ee1e7cf43a85395cba4d449d18f82
--- /dev/null
+++ b/zero123plus/pipeline.py
@@ -0,0 +1,407 @@
+from __future__ import annotations
+
+from collections import OrderedDict
+from typing import Any, Optional
+
+import diffusers
+import numpy
+import torch
+import torch.distributed
+import torch.nn as nn
+import torch.utils.checkpoint
+import transformers
+from diffusers import (
+    AutoencoderKL,
+    DDPMScheduler,
+    DiffusionPipeline,
+    EulerAncestralDiscreteScheduler,
+    ImagePipelineOutput,
+    UNet2DConditionModel,
+)
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import Attention, AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils.import_utils import is_xformers_available
+from PIL import Image
+from torchvision import transforms
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+
+def to_rgb_image(maybe_rgba: Image.Image):
+    if maybe_rgba.mode == 'RGB':
+        return maybe_rgba
+    elif maybe_rgba.mode == 'RGBA':
+        rgba = maybe_rgba
+        img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
+        img = Image.fromarray(img, 'RGB')
+        img.paste(rgba, mask=rgba.getchannel('A'))
+        return img
+    else:
+        raise ValueError("Unsupported image type.", maybe_rgba.mode)
+
+
+class ReferenceOnlyAttnProc(torch.nn.Module):
+    def __init__(
+        self,
+        chained_proc,
+        enabled=False,
+        name=None
+    ) -> None:
+        super().__init__()
+        self.enabled = enabled
+        self.chained_proc = chained_proc
+        self.name = name
+
+    def __call__(
+        self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
+        mode="w", ref_dict: dict = None, is_cfg_guidance = False
+    ) -> Any:
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        if self.enabled and is_cfg_guidance:
+            res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask)
+            hidden_states = hidden_states[1:]
+            encoder_hidden_states = encoder_hidden_states[1:]
+        if self.enabled:
+            if mode == 'w':
+                ref_dict[self.name] = encoder_hidden_states
+            elif mode == 'r':
+                encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
+            elif mode == 'm':
+                encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
+            else:
+                assert False, mode
+        res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
+        if self.enabled and is_cfg_guidance:
+            res = torch.cat([res0, res])
+        return res
+
+
+class RefOnlyNoisedUNet(torch.nn.Module):
+    def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
+        super().__init__()
+        self.unet = unet
+        self.train_sched = train_sched
+        self.val_sched = val_sched
+
+        unet_lora_attn_procs = dict()
+        for name, _ in unet.attn_processors.items():
+            if torch.__version__ >= '2.0':
+                default_attn_proc = AttnProcessor2_0()
+            elif is_xformers_available():
+                default_attn_proc = XFormersAttnProcessor()
+            else:
+                default_attn_proc = AttnProcessor()
+            unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
+                default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
+            )
+        unet.set_attn_processor(unet_lora_attn_procs)
+
+    def __getattr__(self, name: str):
+        try:
+            return super().__getattr__(name)
+        except AttributeError:
+            return getattr(self.unet, name)
+
+    def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
+        if is_cfg_guidance:
+            encoder_hidden_states = encoder_hidden_states[1:]
+            class_labels = class_labels[1:]
+        self.unet(
+            noisy_cond_lat, timestep,
+            encoder_hidden_states=encoder_hidden_states,
+            class_labels=class_labels,
+            cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
+            **kwargs
+        )
+
+    def forward(
+        self, sample, timestep, encoder_hidden_states, class_labels=None,
+        *args, cross_attention_kwargs,
+        down_block_res_samples=None, mid_block_res_sample=None,
+        **kwargs
+    ):
+        cond_lat = cross_attention_kwargs['cond_lat']
+        is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
+        noise = torch.randn_like(cond_lat)
+        if self.training:
+            noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
+            noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
+        else:
+            noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
+            noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
+        ref_dict = {}
+        self.forward_cond(
+            noisy_cond_lat, timestep,
+            encoder_hidden_states, class_labels,
+            ref_dict, is_cfg_guidance, **kwargs
+        )
+        weight_dtype = self.unet.dtype
+        return self.unet(
+            sample, timestep,
+            encoder_hidden_states, *args,
+            class_labels=class_labels,
+            cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
+            down_block_additional_residuals=[
+                sample.to(dtype=weight_dtype) for sample in down_block_res_samples
+            ] if down_block_res_samples is not None else None,
+            mid_block_additional_residual=(
+                mid_block_res_sample.to(dtype=weight_dtype)
+                if mid_block_res_sample is not None else None
+            ),
+            **kwargs
+        )
+
+
+def scale_latents(latents):
+    latents = (latents - 0.22) * 0.75
+    return latents
+
+
+def unscale_latents(latents):
+    latents = latents / 0.75 + 0.22
+    return latents
+
+
+def scale_image(image):
+    image = image * 0.5 / 0.8
+    return image
+
+
+def unscale_image(image):
+    image = image / 0.5 * 0.8
+    return image
+
+
+class DepthControlUNet(torch.nn.Module):
+    def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None:
+        super().__init__()
+        self.unet = unet
+        if controlnet is None:
+            self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet)
+        else:
+            self.controlnet = controlnet
+        DefaultAttnProc = AttnProcessor2_0
+        if is_xformers_available():
+            DefaultAttnProc = XFormersAttnProcessor
+        self.controlnet.set_attn_processor(DefaultAttnProc())
+        self.conditioning_scale = conditioning_scale
+
+    def __getattr__(self, name: str):
+        try:
+            return super().__getattr__(name)
+        except AttributeError:
+            return getattr(self.unet, name)
+
+    def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs):
+        cross_attention_kwargs = dict(cross_attention_kwargs)
+        control_depth = cross_attention_kwargs.pop('control_depth')
+        down_block_res_samples, mid_block_res_sample = self.controlnet(
+            sample,
+            timestep,
+            encoder_hidden_states=encoder_hidden_states,
+            controlnet_cond=control_depth,
+            conditioning_scale=self.conditioning_scale,
+            return_dict=False,
+        )
+        return self.unet(
+            sample,
+            timestep,
+            encoder_hidden_states=encoder_hidden_states,
+            down_block_res_samples=down_block_res_samples,
+            mid_block_res_sample=mid_block_res_sample,
+            cross_attention_kwargs=cross_attention_kwargs
+        )
+
+
+class ModuleListDict(torch.nn.Module):
+    def __init__(self, procs: dict) -> None:
+        super().__init__()
+        self.keys = sorted(procs.keys())
+        self.values = torch.nn.ModuleList(procs[k] for k in self.keys)
+
+    def __getitem__(self, key):
+        return self.values[self.keys.index(key)]
+
+
+class SuperNet(torch.nn.Module):
+    def __init__(self, state_dict: dict[str, torch.Tensor]):
+        super().__init__()
+        state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys()))
+        self.layers = torch.nn.ModuleList(state_dict.values())
+        self.mapping = dict(enumerate(state_dict.keys()))
+        self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
+
+        # .processor for unet, .self_attn for text encoder
+        self.split_keys = [".processor", ".self_attn"]
+
+        # we add a hook to state_dict() and load_state_dict() so that the
+        # naming fits with `unet.attn_processors`
+        def map_to(module, state_dict, *args, **kwargs):
+            new_state_dict = {}
+            for key, value in state_dict.items():
+                num = int(key.split(".")[1])  # 0 is always "layers"
+                new_key = key.replace(f"layers.{num}", module.mapping[num])
+                new_state_dict[new_key] = value
+
+            return new_state_dict
+
+        def remap_key(key, state_dict):
+            for k in self.split_keys:
+                if k in key:
+                    return key.split(k)[0] + k
+            return key.split('.')[0]
+
+        def map_from(module, state_dict, *args, **kwargs):
+            all_keys = list(state_dict.keys())
+            for key in all_keys:
+                replace_key = remap_key(key, state_dict)
+                new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
+                state_dict[new_key] = state_dict[key]
+                del state_dict[key]
+
+        self._register_state_dict_hook(map_to)
+        self._register_load_state_dict_pre_hook(map_from, with_module=True)
+
+
+class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
+    tokenizer: transformers.CLIPTokenizer
+    text_encoder: transformers.CLIPTextModel
+    vision_encoder: transformers.CLIPVisionModelWithProjection
+
+    feature_extractor_clip: transformers.CLIPImageProcessor
+    unet: UNet2DConditionModel
+    scheduler: diffusers.schedulers.KarrasDiffusionSchedulers
+
+    vae: AutoencoderKL
+    ramping: nn.Linear
+
+    feature_extractor_vae: transformers.CLIPImageProcessor
+
+    depth_transforms_multi = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize([0.5], [0.5])
+    ])
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        tokenizer: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        scheduler: KarrasDiffusionSchedulers,
+        vision_encoder: transformers.CLIPVisionModelWithProjection,
+        feature_extractor_clip: CLIPImageProcessor,
+        feature_extractor_vae: CLIPImageProcessor,
+        ramping_coefficients: Optional[list] = None,
+        safety_checker=None,
+    ):
+        DiffusionPipeline.__init__(self)
+
+        self.register_modules(
+            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
+            unet=unet, scheduler=scheduler, safety_checker=None,
+            vision_encoder=vision_encoder,
+            feature_extractor_clip=feature_extractor_clip,
+            feature_extractor_vae=feature_extractor_vae
+        )
+        self.register_to_config(ramping_coefficients=ramping_coefficients)
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+    def prepare(self):
+        train_sched = DDPMScheduler.from_config(self.scheduler.config)
+        if isinstance(self.unet, UNet2DConditionModel):
+            self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()
+
+    def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0):
+        self.prepare()
+        self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)
+        return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)]))
+
+    def encode_condition_image(self, image: torch.Tensor):
+        image = self.vae.encode(image).latent_dist.sample()
+        return image
+
+    @torch.no_grad()
+    def __call__(
+        self,
+        image: Image.Image = None,
+        prompt = "",
+        *args,
+        num_images_per_prompt: Optional[int] = 1,
+        guidance_scale=4.0,
+        depth_image: Image.Image = None,
+        output_type: Optional[str] = "pil",
+        width=640,
+        height=960,
+        num_inference_steps=28,
+        return_dict=True,
+        **kwargs
+    ):
+        self.prepare()
+        if image is None:
+            raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
+        assert not isinstance(image, torch.Tensor)
+        image = to_rgb_image(image)
+        image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
+        image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
+        if depth_image is not None and hasattr(self.unet, "controlnet"):
+            depth_image = to_rgb_image(depth_image)
+            depth_image = self.depth_transforms_multi(depth_image).to(
+                device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
+            )
+        image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
+        image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
+        cond_lat = self.encode_condition_image(image)
+        if guidance_scale > 1:
+            negative_lat = self.encode_condition_image(torch.zeros_like(image))
+            cond_lat = torch.cat([negative_lat, cond_lat])
+        encoded = self.vision_encoder(image_2, output_hidden_states=False)
+        global_embeds = encoded.image_embeds
+        global_embeds = global_embeds.unsqueeze(-2)
+
+        if hasattr(self, "encode_prompt"):
+            encoder_hidden_states = self.encode_prompt(
+                prompt,
+                self.device,
+                num_images_per_prompt,
+                False
+            )[0]
+        else:
+            encoder_hidden_states = self._encode_prompt(
+                prompt,
+                self.device,
+                num_images_per_prompt,
+                False
+            )
+        ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
+        encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
+        cak = dict(cond_lat=cond_lat)
+        if hasattr(self.unet, "controlnet"):
+            cak['control_depth'] = depth_image
+        latents: torch.Tensor = super().__call__(
+            None,
+            *args,
+            cross_attention_kwargs=cak,
+            guidance_scale=guidance_scale,
+            num_images_per_prompt=num_images_per_prompt,
+            prompt_embeds=encoder_hidden_states,
+            num_inference_steps=num_inference_steps,
+            output_type='latent',
+            width=width,
+            height=height,
+            **kwargs
+        ).images
+        latents = unscale_latents(latents)
+        if not output_type == "latent":
+            image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
+        else:
+            image = latents
+
+        image = self.image_processor.postprocess(image, output_type=output_type)
+        if not return_dict:
+            return (image,)
+
+        return ImagePipelineOutput(images=image)