LuLing commited on
Commit
1ff1642
·
verified ·
1 Parent(s): a54d77c

initial zerogpu demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. CITATION.cff +32 -0
  3. LICENSE +21 -0
  4. NOTICE +25 -0
  5. README.md +13 -7
  6. app.py +59 -0
  7. examples/DL3DV/DL3DV-garden-rgb.png +3 -0
  8. examples/DL3DV/DL3DV-garden-seg.png +0 -0
  9. examples/DL3DV/DL3DV-table-chair-set-rgb.png +3 -0
  10. examples/DL3DV/DL3DV-table-chair-set-seg.png +0 -0
  11. examples/DL3DV/DL3DV-tables-rgb.png +3 -0
  12. examples/DL3DV/DL3DV-tables-seg.png +0 -0
  13. examples/Gen3DSR/Gen3DSR_scene1_rgb.png +3 -0
  14. examples/Gen3DSR/Gen3DSR_scene1_seg.png +0 -0
  15. examples/MIDI-example/cartoon_style_07_rgb.png +3 -0
  16. examples/MIDI-example/cartoon_style_07_seg.png +0 -0
  17. examples/Scenethesis/SAM-3D-testing-case_rgb.png +3 -0
  18. examples/Scenethesis/SAM-3D-testing-case_seg.png +0 -0
  19. examples/Scenethesis/children_playroom2_rgb.png +3 -0
  20. examples/Scenethesis/children_playroom2_seg.png +0 -0
  21. examples/Scenethesis/scenethesis-reading-corner-rgb.png +0 -0
  22. examples/Scenethesis/scenethesis-reading-corner-seg.png +0 -0
  23. examples/outdoor/scene_beach2_rgb.png +3 -0
  24. examples/outdoor/scene_beach2_seg.png +0 -0
  25. interactive_demo.py +585 -0
  26. iscene/inference/__init__.py +0 -0
  27. iscene/inference/inferencer.py +503 -0
  28. iscene/inference/segmentation_utils.py +77 -0
  29. iscene/trellis/__init__.py +7 -0
  30. iscene/trellis/models/__init__.py +55 -0
  31. iscene/trellis/models/image_conditioner.py +134 -0
  32. iscene/trellis/models/sparse_structure_flow.py +201 -0
  33. iscene/trellis/models/sparse_structure_sc_flow.py +111 -0
  34. iscene/trellis/models/sparse_structure_vae.py +306 -0
  35. iscene/trellis/models/structured_latent_flow.py +267 -0
  36. iscene/trellis/models/structured_latent_vae/__init__.py +4 -0
  37. iscene/trellis/models/structured_latent_vae/base.py +117 -0
  38. iscene/trellis/models/structured_latent_vae/decoder_gs.py +122 -0
  39. iscene/trellis/models/structured_latent_vae/decoder_mesh.py +167 -0
  40. iscene/trellis/modules/attention/__init__.py +36 -0
  41. iscene/trellis/modules/attention/full_attn.py +140 -0
  42. iscene/trellis/modules/attention/modules.py +342 -0
  43. iscene/trellis/modules/attention_resample.py +77 -0
  44. iscene/trellis/modules/norm.py +24 -0
  45. iscene/trellis/modules/sparse/__init__.py +102 -0
  46. iscene/trellis/modules/sparse/attention/__init__.py +4 -0
  47. iscene/trellis/modules/sparse/attention/full_attn.py +215 -0
  48. iscene/trellis/modules/sparse/attention/modules.py +139 -0
  49. iscene/trellis/modules/sparse/attention/serialized_attn.py +193 -0
  50. iscene/trellis/modules/sparse/attention/windowed_attn.py +150 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/DL3DV/DL3DV-garden-rgb.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/DL3DV/DL3DV-table-chair-set-rgb.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/DL3DV/DL3DV-tables-rgb.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/Gen3DSR/Gen3DSR_scene1_rgb.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/MIDI-example/cartoon_style_07_rgb.png filter=lfs diff=lfs merge=lfs -text
41
+ examples/Scenethesis/SAM-3D-testing-case_rgb.png filter=lfs diff=lfs merge=lfs -text
42
+ examples/Scenethesis/children_playroom2_rgb.png filter=lfs diff=lfs merge=lfs -text
43
+ examples/outdoor/scene_beach2_rgb.png filter=lfs diff=lfs merge=lfs -text
CITATION.cff ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ title: "I-Scene: 3D Instance Models are Implicit Generalizable Spatial Learners"
3
+ message: "If you use I-Scene, please cite the I-Scene paper and the TRELLIS paper."
4
+ url: "https://luling06.github.io/I-Scene-web-page/"
5
+ repository-code: "https://github.com/LuLing06/I-Scene-project"
6
+ authors:
7
+ - family-names: "Ling"
8
+ given-names: "Lu"
9
+ - family-names: "Ge"
10
+ given-names: "Yunhao"
11
+ - family-names: "Sheng"
12
+ given-names: "Yichen"
13
+ - family-names: "Bera"
14
+ given-names: "Aniket"
15
+ date-released: 2026-05-05
16
+ references:
17
+ - type: article
18
+ title: "I-Scene: 3D Instance Models are Implicit Generalizable Spatial Learners"
19
+ authors:
20
+ - family-names: "Ling"
21
+ given-names: "Lu"
22
+ - family-names: "Ge"
23
+ given-names: "Yunhao"
24
+ - family-names: "Sheng"
25
+ given-names: "Yichen"
26
+ - family-names: "Bera"
27
+ given-names: "Aniket"
28
+ journal: "arXiv preprint arXiv:2512.13683"
29
+ year: 2025
30
+ - type: article
31
+ title: "Structured 3D Latents for Scalable and Versatile 3D Generation"
32
+ url: "https://trellis3d.github.io/"
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Lu Ling
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
NOTICE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ I-Scene
2
+
3
+ This repository contains the I-Scene inference code and the IScene-v1 model
4
+ package for segmentation-conditioned 3D scene generation.
5
+
6
+ IScene-v1:
7
+ Model package: IScene-v1
8
+ Project: https://luling06.github.io/I-Scene-web-page/
9
+ Code: https://github.com/LuLing06/I-Scene-project
10
+ Hugging Face repository: https://huggingface.co/LuLing/IScene
11
+ Contents: IScene-specific checkpoint files and inference configuration
12
+ Base model: microsoft/TRELLIS-image-large
13
+
14
+ I-Scene builds on TRELLIS, the image-to-3D generation framework released by
15
+ Microsoft under the MIT License.
16
+
17
+ TRELLIS:
18
+ Repository: https://github.com/microsoft/TRELLIS
19
+ Model: https://huggingface.co/microsoft/TRELLIS-image-large
20
+ Paper: Structured 3D Latents for Scalable and Versatile 3D Generation
21
+
22
+ The IScene-v1 model package provides I-Scene-specific checkpoint files and loads
23
+ TRELLIS base components from `microsoft/TRELLIS-image-large`. The TRELLIS
24
+ copyright notice and license terms should be preserved when redistributing code
25
+ or model packages derived from TRELLIS.
README.md CHANGED
@@ -1,15 +1,21 @@
1
  ---
2
- title: IScene Demo
3
- emoji: 📉
4
- colorFrom: gray
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
 
11
  license: mit
12
- short_description: I-Scene online demo
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
1
  ---
2
+ title: I-Scene Demo
3
+ emoji: 🏠
4
+ colorFrom: yellow
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.44.1
 
8
  app_file: app.py
9
  pinned: false
10
+ suggested_hardware: zero-a10g
11
  license: mit
12
+ short_description: Interactive I-Scene 3D scene generation demo
13
  ---
14
 
15
+ # I-Scene Demo
16
+
17
+ This Space runs the I-Scene interactive demo with the public checkpoint:
18
+
19
+ https://huggingface.co/LuLing/IScene
20
+
21
+ The first run may be slow because model checkpoints need to be downloaded and cached.
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
6
+ os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
7
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/data/.cache/huggingface/transformers")
8
+
9
+ import spaces
10
+ import torch
11
+
12
+ import interactive_demo
13
+
14
+
15
+ def _configure_runtime_device() -> None:
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
18
+ if interactive_demo.DEVICE != device or interactive_demo.DTYPE != dtype:
19
+ interactive_demo._sam_cache.clear()
20
+ interactive_demo.DEVICE = device
21
+ interactive_demo.DTYPE = dtype
22
+
23
+
24
+ _run_segmentation = interactive_demo.run_segmentation
25
+ _run_gaussian_preview = interactive_demo.run_gaussian_preview
26
+ _run_glb_export = interactive_demo.run_glb_export
27
+
28
+
29
+ @spaces.GPU(duration=120)
30
+ def run_segmentation(*args, **kwargs):
31
+ _configure_runtime_device()
32
+ return _run_segmentation(*args, **kwargs)
33
+
34
+
35
+ @spaces.GPU(duration=180)
36
+ def run_gaussian_preview(*args, **kwargs):
37
+ _configure_runtime_device()
38
+ return _run_gaussian_preview(*args, **kwargs)
39
+
40
+
41
+ @spaces.GPU(duration=240)
42
+ def run_glb_export(*args, **kwargs):
43
+ _configure_runtime_device()
44
+ yield from _run_glb_export(*args, **kwargs)
45
+
46
+
47
+ interactive_demo.run_segmentation = run_segmentation
48
+ interactive_demo.run_gaussian_preview = run_gaussian_preview
49
+ interactive_demo.run_glb_export = run_glb_export
50
+ interactive_demo.MODEL_ID = os.environ.get("ISCENE_MODEL", interactive_demo.DEFAULT_MODEL)
51
+ interactive_demo.BASE_MODEL_ID = os.environ.get("ISCENE_BASE_MODEL") or None
52
+ interactive_demo.DEFAULT_OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
53
+ interactive_demo.UPLOAD_ROOT.mkdir(parents=True, exist_ok=True)
54
+
55
+ demo = interactive_demo.build_demo()
56
+ demo.queue()
57
+
58
+ if __name__ == "__main__":
59
+ demo.launch(server_name="0.0.0.0", server_port=7860)
examples/DL3DV/DL3DV-garden-rgb.png ADDED

Git LFS Details

  • SHA256: e26cc42b8ed2312941ab632f8846d21d9cecd1f5dea18ba34798d7e80d2a22fe
  • Pointer size: 132 Bytes
  • Size of remote file: 4.37 MB
examples/DL3DV/DL3DV-garden-seg.png ADDED
examples/DL3DV/DL3DV-table-chair-set-rgb.png ADDED

Git LFS Details

  • SHA256: ae76f97c7f09e4932d6a73810891e41ebbaa6cc6c61ce093ae17b175a6a7dd48
  • Pointer size: 131 Bytes
  • Size of remote file: 877 kB
examples/DL3DV/DL3DV-table-chair-set-seg.png ADDED
examples/DL3DV/DL3DV-tables-rgb.png ADDED

Git LFS Details

  • SHA256: f30330e4d6661733f3bc551a02c1567a69942806822f5a4de1b758cfa51a6cf4
  • Pointer size: 132 Bytes
  • Size of remote file: 3.35 MB
examples/DL3DV/DL3DV-tables-seg.png ADDED
examples/Gen3DSR/Gen3DSR_scene1_rgb.png ADDED

Git LFS Details

  • SHA256: f1fe5da5fc2b15a427ce59833d3e17488ccedc2237ffc52f003badfb7d08b833
  • Pointer size: 132 Bytes
  • Size of remote file: 1.93 MB
examples/Gen3DSR/Gen3DSR_scene1_seg.png ADDED
examples/MIDI-example/cartoon_style_07_rgb.png ADDED

Git LFS Details

  • SHA256: 9142b580956a91ee0120df93b2698fd6347293f7d079e54bb71846b94d088cb3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
examples/MIDI-example/cartoon_style_07_seg.png ADDED
examples/Scenethesis/SAM-3D-testing-case_rgb.png ADDED

Git LFS Details

  • SHA256: 6aaa5403abdbd8d1034f6d817b72eb8306829e1e40174ed3f8211206c12e618a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.24 MB
examples/Scenethesis/SAM-3D-testing-case_seg.png ADDED
examples/Scenethesis/children_playroom2_rgb.png ADDED

Git LFS Details

  • SHA256: b0eec676a580885e65e8ba59a6e325729a35f2ba27a079c45ce6e2b990958a05
  • Pointer size: 131 Bytes
  • Size of remote file: 453 kB
examples/Scenethesis/children_playroom2_seg.png ADDED
examples/Scenethesis/scenethesis-reading-corner-rgb.png ADDED
examples/Scenethesis/scenethesis-reading-corner-seg.png ADDED
examples/outdoor/scene_beach2_rgb.png ADDED

Git LFS Details

  • SHA256: efa3144f87a622310adacc3cc2b212a52a80e75b9486cf864628606b4c42a009
  • Pointer size: 131 Bytes
  • Size of remote file: 525 kB
examples/outdoor/scene_beach2_seg.png ADDED
interactive_demo.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interactive I-Scene demo.
2
+
3
+ Run from the repository root:
4
+
5
+ python interactive_demo.py
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import os
12
+ import uuid
13
+ from dataclasses import dataclass
14
+ from datetime import datetime
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
19
+
20
+ import gradio as gr
21
+ import numpy as np
22
+ import torch
23
+ from gradio_image_prompter import ImagePrompter
24
+ from gradio_litmodel3d import LitModel3D
25
+ from PIL import Image
26
+ from transformers import AutoModelForMaskGeneration, AutoProcessor
27
+
28
+ from iscene.inference.inferencer import ISceneInferencer
29
+
30
+
31
+ REPO_ROOT = Path(__file__).resolve().parent
32
+ DEFAULT_MODEL = "LuLing/IScene"
33
+ MODEL_ID = DEFAULT_MODEL
34
+ BASE_MODEL_ID: str | None = None
35
+ DEFAULT_SEED = 43
36
+ DEFAULT_SIMPLIFY = 0.95
37
+ DEFAULT_OUTPUT_ROOT = REPO_ROOT / "outputs" / "demo"
38
+ UPLOAD_ROOT = DEFAULT_OUTPUT_ROOT / "_uploads"
39
+ TARGET_SIZE = (512, 512)
40
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
41
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
42
+
43
+ SAM_MODELS = {
44
+ "sam-vit-huge (best quality, 636M)": "facebook/sam-vit-huge",
45
+ "sam-vit-large (balanced, 308M)": "facebook/sam-vit-large",
46
+ "sam-vit-base (fastest, 91M)": "facebook/sam-vit-base",
47
+ }
48
+
49
+ MARKDOWN = """
50
+ # I-Scene Interactive Demo
51
+
52
+ Generate a 3D scene from one image.
53
+
54
+ Workflow:
55
+ 1. Pick an example, or upload an image and draw boxes around objects.
56
+ 2. Use the example mask, or click **Run SAM Segmentation** to create a mask.
57
+ 3. Click **Generate Gaussian Splatting Preview** to create and preview `scene_pred.ply`.
58
+ 4. Click **Generate GLB** only when you need mesh assets.
59
+ 5. To save each instance in the scene, run the inference code with the same RGB/mask; `run_inference.py` writes per-instance assets alongside the scene output.
60
+
61
+ Note: The first run may be slow because the model checkpoint needs to be downloaded and cached.
62
+ """
63
+
64
+ EXAMPLE_ORDER = [
65
+ "Scenethesis/SAM-3D-testing-case_rgb.png",
66
+ "Gen3DSR/Gen3DSR_scene1_rgb.png",
67
+ "MIDI-example/cartoon_style_07_rgb.png",
68
+ "Scenethesis/children_playroom2_rgb.png",
69
+ "Scenethesis/scenethesis-reading-corner-rgb.png",
70
+ "DL3DV/DL3DV-garden-rgb.png",
71
+ "DL3DV/DL3DV-table-chair-set-rgb.png",
72
+ "DL3DV/DL3DV-tables-rgb.png",
73
+ "outdoor/scene_beach2_rgb.png",
74
+ ]
75
+
76
+
77
+ def _discover_examples() -> list[tuple[str, Path, Path]]:
78
+ examples_root = REPO_ROOT / "examples"
79
+ pairs: list[tuple[str, Path, Path]] = []
80
+ for rel_name in EXAMPLE_ORDER:
81
+ rgb_path = examples_root / rel_name
82
+ if not rgb_path.exists():
83
+ continue
84
+
85
+ seg_path = None
86
+ if "_rgb" in rgb_path.name:
87
+ seg_path = rgb_path.with_name(rgb_path.name.replace("_rgb", "_seg"))
88
+ elif "-rgb" in rgb_path.name:
89
+ seg_path = rgb_path.with_name(rgb_path.name.replace("-rgb", "-seg"))
90
+ if seg_path is None or not seg_path.exists():
91
+ continue
92
+
93
+ rel = rgb_path.relative_to(examples_root)
94
+ case_name = rgb_path.stem.replace("_rgb", "").replace("-rgb", "")
95
+ label = f"{rel.parent.as_posix()} / {case_name}"
96
+ pairs.append((label, rgb_path, seg_path))
97
+ return pairs
98
+
99
+
100
+ EXAMPLES = _discover_examples()
101
+ EXAMPLE_ROWS = [[{"image": str(rgb)}, str(mask)] for _, rgb, mask in EXAMPLES]
102
+
103
+
104
+ @dataclass
105
+ class DemoRunState:
106
+ rgb_path: str
107
+ mask_path: str
108
+ output_dir: str
109
+ seed: int
110
+ simplify: float
111
+
112
+
113
+ _sam_cache: dict[str, tuple[AutoProcessor, AutoModelForMaskGeneration]] = {}
114
+ _inferencer_cache: dict[tuple[str, str], ISceneInferencer] = {}
115
+
116
+
117
+ def _make_session_dir(request: gr.Request | None, root: Path = UPLOAD_ROOT) -> Path:
118
+ session_hash = getattr(request, "session_hash", None) or uuid.uuid4().hex[:10]
119
+ path = root / session_hash
120
+ path.mkdir(parents=True, exist_ok=True)
121
+ return path
122
+
123
+
124
+ def _timestamped_output_dir(request: gr.Request | None) -> Path:
125
+ session_hash = getattr(request, "session_hash", None) or uuid.uuid4().hex[:10]
126
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
127
+ return DEFAULT_OUTPUT_ROOT / f"{timestamp}_{session_hash}"
128
+
129
+
130
+ def _get_prompt_image(image_prompts: Any) -> Image.Image | None:
131
+ if image_prompts is None:
132
+ return None
133
+ if isinstance(image_prompts, dict):
134
+ image = image_prompts.get("image")
135
+ else:
136
+ image = image_prompts
137
+ if image is None:
138
+ return None
139
+ if isinstance(image, Image.Image):
140
+ return image.convert("RGB")
141
+ return Image.open(image).convert("RGB")
142
+
143
+
144
+ def _save_prompt_rgb(image_prompts: Any, request: gr.Request | None) -> Path:
145
+ image = _get_prompt_image(image_prompts)
146
+ if image is None:
147
+ raise gr.Error("Please upload an RGB image.")
148
+ session_dir = _make_session_dir(request)
149
+ path = session_dir / "input_rgb.png"
150
+ image.save(path)
151
+ return path
152
+
153
+
154
+ def _resolve_mask_path(mask_path: str | None) -> Path:
155
+ if not mask_path:
156
+ raise gr.Error("Please choose an example or run SAM segmentation first.")
157
+ path = Path(mask_path)
158
+ if not path.exists():
159
+ raise gr.Error(f"Mask file does not exist: {path}")
160
+ return path
161
+
162
+
163
+ def _get_inferencer() -> ISceneInferencer:
164
+ key = (MODEL_ID, BASE_MODEL_ID or "")
165
+ if key not in _inferencer_cache:
166
+ _inferencer_cache[key] = ISceneInferencer.from_pretrained(MODEL_ID, base_model_id=BASE_MODEL_ID)
167
+ return _inferencer_cache[key]
168
+
169
+
170
+ def _get_sam_model(model_choice: str) -> tuple[AutoProcessor, AutoModelForMaskGeneration]:
171
+ model_id = SAM_MODELS[model_choice]
172
+ if model_id in _sam_cache:
173
+ return _sam_cache[model_id]
174
+ processor = AutoProcessor.from_pretrained(model_id)
175
+ segmentator = AutoModelForMaskGeneration.from_pretrained(model_id).to(DEVICE, DTYPE)
176
+ segmentator.eval()
177
+ _sam_cache[model_id] = (processor, segmentator)
178
+ return processor, segmentator
179
+
180
+
181
+ def _boxes_from_prompts(image_prompts: Any) -> list[list[list[int]]]:
182
+ points = image_prompts.get("points", []) if isinstance(image_prompts, dict) else []
183
+ if not points:
184
+ raise gr.Error("Please draw at least one box before running SAM segmentation.")
185
+ boxes = []
186
+ for box in points:
187
+ x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[3]), int(box[4])
188
+ x_min, x_max = sorted((x1, x2))
189
+ y_min, y_max = sorted((y1, y2))
190
+ if x_max <= x_min or y_max <= y_min:
191
+ continue
192
+ boxes.append([x_min, y_min, x_max, y_max])
193
+ if not boxes:
194
+ raise gr.Error("No valid boxes were drawn.")
195
+ return [boxes]
196
+
197
+
198
+ def _mask_to_polygon(mask: np.ndarray) -> list[list[int]] | None:
199
+ import cv2
200
+
201
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
202
+ if not contours:
203
+ return None
204
+ contour = max(contours, key=cv2.contourArea)
205
+ return contour.reshape(-1, 2).tolist()
206
+
207
+
208
+ def _polygon_to_mask(polygon: list[list[int]], image_shape: tuple[int, int]) -> np.ndarray:
209
+ import cv2
210
+
211
+ mask = np.zeros(image_shape, dtype=np.uint8)
212
+ cv2.fillPoly(mask, [np.array(polygon, dtype=np.int32)], color=(1,))
213
+ return mask
214
+
215
+
216
+ def _refine_masks(
217
+ masks: torch.Tensor,
218
+ *,
219
+ polygon_refinement: bool,
220
+ mask_threshold: float,
221
+ ) -> list[np.ndarray]:
222
+ masks = masks.detach().cpu().float()
223
+ if masks.ndim == 5:
224
+ masks = masks[:, :, 0]
225
+ if masks.ndim == 4:
226
+ masks = masks.mean(dim=1)
227
+ masks = (masks > mask_threshold).numpy().astype(np.uint8)
228
+ refined = [mask for mask in masks]
229
+ if polygon_refinement:
230
+ for idx, mask in enumerate(refined):
231
+ polygon = _mask_to_polygon(mask)
232
+ if polygon is not None:
233
+ refined[idx] = _polygon_to_mask(polygon, mask.shape)
234
+ return refined
235
+
236
+
237
+ def _palette() -> list[int]:
238
+ colors = [0, 0, 0]
239
+ hue = 0.0
240
+ golden_ratio = 0.618033988749895
241
+ for _ in range(1, 256):
242
+ hue = (hue + golden_ratio) % 1.0
243
+ h = hue * 6.0
244
+ c = 0.81
245
+ x = c * (1 - abs(h % 2 - 1))
246
+ m = 0.09
247
+ if h < 1:
248
+ r, g, b = c, x, 0
249
+ elif h < 2:
250
+ r, g, b = x, c, 0
251
+ elif h < 3:
252
+ r, g, b = 0, c, x
253
+ elif h < 4:
254
+ r, g, b = 0, x, c
255
+ elif h < 5:
256
+ r, g, b = x, 0, c
257
+ else:
258
+ r, g, b = c, 0, x
259
+ colors.extend([int((r + m) * 255), int((g + m) * 255), int((b + m) * 255)])
260
+ return colors
261
+
262
+
263
+ def _label_mask_to_pil(label_map: np.ndarray) -> Image.Image:
264
+ if label_map.max(initial=0) < 256:
265
+ image = Image.fromarray(label_map.astype(np.uint8), mode="P")
266
+ image.putpalette(_palette())
267
+ return image
268
+ encoded = np.zeros((*label_map.shape, 3), dtype=np.uint8)
269
+ encoded[..., 0] = label_map & 255
270
+ encoded[..., 1] = (label_map >> 8) & 255
271
+ return Image.fromarray(encoded, mode="RGB")
272
+
273
+
274
+ def resize_prompt_image(image_prompts: Any) -> Any:
275
+ image = _get_prompt_image(image_prompts)
276
+ if image is None:
277
+ return image_prompts
278
+ resized = image.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
279
+ UPLOAD_ROOT.mkdir(parents=True, exist_ok=True)
280
+ path = UPLOAD_ROOT / f"prompt_{uuid.uuid4().hex[:10]}.png"
281
+ resized.save(path)
282
+ return {"image": str(path), "points": []}
283
+
284
+
285
+ def reset_uploaded_image(image_prompts: Any) -> tuple[Any, None, str]:
286
+ return resize_prompt_image(image_prompts), None, ""
287
+
288
+
289
+ def remember_example_mask_path(_image_prompts: Any, mask_path: str) -> str:
290
+ return str(mask_path)
291
+
292
+
293
+ @torch.no_grad()
294
+ def run_segmentation(
295
+ image_prompts: Any,
296
+ model_choice: str,
297
+ polygon_refinement: bool,
298
+ mask_threshold: float,
299
+ request: gr.Request,
300
+ ) -> tuple[str, str]:
301
+ image = _get_prompt_image(image_prompts)
302
+ if image is None:
303
+ raise gr.Error("Please upload an RGB image before running segmentation.")
304
+ boxes = _boxes_from_prompts(image_prompts)
305
+ processor, segmentator = _get_sam_model(model_choice)
306
+ inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(segmentator.device, segmentator.dtype)
307
+ outputs = segmentator(**inputs)
308
+ masks = processor.post_process_masks(
309
+ masks=outputs.pred_masks,
310
+ original_sizes=inputs.original_sizes,
311
+ reshaped_input_sizes=inputs.reshaped_input_sizes,
312
+ )[0]
313
+ masks = _refine_masks(masks, polygon_refinement=polygon_refinement, mask_threshold=mask_threshold)
314
+
315
+ label_map = np.zeros(image.size[::-1], dtype=np.uint32)
316
+ for idx, mask in enumerate(masks, start=1):
317
+ label_map[mask > 0] = idx
318
+
319
+ mask_image = _label_mask_to_pil(label_map)
320
+ session_dir = _make_session_dir(request)
321
+ raw_path = session_dir / "sam_mask.png"
322
+ mask_image.save(raw_path)
323
+
324
+ torch.cuda.empty_cache()
325
+ return str(raw_path), str(raw_path)
326
+
327
+
328
+ def run_gaussian_preview(
329
+ image_prompts: Any,
330
+ mask_path: str | None,
331
+ seed: int,
332
+ simplify: float,
333
+ output_dir_text: str,
334
+ request: gr.Request,
335
+ ) -> tuple[str, dict[str, Any], dict[str, Any], str, DemoRunState]:
336
+ rgb_path = _save_prompt_rgb(image_prompts, request)
337
+ mask_path = _resolve_mask_path(mask_path)
338
+ output_dir = Path(output_dir_text).expanduser() if output_dir_text.strip() else _timestamped_output_dir(request)
339
+ output_dir.mkdir(parents=True, exist_ok=True)
340
+
341
+ inferencer = _get_inferencer()
342
+ inferencer.infer_and_save_scene(
343
+ scene_rgb_path=rgb_path,
344
+ instance_seg_path=mask_path,
345
+ output_dir=output_dir,
346
+ overwrite=True,
347
+ save_dbg=False,
348
+ simplify=float(simplify),
349
+ only_3dgs=True,
350
+ seed=int(seed),
351
+ )
352
+
353
+ scene_ply = output_dir / "scene_pred.ply"
354
+ if not scene_ply.exists():
355
+ raise gr.Error(f"Generation finished but scene_pred.ply was not found in {output_dir}")
356
+
357
+ state = DemoRunState(
358
+ rgb_path=str(rgb_path),
359
+ mask_path=str(mask_path),
360
+ output_dir=str(output_dir),
361
+ seed=int(seed),
362
+ simplify=float(simplify),
363
+ )
364
+ torch.cuda.empty_cache()
365
+ return (
366
+ str(scene_ply),
367
+ gr.update(value=str(scene_ply), interactive=True),
368
+ gr.update(value=None, interactive=False),
369
+ "",
370
+ state,
371
+ )
372
+
373
+
374
+ def _progress_bar(percent: int) -> str:
375
+ percent = max(0, min(100, int(percent)))
376
+ return f"""
377
+ <div style="height: 14px; width: 100%; background: #ece7dc; border-radius: 999px; overflow: hidden; border: 1px solid #d8cbb7;">
378
+ <div style="height: 100%; width: {percent}%; background: linear-gradient(90deg, #b77a2f, #e0b15a); transition: width 0.4s ease;"></div>
379
+ </div>
380
+ """
381
+
382
+
383
+ def run_glb_export(
384
+ state: DemoRunState | dict[str, Any] | None,
385
+ simplify: float,
386
+ ) -> Any:
387
+ if state is None:
388
+ raise gr.Error("Please run GS preview first so the demo knows which RGB/mask/output directory to use.")
389
+ if isinstance(state, dict):
390
+ state = DemoRunState(**state)
391
+
392
+ output_dir = Path(state.output_dir)
393
+ yield gr.update(value=None, interactive=False), _progress_bar(5), gr.update(value=None)
394
+ inferencer = _get_inferencer()
395
+ yield gr.update(value=None, interactive=False), _progress_bar(15), gr.update(value=None)
396
+ inferencer.infer_and_save_scene(
397
+ scene_rgb_path=state.rgb_path,
398
+ instance_seg_path=state.mask_path,
399
+ output_dir=output_dir,
400
+ overwrite=True,
401
+ save_dbg=False,
402
+ simplify=float(simplify),
403
+ only_3dgs=False,
404
+ seed=int(state.seed),
405
+ )
406
+
407
+ scene_glb = output_dir / "scene_pred.glb"
408
+ if not scene_glb.exists():
409
+ raise gr.Error(f"GLB export finished but scene_pred.glb was not found in {output_dir}")
410
+
411
+ torch.cuda.empty_cache()
412
+ yield gr.update(value=str(scene_glb), interactive=True), _progress_bar(100), str(scene_glb)
413
+
414
+
415
+ def clear_glb_outputs() -> tuple[dict[str, Any], str, None, dict[str, Any]]:
416
+ return gr.update(value=None, interactive=False), "", None, gr.update(value=None)
417
+
418
+
419
+ def build_demo() -> gr.Blocks:
420
+ with gr.Blocks(title="I-Scene Interactive Demo", delete_cache=(3600, 3600)) as demo:
421
+ gr.Markdown(MARKDOWN)
422
+
423
+ run_state = gr.State(None)
424
+
425
+ with gr.Row():
426
+ with gr.Column(scale=1):
427
+ image_prompts = ImagePrompter(
428
+ label="RGB image (upload, then optionally draw boxes for SAM)",
429
+ type="pil",
430
+ height=520,
431
+ )
432
+
433
+ with gr.Row():
434
+ segment_button = gr.Button("Run SAM Segmentation", variant="secondary")
435
+
436
+ with gr.Accordion("Segmentation settings", open=False):
437
+ sam_model = gr.Dropdown(
438
+ choices=list(SAM_MODELS.keys()),
439
+ value="sam-vit-huge (best quality, 636M)",
440
+ label="SAM model",
441
+ )
442
+ mask_threshold = gr.Slider(
443
+ minimum=-1.0,
444
+ maximum=1.0,
445
+ value=0.0,
446
+ step=0.05,
447
+ label="Mask threshold",
448
+ )
449
+ polygon_refinement = gr.Checkbox(
450
+ label="Polygon refinement",
451
+ value=False,
452
+ )
453
+
454
+ sam_mask_preview = gr.Image(
455
+ label="Instance mask",
456
+ type="filepath",
457
+ format="png",
458
+ height=260,
459
+ )
460
+ mask_path_value = gr.Textbox(visible=False)
461
+
462
+ with gr.Accordion("Generation settings", open=False):
463
+ seed = gr.Number(label="Seed", value=DEFAULT_SEED, precision=0)
464
+ simplify = gr.Slider(
465
+ minimum=0.5,
466
+ maximum=1.0,
467
+ value=DEFAULT_SIMPLIFY,
468
+ step=0.01,
469
+ label="GLB mesh simplify ratio",
470
+ )
471
+ output_dir = gr.Textbox(
472
+ label="Output directory (optional)",
473
+ placeholder="Leave empty to use outputs/demo/<timestamp>_<session>",
474
+ )
475
+
476
+ generate_gs_button = gr.Button("Generate Gaussian Splatting Preview", variant="primary", size="lg")
477
+
478
+ with gr.Column(scale=1):
479
+ preview = LitModel3D(
480
+ label="3D preview",
481
+ exposure=10.0,
482
+ height=520,
483
+ )
484
+ download_gs = gr.DownloadButton(
485
+ label="Download Gaussian Splatting PLY",
486
+ interactive=False,
487
+ )
488
+
489
+ with gr.Row():
490
+ generate_glb_button = gr.Button("Generate GLB", variant="secondary")
491
+ glb_progress = gr.HTML(value="")
492
+ glb_preview = gr.Model3D(
493
+ label="GLB mesh preview",
494
+ clear_color=(0.98, 0.96, 0.91, 1.0),
495
+ display_mode="solid",
496
+ height=360,
497
+ )
498
+ download_glb = gr.DownloadButton(
499
+ label="Download Mesh GLB",
500
+ interactive=False,
501
+ )
502
+
503
+ image_prompts.upload(
504
+ reset_uploaded_image,
505
+ inputs=[image_prompts],
506
+ outputs=[image_prompts, sam_mask_preview, mask_path_value],
507
+ )
508
+
509
+ segment_button.click(
510
+ run_segmentation,
511
+ inputs=[image_prompts, sam_model, polygon_refinement, mask_threshold],
512
+ outputs=[sam_mask_preview, mask_path_value],
513
+ )
514
+
515
+ generate_gs_button.click(
516
+ clear_glb_outputs,
517
+ outputs=[download_glb, glb_progress, run_state, glb_preview],
518
+ show_progress="hidden",
519
+ ).then(
520
+ run_gaussian_preview,
521
+ inputs=[
522
+ image_prompts,
523
+ mask_path_value,
524
+ seed,
525
+ simplify,
526
+ output_dir,
527
+ ],
528
+ outputs=[preview, download_gs, download_glb, glb_progress, run_state],
529
+ show_progress="full",
530
+ )
531
+
532
+ generate_glb_button.click(
533
+ run_glb_export,
534
+ inputs=[run_state, simplify],
535
+ outputs=[download_glb, glb_progress, glb_preview],
536
+ show_progress="hidden",
537
+ )
538
+
539
+ with gr.Row():
540
+ gr.Examples(
541
+ examples=EXAMPLE_ROWS,
542
+ inputs=[image_prompts, sam_mask_preview],
543
+ outputs=[mask_path_value],
544
+ fn=remember_example_mask_path,
545
+ cache_examples=False,
546
+ label="Examples",
547
+ run_on_click=True,
548
+ )
549
+
550
+ return demo
551
+
552
+
553
+ def parse_args() -> argparse.Namespace:
554
+ parser = argparse.ArgumentParser(description=__doc__)
555
+ parser.add_argument("--server_name", default="0.0.0.0")
556
+ parser.add_argument("--server_port", type=int, default=7860)
557
+ parser.add_argument("--share", action="store_true")
558
+ parser.add_argument("--model", default=DEFAULT_MODEL, help="I-Scene model id or local model package path.")
559
+ parser.add_argument(
560
+ "--base_model",
561
+ default=None,
562
+ help="Optional TRELLIS base model id or local mirror path. Defaults to the model package metadata.",
563
+ )
564
+ return parser.parse_args()
565
+
566
+
567
+ def main() -> None:
568
+ global MODEL_ID, BASE_MODEL_ID
569
+
570
+ args = parse_args()
571
+ MODEL_ID = args.model
572
+ BASE_MODEL_ID = args.base_model
573
+ DEFAULT_OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
574
+ UPLOAD_ROOT.mkdir(parents=True, exist_ok=True)
575
+ demo = build_demo()
576
+ demo.queue()
577
+ demo.launch(
578
+ server_name=args.server_name,
579
+ server_port=args.server_port,
580
+ share=args.share,
581
+ )
582
+
583
+
584
+ if __name__ == "__main__":
585
+ main()
iscene/inference/__init__.py ADDED
File without changes
iscene/inference/inferencer.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from plyfile import PlyData, PlyElement
10
+ import torch
11
+ import trimesh
12
+ from tqdm import tqdm
13
+
14
+ from ..trellis.pipelines import TrellisImageTo3DSceneContextPipeline
15
+ from ..trellis.modules import sparse as sp
16
+
17
+ from .segmentation_utils import load_scene_and_instance_masks, segmentation_to_id_map
18
+
19
+
20
+ DEFAULT_BASE_MODEL_ID = "microsoft/TRELLIS-image-large"
21
+ SPARSE_STRUCTURE_SAMPLER_PARAMS = {"steps": 25, "cfg_strength": 3.0}
22
+ SLAT_SAMPLER_PARAMS = {"steps": 25, "cfg_strength": 3.0}
23
+
24
+
25
+ def _resolve_package_file(model_id_or_path: str | Path, filename: str, revision: str | None = None) -> Path:
26
+ root = Path(model_id_or_path).expanduser()
27
+ local_path = root / filename
28
+ if local_path.exists():
29
+ return local_path
30
+
31
+ from huggingface_hub import hf_hub_download
32
+
33
+ return Path(hf_hub_download(str(model_id_or_path), filename, revision=revision))
34
+
35
+
36
+ class ISceneInferencer:
37
+ def __init__(
38
+ self,
39
+ model_id_or_path: str | Path,
40
+ *,
41
+ base_model_id: str | Path | None = None,
42
+ revision: str | None = None,
43
+ base_revision: str | None = None,
44
+ ):
45
+ self.model_id_or_path = str(model_id_or_path)
46
+ self.base_model_id = str(base_model_id) if base_model_id is not None else None
47
+ self.revision = revision
48
+ self.base_revision = base_revision
49
+ self.pipeline = None
50
+
51
+ @classmethod
52
+ def from_pretrained(
53
+ cls,
54
+ model_id_or_path: str | Path,
55
+ *,
56
+ base_model_id: str | Path | None = None,
57
+ revision: str | None = None,
58
+ base_revision: str | None = None,
59
+ ) -> "ISceneInferencer":
60
+ return cls(
61
+ model_id_or_path,
62
+ base_model_id=base_model_id,
63
+ revision=revision,
64
+ base_revision=base_revision,
65
+ )
66
+
67
+ def _load_release_metadata(self) -> dict:
68
+ metadata_path = _resolve_package_file(self.model_id_or_path, "iscene_config.json", revision=self.revision)
69
+ with open(metadata_path, "r") as f:
70
+ metadata = json.load(f)
71
+
72
+ required_keys = {
73
+ "base_model_id",
74
+ "config_file",
75
+ "denoiser_checkpoint",
76
+ "image_conditioner_checkpoint",
77
+ }
78
+ missing = sorted(required_keys - set(metadata))
79
+ if missing:
80
+ raise ValueError(f"IScene model package is missing required metadata keys: {missing}")
81
+ return metadata
82
+
83
+ def setup_pipeline(self):
84
+ metadata = self._load_release_metadata()
85
+ config_file = _resolve_package_file(self.model_id_or_path, metadata["config_file"], revision=self.revision)
86
+ denoiser_checkpoint = _resolve_package_file(
87
+ self.model_id_or_path,
88
+ metadata["denoiser_checkpoint"],
89
+ revision=self.revision,
90
+ )
91
+ image_conditioner_checkpoint = _resolve_package_file(
92
+ self.model_id_or_path,
93
+ metadata["image_conditioner_checkpoint"],
94
+ revision=self.revision,
95
+ )
96
+ base_model_id = self.base_model_id or metadata.get("base_model_id", DEFAULT_BASE_MODEL_ID)
97
+
98
+ pipeline, cfg = TrellisImageTo3DSceneContextPipeline.from_pretrained(
99
+ str(base_model_id),
100
+ config_file=config_file,
101
+ denoiser_checkpoint=denoiser_checkpoint,
102
+ image_conditioner_checkpoint=image_conditioner_checkpoint,
103
+ revision=self.base_revision,
104
+ )
105
+ pipeline.cuda()
106
+ pipeline.set_exp_cfg(cfg)
107
+ return pipeline
108
+
109
+ def infer_and_save_scene(
110
+ self,
111
+ scene_rgb_path: str | Path,
112
+ instance_seg_path: str | Path,
113
+ output_dir: str | Path,
114
+ overwrite: bool = True,
115
+ save_dbg: bool = False,
116
+ simplify: float = 0.95,
117
+ only_3dgs: bool = False,
118
+ seed: int = 42,
119
+ verbose: bool = False,
120
+ ) -> None:
121
+ scene_results = self.infer_scene_instances(
122
+ scene_rgb_path,
123
+ instance_seg_path,
124
+ seed=seed,
125
+ only_3dgs=only_3dgs,
126
+ save_dbg=save_dbg,
127
+ verbose=verbose,
128
+ )
129
+ self.save_scene_outputs(
130
+ scene_results,
131
+ output_dir,
132
+ overwrite=overwrite,
133
+ save_dbg=save_dbg,
134
+ simplify=simplify,
135
+ only_3dgs=only_3dgs,
136
+ verbose=verbose,
137
+ )
138
+
139
+ @staticmethod
140
+ def _prepare_instance_inputs(
141
+ scene_rgb_path: str | Path,
142
+ instance_seg_path: str | Path,
143
+ input_loader=load_scene_and_instance_masks,
144
+ ):
145
+ scene_rgb, instance_masks, label_ids = input_loader(
146
+ scene_rgb_path,
147
+ instance_seg_path,
148
+ )
149
+ scene_mask = (segmentation_to_id_map(Image.open(instance_seg_path)) > 0).astype("uint8") * 255
150
+ scene_mask_pil = Image.fromarray(scene_mask)
151
+ return scene_rgb, instance_masks, scene_mask_pil, label_ids
152
+
153
+ @staticmethod
154
+ @torch.no_grad()
155
+ def _sample_sparse_structure(
156
+ pipeline,
157
+ *,
158
+ scene_rgb,
159
+ scene_mask,
160
+ instance_masks,
161
+ seed: int,
162
+ sparse_structure_sampler_params: dict,
163
+ collect_debug: bool,
164
+ verbose: bool,
165
+ ) -> dict | None:
166
+ if scene_rgb is None or not instance_masks:
167
+ logging.warning("Empty input lists for sparse-structure inference.")
168
+ return None
169
+
170
+ preprocessed_list = []
171
+ dbg_rets = [] if collect_debug else None
172
+ for instance_mask in instance_masks:
173
+ preprocessed, dbg_ret = pipeline.preprocess_image(
174
+ scene_rgb,
175
+ scene_mask,
176
+ instance_mask,
177
+ return_debug=collect_debug,
178
+ )
179
+ preprocessed_list.append(preprocessed)
180
+ if collect_debug and dbg_rets is not None:
181
+ dbg_rets.append(dbg_ret)
182
+
183
+ exp_setting = getattr(pipeline.exp_cfg.dataset.args, "exp_setting", "")
184
+ slot_names = ["scene_space_instance"]
185
+ if "global" in exp_setting:
186
+ slot_names.append("scene_space_scene")
187
+ if "local" in exp_setting:
188
+ slot_names.append("canonical_space_instance")
189
+
190
+ ss_cond, slat_cond, resolved_batch_size, num_slots = pipeline.get_cond_batch(preprocessed_list)
191
+ if len(slot_names) != num_slots:
192
+ slot_names = [f"slot_{i}" for i in range(num_slots)]
193
+
194
+ torch.manual_seed(seed)
195
+ coords = pipeline.sample_sparse_structure(
196
+ ss_cond,
197
+ num_samples=resolved_batch_size,
198
+ sampler_params=sparse_structure_sampler_params,
199
+ verbose=verbose,
200
+ )
201
+
202
+ results = {
203
+ "coords": coords,
204
+ "num_instances": resolved_batch_size,
205
+ "num_slots": num_slots,
206
+ "slot_names": slot_names,
207
+ "slat_cond": slat_cond,
208
+ }
209
+ if collect_debug:
210
+ results["dbg_ret_list"] = dbg_rets
211
+ return results
212
+
213
+ @torch.no_grad()
214
+ def infer_scene_instances(
215
+ self,
216
+ scene_rgb_path: str | Path,
217
+ instance_seg_path: str | Path,
218
+ seed: int = 42,
219
+ only_3dgs: bool = False,
220
+ save_dbg: bool = False,
221
+ verbose: bool = False,
222
+ ):
223
+ scene_rgb, instance_masks, scene_mask_pil, label_ids = self._prepare_instance_inputs(
224
+ scene_rgb_path,
225
+ instance_seg_path,
226
+ )
227
+
228
+ if not instance_masks:
229
+ logging.warning("No foreground instances found in segmentation.")
230
+ return None
231
+
232
+ if self.pipeline is None:
233
+ self.pipeline = self.setup_pipeline()
234
+
235
+ stage1_results = self._sample_sparse_structure(
236
+ self.pipeline,
237
+ scene_rgb=scene_rgb,
238
+ scene_mask=scene_mask_pil,
239
+ instance_masks=instance_masks,
240
+ seed=seed,
241
+ sparse_structure_sampler_params=SPARSE_STRUCTURE_SAMPLER_PARAMS,
242
+ collect_debug=save_dbg,
243
+ verbose=verbose,
244
+ )
245
+ if stage1_results is None:
246
+ return None
247
+
248
+ coords = stage1_results["coords"]
249
+ slat = self.pipeline.sample_slat(
250
+ stage1_results["slat_cond"],
251
+ coords,
252
+ sampler_params=SLAT_SAMPLER_PARAMS,
253
+ verbose=verbose,
254
+ )
255
+
256
+ num_instances = stage1_results["num_instances"]
257
+ num_slots = stage1_results["num_slots"]
258
+ slot_names = stage1_results["slot_names"]
259
+ total_slots = num_instances * num_slots
260
+ scene_slot_idx = slot_names.index("scene_space_scene") if "scene_space_scene" in slot_names else -1
261
+ skipped_slot_ids = {
262
+ instance_idx * num_slots + scene_slot_idx
263
+ for instance_idx in range(num_instances)
264
+ } if scene_slot_idx >= 0 else set()
265
+
266
+ unique_batch_ids = torch.unique(slat.coords[:, 0]).sort()[0]
267
+ decode_formats = ["gaussian"] if only_3dgs else ["mesh", "gaussian"]
268
+ decoded_results = {fmt: [None] * total_slots for fmt in decode_formats}
269
+ for bid in tqdm(unique_batch_ids, desc="Decoding assets", disable=not verbose):
270
+ bid_int = int(bid.item())
271
+ if bid_int in skipped_slot_ids:
272
+ continue
273
+ mask = slat.coords[:, 0] == bid
274
+ sample_coords = slat.coords[mask].clone()
275
+ sample_coords[:, 0] = 0
276
+ sample_slat = sp.SparseTensor(
277
+ feats=slat.feats[mask],
278
+ coords=sample_coords,
279
+ )
280
+ sample_decoded = self.pipeline.decode_slat(sample_slat, decode_formats)
281
+ for fmt, values in decoded_results.items():
282
+ if fmt in sample_decoded:
283
+ values[bid_int] = sample_decoded[fmt]
284
+
285
+ scene_results = {
286
+ **decoded_results,
287
+ "coords": coords,
288
+ "num_instances": num_instances,
289
+ "num_slots": num_slots,
290
+ "slot_names": slot_names,
291
+ }
292
+ if save_dbg:
293
+ scene_results["dbg_ret_list"] = stage1_results.get("dbg_ret_list", [])
294
+
295
+ scene_results["label_ids"] = label_ids
296
+ if save_dbg:
297
+ scene_results["scene_rgb"] = scene_rgb
298
+ scene_results["instance_masks"] = instance_masks
299
+ return scene_results
300
+
301
+ def save_scene_outputs(
302
+ self,
303
+ scene_results,
304
+ output_dir: str | Path,
305
+ overwrite: bool = True,
306
+ save_dbg: bool = False,
307
+ simplify: float = 0.95,
308
+ only_3dgs: bool = False,
309
+ verbose: bool = False,
310
+ ) -> None:
311
+ if scene_results is None:
312
+ return
313
+
314
+ out_dir = Path(output_dir)
315
+ out_dir.mkdir(parents=True, exist_ok=True)
316
+ if overwrite:
317
+ for stale_scene_slot in out_dir.glob("instance_*_scene_space_scene.*"):
318
+ stale_scene_slot.unlink()
319
+ for stale_instance_slot in out_dir.glob("instance_*_scene_space_instance.*"):
320
+ stale_instance_slot.unlink()
321
+ for stale_scene_slot in out_dir.glob("scene_space_scene.*"):
322
+ stale_scene_slot.unlink()
323
+
324
+ label_ids = scene_results.get("label_ids", [])
325
+ slot_names = scene_results.get("slot_names", [])
326
+ num_instances = int(scene_results.get("num_instances", len(label_ids)))
327
+ num_slots = int(scene_results.get("num_slots", len(slot_names) if slot_names else 0))
328
+ meshes = scene_results.get("mesh")
329
+ gaussians = scene_results.get("gaussian")
330
+ coords = scene_results.get("coords")
331
+
332
+ if gaussians is None:
333
+ raise ValueError("scene_results must contain gaussian outputs.")
334
+ if not only_3dgs and meshes is None:
335
+ raise ValueError("scene_results must contain mesh outputs when only_3dgs=False.")
336
+
337
+ if num_slots <= 0:
338
+ num_slots = max(1, len(gaussians) // max(num_instances, 1))
339
+ if not slot_names or len(slot_names) != num_slots:
340
+ slot_names = [f"slot_{i}" for i in range(num_slots)]
341
+
342
+ scene_slot_idx = slot_names.index("scene_space_scene") if "scene_space_scene" in slot_names else -1
343
+ instance_slot_idx = slot_names.index("scene_space_instance") if "scene_space_instance" in slot_names else 0
344
+
345
+ if only_3dgs:
346
+ instance_ply_paths: list[str] = []
347
+ for instance_idx in tqdm(range(num_instances), desc="Saving Gaussian assets", disable=not verbose):
348
+ label_id = label_ids[instance_idx] if instance_idx < len(label_ids) else instance_idx
349
+ for slot_idx in range(num_slots):
350
+ if slot_idx != instance_slot_idx or slot_idx == scene_slot_idx:
351
+ continue
352
+
353
+ flat_idx = instance_idx * num_slots + slot_idx
354
+ ply_path = out_dir / f"instance_{int(label_id):02d}.ply"
355
+ if ply_path.exists() and not overwrite:
356
+ instance_ply_paths.append(str(ply_path))
357
+ continue
358
+
359
+ gaussian = gaussians[flat_idx]
360
+ if gaussian is None:
361
+ continue
362
+
363
+ gaussian[0].save_ply(str(ply_path))
364
+ instance_ply_paths.append(str(ply_path))
365
+
366
+ if instance_ply_paths:
367
+ scene_ply_path = out_dir / "scene_pred.ply"
368
+ if overwrite or not scene_ply_path.exists():
369
+ merge_gaussian_ply_files(instance_ply_paths, str(scene_ply_path))
370
+ else:
371
+ from ..trellis.utils import postprocessing_utils
372
+
373
+ instance_glbs: list[Path] = []
374
+
375
+ for instance_idx in tqdm(range(num_instances), desc="Exporting GLB assets", disable=not verbose):
376
+ label_id = label_ids[instance_idx] if instance_idx < len(label_ids) else instance_idx
377
+ for slot_idx in range(num_slots):
378
+ if slot_idx != instance_slot_idx or slot_idx == scene_slot_idx:
379
+ continue
380
+
381
+ flat_idx = instance_idx * num_slots + slot_idx
382
+ out_path = out_dir / f"instance_{int(label_id):02d}.glb"
383
+ if out_path.exists() and not overwrite:
384
+ instance_glbs.append(out_path)
385
+ continue
386
+
387
+ gaussian = gaussians[flat_idx]
388
+ mesh = meshes[flat_idx]
389
+ if gaussian is None or mesh is None:
390
+ continue
391
+
392
+ glb = postprocessing_utils.to_glb(
393
+ gaussian[0],
394
+ mesh[0],
395
+ simplify=simplify,
396
+ texture_size=1024,
397
+ verbose=False,
398
+ )
399
+ out_path.parent.mkdir(parents=True, exist_ok=True)
400
+ glb.export(str(out_path))
401
+ instance_glbs.append(out_path)
402
+
403
+ if instance_glbs:
404
+ scene_mesh = self._merge_instance_glbs_to_scene(sorted(instance_glbs))
405
+ scene_mesh.export(str(out_dir / "scene_pred.glb"))
406
+
407
+ if save_dbg:
408
+ self._save_debug_outputs(
409
+ scene_results,
410
+ out_dir,
411
+ label_ids=label_ids,
412
+ slot_names=slot_names,
413
+ num_slots=num_slots,
414
+ coords=coords,
415
+ )
416
+
417
+ def _save_debug_outputs(
418
+ self,
419
+ scene_results,
420
+ out_dir: Path,
421
+ *,
422
+ label_ids: list[int],
423
+ slot_names: list[str],
424
+ num_slots: int,
425
+ coords,
426
+ ) -> None:
427
+ scene_rgb = scene_results.get("scene_rgb")
428
+ instance_masks = scene_results.get("instance_masks")
429
+ dbg_ret_list = scene_results.get("dbg_ret_list", [])
430
+ num_instances = int(scene_results.get("num_instances", len(label_ids)))
431
+
432
+ for instance_idx in range(num_instances):
433
+ label_id = label_ids[instance_idx] if instance_idx < len(label_ids) else instance_idx
434
+
435
+ if scene_rgb is not None:
436
+ scene_rgb.save(str(out_dir / f"instance_{int(label_id):02d}_scene_rgb.png"))
437
+ if instance_masks is not None and instance_idx < len(instance_masks):
438
+ instance_masks[instance_idx].save(str(out_dir / f"instance_{int(label_id):02d}_instance_mask.png"))
439
+
440
+ if dbg_ret_list and instance_idx < len(dbg_ret_list):
441
+ dbg_ret = dbg_ret_list[instance_idx]
442
+ if "instance_rgb_canonical_tensor" in dbg_ret:
443
+ canonical_np = dbg_ret["instance_rgb_canonical_tensor"].cpu().numpy().transpose(1, 2, 0)
444
+ canonical_np = np.clip(canonical_np * 255.0, 0, 255).astype(np.uint8)
445
+ Image.fromarray(canonical_np).save(
446
+ str(out_dir / f"instance_{int(label_id):02d}_canonical_space_instance_rgb.png")
447
+ )
448
+
449
+ if coords is not None:
450
+ for slot_idx in range(num_slots):
451
+ flat_idx = instance_idx * num_slots + slot_idx
452
+ coord_path = out_dir / f"instance_{int(label_id):02d}_{slot_names[slot_idx]}_coords.ply"
453
+ save_sparse_coords_as_ply(coords[coords[:, 0] == flat_idx], str(coord_path))
454
+
455
+ @staticmethod
456
+ def _merge_instance_glbs_to_scene(instance_mesh_paths):
457
+ aggregated = trimesh.Scene()
458
+
459
+ for idx, mesh_path in enumerate(sorted(Path(p) for p in instance_mesh_paths)):
460
+ try:
461
+ loaded = trimesh.load(str(mesh_path))
462
+ except Exception as exc:
463
+ logging.warning("Failed to load %s for scene aggregation: %s", mesh_path, exc)
464
+ continue
465
+
466
+ stem = mesh_path.stem
467
+ if hasattr(loaded, "geometry"):
468
+ for sub_idx, (sub_name, geometry) in enumerate(loaded.geometry.items()):
469
+ base_name = f"{stem}_{sub_name}" if sub_name else stem
470
+ node_name = base_name if base_name not in aggregated.geometry else f"{base_name}_{sub_idx}"
471
+ aggregated.add_geometry(geometry, node_name=node_name)
472
+ else:
473
+ node_name = stem if stem not in aggregated.geometry else f"{stem}_{idx}"
474
+ aggregated.add_geometry(loaded, node_name=node_name)
475
+
476
+ return aggregated
477
+
478
+
479
+ def save_sparse_coords_as_ply(coords, output_path: str, resolution: int = 64) -> None:
480
+ spatial_coords = coords[:, 1:].float().cpu().numpy()
481
+ points = (spatial_coords + 0.5) / resolution * 2.0 - 1.0
482
+ points = points[:, [0, 2, 1]]
483
+ points[:, 2] = -points[:, 2]
484
+ trimesh.points.PointCloud(points).export(output_path)
485
+
486
+
487
+ def merge_gaussian_ply_files(ply_paths: list[str], output_path: str) -> None:
488
+ all_vertices = []
489
+ for ply_path in ply_paths:
490
+ if not Path(ply_path).exists():
491
+ continue
492
+ try:
493
+ plydata = PlyData.read(str(ply_path))
494
+ except Exception as exc:
495
+ logging.warning("Failed to read %s: %s", ply_path, exc)
496
+ continue
497
+ all_vertices.append(plydata["vertex"].data)
498
+
499
+ if not all_vertices:
500
+ return
501
+
502
+ merged = np.concatenate(all_vertices)
503
+ PlyData([PlyElement.describe(merged, "vertex")]).write(output_path)
iscene/inference/segmentation_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+
10
+ def load_rgb_image(image_path: Union[str, Path]) -> Image.Image:
11
+ """Load an RGB image and handle alpha / transparency consistently."""
12
+ img = Image.open(image_path)
13
+ if img.mode in ("RGBA", "LA") or ("transparency" in img.info):
14
+ rgba = img.convert("RGBA")
15
+ background = Image.new("RGBA", rgba.size, (0, 0, 0, 0))
16
+ return Image.alpha_composite(background, rgba).convert("RGB")
17
+ return img.convert("RGB")
18
+
19
+
20
+ def segmentation_to_id_map(segmentation: Image.Image) -> np.ndarray:
21
+ """Decode an instance segmentation image into one integer label per pixel."""
22
+ seg_array = np.array(segmentation)
23
+ if seg_array.ndim == 2:
24
+ return seg_array.astype(np.uint32)
25
+
26
+ if seg_array.ndim == 3 and seg_array.shape[2] >= 1:
27
+ channels = seg_array[..., :3].astype(np.uint32)
28
+ if channels.shape[2] == 1:
29
+ return channels[..., 0]
30
+
31
+ r = channels[..., 0]
32
+ g = channels[..., 1]
33
+ b = channels[..., 2] if channels.shape[2] >= 3 else np.zeros_like(r)
34
+
35
+ if np.array_equal(r, g) and np.array_equal(r, b):
36
+ return r
37
+
38
+ packed_rg = r + (g << 8)
39
+ packed_rgb = packed_rg + (b << 16)
40
+ rg_ids = np.unique(packed_rg)
41
+ rgb_ids = np.unique(packed_rgb)
42
+
43
+ # Preserve the legacy 16-bit R/G packed format when B carries no label
44
+ # information. Use full RGB packing for color-coded masks so blue-only
45
+ # labels are not dropped and distinct colors are not merged.
46
+ if np.any(b != 0) or len(rgb_ids) != len(rg_ids):
47
+ return packed_rgb
48
+ return packed_rg
49
+
50
+ return np.zeros(seg_array.shape[:2], dtype=np.uint32)
51
+
52
+
53
+ def load_scene_and_instance_masks(
54
+ rgb_image_path: Union[str, Path],
55
+ segmentation_path: Union[str, Path],
56
+ ) -> tuple[Image.Image, list[Image.Image], list[int]]:
57
+ """
58
+ Load one scene RGB image and split a multi-label segmentation into per-instance masks.
59
+
60
+ The segmentation can be single-channel label IDs, palette IDs, packed 16-bit
61
+ R/G IDs, or RGB color-coded instance IDs.
62
+ """
63
+ segmentation = Image.open(segmentation_path)
64
+ scene_rgb = load_rgb_image(rgb_image_path).resize(segmentation.size)
65
+
66
+ id_map = segmentation_to_id_map(segmentation)
67
+
68
+ label_ids = np.unique(id_map)
69
+ label_ids = sorted(int(label_id) for label_id in label_ids[label_ids > 0].tolist())
70
+
71
+ instance_masks: list[Image.Image] = []
72
+ for label_id in label_ids:
73
+ mask = np.zeros_like(id_map, dtype=np.uint8)
74
+ mask[id_map == label_id] = 255
75
+ instance_masks.append(Image.fromarray(mask))
76
+
77
+ return scene_rgb, instance_masks, label_ids
iscene/trellis/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Lightweight TRELLIS components used by IScene inference.
2
+
3
+ Subpackages are imported by their direct users. Keeping this package init small
4
+ avoids importing optional rendering dependencies for Gaussian-only inference.
5
+ """
6
+
7
+ __all__ = ["models", "modules", "pipelines", "renderers", "representations", "utils"]
iscene/trellis/models/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ "SparseStructureDecoder": "sparse_structure_vae",
5
+ "SparseStructureSceneContextFlowModel": "sparse_structure_sc_flow",
6
+ "SLatGaussianDecoder": "structured_latent_vae.decoder_gs",
7
+ "SLatMeshDecoder": "structured_latent_vae.decoder_mesh",
8
+ "SLatFlowModel": "structured_latent_flow",
9
+ "ImageConditioner": "image_conditioner",
10
+ }
11
+
12
+ __all__ = list(__attributes.keys())
13
+
14
+ def __getattr__(name):
15
+ if name not in __attributes:
16
+ raise AttributeError(f"module {__name__} has no attribute {name}")
17
+
18
+ module_name = __attributes[name]
19
+ module = importlib.import_module(f".{module_name}", __name__)
20
+ value = getattr(module, name)
21
+ globals()[name] = value
22
+ return value
23
+
24
+
25
+ def from_pretrained(path: str, revision: str | None = None, **kwargs):
26
+ """
27
+ Load a model from a pretrained checkpoint.
28
+
29
+ Args:
30
+ path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
31
+ NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
32
+ **kwargs: Additional arguments for the model constructor.
33
+ """
34
+ import os
35
+ import json
36
+ from safetensors.torch import load_file
37
+ is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
38
+
39
+ if is_local:
40
+ config_file = f"{path}.json"
41
+ model_file = f"{path}.safetensors"
42
+ else:
43
+ from huggingface_hub import hf_hub_download
44
+ path_parts = path.split('/')
45
+ repo_id = f'{path_parts[0]}/{path_parts[1]}'
46
+ model_name = '/'.join(path_parts[2:])
47
+ config_file = hf_hub_download(repo_id, f"{model_name}.json", revision=revision)
48
+ model_file = hf_hub_download(repo_id, f"{model_name}.safetensors", revision=revision)
49
+
50
+ with open(config_file, 'r') as f:
51
+ config = json.load(f)
52
+ model = __getattr__(config['name'])(**config['args'], **kwargs)
53
+ model.load_state_dict(load_file(model_file))
54
+
55
+ return model
iscene/trellis/models/image_conditioner.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ import torch.nn.functional as F
5
+ import logging
6
+
7
+ from ..modules.utils import convert_module_to_f32
8
+ from ..utils import dist_utils
9
+
10
+ class ImageConditioner(nn.Module):
11
+ def __init__(self, image_cond_model: str = 'dinov2_vitl14_reg', cond_in_channels: int = 10, use_fp16: bool = True):
12
+ super().__init__()
13
+
14
+ self.image_cond_model_name = image_cond_model
15
+ self.cond_in_channels = cond_in_channels
16
+ self._init_image_cond_model()
17
+
18
+ if use_fp16:
19
+ self.convert_to_fp16()
20
+ self.dtype = torch.float16 if use_fp16 else torch.float32
21
+
22
+
23
+ def convert_to_fp16(self):
24
+ logging.info('Image conditioner does not support fp16, skip this.')
25
+
26
+
27
+ def convert_to_fp32(self):
28
+ logging.info('Image conditioner does not support fp32, skip this.')
29
+ self.base_img_conditioner.apply(convert_module_to_f32)
30
+
31
+
32
+ def forward(self, image: torch.Tensor):
33
+ if isinstance(image, torch.Tensor):
34
+ assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
35
+ elif isinstance(image, list):
36
+ raise ValueError(f"Unsupported type of image: {type(image)}")
37
+ else:
38
+ raise ValueError(f"Unsupported type of image: {type(image)}")
39
+
40
+ image = image.to(self.dtype).cuda()
41
+
42
+ if image.shape[1] == 3:
43
+ base_img = self.base_transform(image)
44
+ else:
45
+ # Handle multi-channel input (e.g. 7 channels: RGB + RGB + Mask)
46
+ # We normalize every 3-channel block using ImageNet stats, and leave the rest as is.
47
+ mean = torch.tensor([0.485, 0.456, 0.406], device=image.device, dtype=image.dtype).view(1, 3, 1, 1)
48
+ std = torch.tensor([0.229, 0.224, 0.225], device=image.device, dtype=image.dtype).view(1, 3, 1, 1)
49
+
50
+ chunks = []
51
+ for i in range(0, image.shape[1], 3):
52
+ chunk = image[:, i:min(i+3, image.shape[1])]
53
+ if chunk.shape[1] == 3:
54
+ chunk = (chunk - mean) / std
55
+ chunks.append(chunk)
56
+ base_img = torch.cat(chunks, dim=1)
57
+
58
+ B, C, H, W = base_img.shape
59
+ patchtokens = []
60
+
61
+ features = self.base_img_conditioner(base_img, is_training=True)['x_prenorm']
62
+ patchtokens = F.layer_norm(features, features.shape[-1:])
63
+ return patchtokens
64
+
65
+
66
+ def _init_image_cond_model(self):
67
+ """
68
+ Initialize the image conditioning model.
69
+ """
70
+ with dist_utils.local_master_first():
71
+ dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True)
72
+ dinov2_model.eval().cuda()
73
+ transform = transforms.Compose([
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
75
+ ])
76
+
77
+ self.base_img_conditioner = dinov2_model
78
+ self.base_transform = transform
79
+
80
+ if self.cond_in_channels > 3:
81
+ self.base_img_conditioner = self.expand_dinov2_model(self.base_img_conditioner, self.cond_in_channels)
82
+
83
+ self.set_param_requires_grad(self.base_img_conditioner, False)
84
+
85
+
86
+ def set_param_requires_grad(self, model, requires_grad: bool):
87
+ for param in model.parameters():
88
+ param.requires_grad = requires_grad
89
+
90
+
91
+ def expand_dinov2_model(self, dinov2_model, cond_in_channels: int):
92
+ """
93
+ Expand the DINOv2 patch embedding to accept additional input channels.
94
+ """
95
+
96
+ # locate the patch-embedding projection conv for both hf Dinov2Model and torch.hub model
97
+ if hasattr(dinov2_model, 'embeddings'):
98
+ proj = dinov2_model.embeddings.patch_embeddings.projection
99
+ elif hasattr(dinov2_model, 'patch_embed'):
100
+ proj = dinov2_model.patch_embed.proj
101
+ else:
102
+ raise RuntimeError('Cannot locate patch-embedding projection in DINOv2 model.')
103
+
104
+ if proj.weight.shape[1] < cond_in_channels:
105
+ weight = proj.weight # (out_channels, 3, k, k)
106
+
107
+ extra = []
108
+ channels_left = cond_in_channels - 3
109
+ while channels_left > 0:
110
+ take = min(3, channels_left)
111
+ extra.append(weight[:, :take].clone())
112
+ channels_left -= take
113
+
114
+ new_weight = torch.cat([weight] + extra, dim=1)
115
+
116
+ new_proj = torch.nn.Conv2d(
117
+ in_channels=cond_in_channels,
118
+ out_channels=weight.shape[0],
119
+ kernel_size=proj.kernel_size,
120
+ stride=proj.stride,
121
+ padding=proj.padding,
122
+ bias=(proj.bias is not None),
123
+ )
124
+ new_proj.weight.data = new_weight
125
+ if proj.bias is not None:
126
+ new_proj.bias.data = proj.bias.data.clone()
127
+
128
+ # replace inside the model
129
+ if hasattr(dinov2_model, 'embeddings'):
130
+ dinov2_model.embeddings.patch_embeddings.projection = new_proj
131
+ else:
132
+ dinov2_model.patch_embed.proj = new_proj
133
+
134
+ return dinov2_model
iscene/trellis/models/sparse_structure_flow.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ..modules.utils import convert_module_to_f16, convert_module_to_f32
7
+ from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
8
+ from ..modules.spatial import patchify, unpatchify
9
+ import copy
10
+ from pathlib import Path
11
+
12
+ class TimestepEmbedder(nn.Module):
13
+ """
14
+ Embeds scalar timesteps into vector representations.
15
+ """
16
+ def __init__(self, hidden_size, frequency_embedding_size=256):
17
+ super().__init__()
18
+ self.mlp = nn.Sequential(
19
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
20
+ nn.SiLU(),
21
+ nn.Linear(hidden_size, hidden_size, bias=True),
22
+ )
23
+ self.frequency_embedding_size = frequency_embedding_size
24
+
25
+ @staticmethod
26
+ def timestep_embedding(t, dim, max_period=10000):
27
+ """
28
+ Create sinusoidal timestep embeddings.
29
+
30
+ Args:
31
+ t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ dim: the dimension of the output.
34
+ max_period: controls the minimum frequency of the embeddings.
35
+
36
+ Returns:
37
+ an (N, D) Tensor of positional embeddings.
38
+ """
39
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
40
+ half = dim // 2
41
+ freqs = torch.exp(
42
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
43
+ ).to(device=t.device)
44
+ args = t[:, None].float() * freqs[None]
45
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
46
+ if dim % 2:
47
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
48
+ return embedding
49
+
50
+ def forward(self, t):
51
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
52
+ t_emb = self.mlp(t_freq)
53
+ return t_emb
54
+
55
+
56
+ class SparseStructureFlowModel(nn.Module):
57
+ def __init__(
58
+ self,
59
+ resolution: int,
60
+ in_channels: int,
61
+ model_channels: int,
62
+ cond_channels: int,
63
+ out_channels: int,
64
+ num_blocks: int,
65
+ num_heads: Optional[int] = None,
66
+ num_head_channels: Optional[int] = 64,
67
+ mlp_ratio: float = 4,
68
+ patch_size: int = 2,
69
+ pe_mode: Literal["ape", "rope"] = "ape",
70
+ use_fp16: bool = False,
71
+ use_checkpoint: bool = False,
72
+ share_mod: bool = False,
73
+ qk_rms_norm: bool = False,
74
+ qk_rms_norm_cross: bool = False,
75
+ ):
76
+ super().__init__()
77
+ self.resolution = resolution
78
+ self.in_channels = in_channels
79
+ self.model_channels = model_channels
80
+ self.cond_channels = cond_channels
81
+ self.out_channels = out_channels
82
+ self.num_blocks = num_blocks
83
+ self.num_heads = num_heads or model_channels // num_head_channels
84
+ self.mlp_ratio = mlp_ratio
85
+ self.patch_size = patch_size
86
+ self.pe_mode = pe_mode
87
+ self.use_fp16 = use_fp16
88
+ self.use_checkpoint = use_checkpoint
89
+ self.share_mod = share_mod
90
+ self.qk_rms_norm = qk_rms_norm
91
+ self.qk_rms_norm_cross = qk_rms_norm_cross
92
+ self.dtype = torch.float16 if use_fp16 else torch.float32
93
+
94
+ self.t_embedder = TimestepEmbedder(model_channels)
95
+ if share_mod:
96
+ self.adaLN_modulation = nn.Sequential(
97
+ nn.SiLU(),
98
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
99
+ )
100
+
101
+ if pe_mode == "ape":
102
+ pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
103
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
104
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
105
+ pos_emb = pos_embedder(coords)
106
+ self.register_buffer("pos_emb", pos_emb)
107
+
108
+ self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
109
+
110
+ self.blocks = nn.ModuleList([
111
+ ModulatedTransformerCrossBlock(
112
+ model_channels,
113
+ cond_channels,
114
+ num_heads=self.num_heads,
115
+ mlp_ratio=self.mlp_ratio,
116
+ attn_mode='full',
117
+ use_checkpoint=self.use_checkpoint,
118
+ use_rope=(pe_mode == "rope"),
119
+ share_mod=share_mod,
120
+ qk_rms_norm=self.qk_rms_norm,
121
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
122
+ )
123
+ for _ in range(num_blocks)
124
+ ])
125
+
126
+ self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
127
+
128
+ self.initialize_weights()
129
+ if use_fp16:
130
+ self.convert_to_fp16()
131
+
132
+ @property
133
+ def device(self) -> torch.device:
134
+ """
135
+ Return the device of the model.
136
+ """
137
+ return next(self.parameters()).device
138
+
139
+ def convert_to_fp16(self) -> None:
140
+ """
141
+ Convert the torso of the model to float16.
142
+ """
143
+ self.blocks.apply(convert_module_to_f16)
144
+
145
+ def convert_to_fp32(self) -> None:
146
+ """
147
+ Convert the torso of the model to float32.
148
+ """
149
+ self.blocks.apply(convert_module_to_f32)
150
+
151
+ def initialize_weights(self) -> None:
152
+ # Initialize transformer layers:
153
+ def _basic_init(module):
154
+ if isinstance(module, nn.Linear):
155
+ torch.nn.init.xavier_uniform_(module.weight)
156
+ if module.bias is not None:
157
+ nn.init.constant_(module.bias, 0)
158
+ self.apply(_basic_init)
159
+
160
+ # Initialize timestep embedding MLP:
161
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
162
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
163
+
164
+ # Zero-out adaLN modulation layers in DiT blocks:
165
+ if self.share_mod:
166
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
167
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
168
+ else:
169
+ for block in self.blocks:
170
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
171
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
172
+
173
+ # Zero-out output layers:
174
+ nn.init.constant_(self.out_layer.weight, 0)
175
+ nn.init.constant_(self.out_layer.bias, 0)
176
+
177
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
178
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
179
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
180
+
181
+ h = patchify(x, self.patch_size)
182
+ h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
183
+
184
+ h = self.input_layer(h)
185
+ h = h + self.pos_emb[None]
186
+ t_emb = self.t_embedder(t)
187
+ if self.share_mod:
188
+ t_emb = self.adaLN_modulation(t_emb)
189
+ t_emb = t_emb.type(self.dtype)
190
+ h = h.type(self.dtype)
191
+ cond = cond.type(self.dtype)
192
+ for block in self.blocks:
193
+ h = block(h, t_emb, cond)
194
+ h = h.type(x.dtype)
195
+ h = F.layer_norm(h, h.shape[-1:])
196
+ h = self.out_layer(h)
197
+
198
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
199
+ h = unpatchify(h, self.patch_size).contiguous()
200
+
201
+ return h
iscene/trellis/models/sparse_structure_sc_flow.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from ..modules.utils import convert_module_to_f16
5
+ from ..modules.spatial import patchify, unpatchify
6
+ from pathlib import Path
7
+ from .sparse_structure_flow import SparseStructureFlowModel
8
+
9
+ class SparseStructureSceneContextFlowModel(SparseStructureFlowModel):
10
+ def __init__(
11
+ self,
12
+ resolution: int,
13
+ in_channels: int,
14
+ model_channels: int,
15
+ cond_channels: int,
16
+ out_channels: int,
17
+ num_blocks: int,
18
+ num_heads: Optional[int] = None,
19
+ num_head_channels: Optional[int] = 64,
20
+ mlp_ratio: float = 4,
21
+ patch_size: int = 2,
22
+ pe_mode: Literal["ape", "rope"] = "ape",
23
+ use_fp16: bool = False,
24
+ use_checkpoint: bool = False,
25
+ share_mod: bool = False,
26
+ qk_rms_norm: bool = False,
27
+ qk_rms_norm_cross: bool = False,
28
+ pretrained_base: Optional[str] = None,
29
+ scene_context_attn_num: int = 5,
30
+ learning_pattern: Literal['full-finetune'] = 'full-finetune',
31
+ exp_setting: str = "global local",
32
+ type_embedding_type = None,
33
+ k_bias_scale = 0.2,
34
+ ):
35
+ super().__init__(resolution, in_channels, model_channels, cond_channels, out_channels, num_blocks, num_heads, num_head_channels, mlp_ratio, patch_size, pe_mode, use_fp16, use_checkpoint, share_mod, qk_rms_norm, qk_rms_norm_cross)
36
+
37
+ assert pretrained_base is not None, 'pretrained_base is required for SparseStructureSceneContextFlowModel'
38
+ assert Path(pretrained_base).exists(), f'Pretrained base model {pretrained_base} not found'
39
+ self.scene_context_attn_num = scene_context_attn_num
40
+
41
+ # load the base model
42
+ if Path(pretrained_base).suffix == '.pt':
43
+ self.load_state_dict(torch.load(pretrained_base, map_location='cpu'), strict=True)
44
+ elif Path(pretrained_base).suffix == '.safetensors':
45
+ from safetensors.torch import load_file
46
+ self.load_state_dict(load_file(pretrained_base), strict=True)
47
+ else:
48
+ raise ValueError(f'Invalid pretrained base model {pretrained_base}')
49
+
50
+ # hijack some blocks to use scene context attention
51
+ block_num = len(self.blocks)
52
+ start_idx = block_num // 2 - scene_context_attn_num // 2
53
+ for i in range(scene_context_attn_num):
54
+ self.blocks[start_idx + i].is_scene_context = True
55
+ self.blocks[start_idx + i].num_instances = len(exp_setting.split(' ')) + 1
56
+ if type_embedding_type is not None:
57
+ enable_gate = 'enable_gate' in type_embedding_type
58
+ enable_k_bias = 'enable_k_bias' in type_embedding_type
59
+ k_bias_scale = k_bias_scale
60
+ self.blocks[start_idx + i].self_attn.initialize_positional_encoding(self.blocks[start_idx + i].num_instances - 1,
61
+ enable_gate=enable_gate,
62
+ enable_k_bias=enable_k_bias,
63
+ k_bias_scale=k_bias_scale)
64
+
65
+ if use_fp16:
66
+ self.convert_to_fp16()
67
+
68
+ if learning_pattern != 'full-finetune':
69
+ raise ValueError(f'Unsupported learning pattern for release inference: {learning_pattern}')
70
+
71
+
72
+ def convert_to_fp16(self) -> None:
73
+ """
74
+ Convert the torso of the model to float16.
75
+ """
76
+ for block in self.blocks:
77
+ block.apply(convert_module_to_f16)
78
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor, *args, **kwargs) -> torch.Tensor:
79
+ """
80
+ x: B, N, C, [resolution, resolution, resolution]
81
+ cond: B, N, C, H, W
82
+ """
83
+ B, N, C, *rest = x.shape
84
+ x = x.view(B * N, C, *rest)
85
+
86
+ B, N, T, C = cond.shape
87
+ cond = cond.view(B * N, T, C)
88
+
89
+ t = t.repeat_interleave(N, dim=0)
90
+ h = patchify(x, self.patch_size)
91
+ h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
92
+
93
+ h = self.input_layer(h)
94
+ h = h + self.pos_emb[None]
95
+ t_emb = self.t_embedder(t)
96
+ if self.share_mod:
97
+ t_emb = self.adaLN_modulation(t_emb)
98
+ t_emb = t_emb.type(self.dtype)
99
+ h = h.type(self.dtype)
100
+ cond = cond.type(self.dtype)
101
+
102
+ for block in self.blocks:
103
+ h = block(x=h, mod=t_emb, context=cond)
104
+
105
+ h = h.type(x.dtype)
106
+ h = F.layer_norm(h, h.shape[-1:])
107
+ h = self.out_layer(h)
108
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
109
+ h = unpatchify(h, self.patch_size).contiguous()
110
+ h = h.view(B, N, *h.shape[1:])
111
+ return h
iscene/trellis/models/sparse_structure_vae.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ..modules.norm import GroupNorm32, ChannelLayerNorm32
6
+ from ..modules.spatial import pixel_shuffle_3d
7
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
8
+
9
+
10
+ def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
11
+ """
12
+ Return a normalization layer.
13
+ """
14
+ if norm_type == "group":
15
+ return GroupNorm32(32, *args, **kwargs)
16
+ elif norm_type == "layer":
17
+ return ChannelLayerNorm32(*args, **kwargs)
18
+ else:
19
+ raise ValueError(f"Invalid norm type {norm_type}")
20
+
21
+
22
+ class ResBlock3d(nn.Module):
23
+ def __init__(
24
+ self,
25
+ channels: int,
26
+ out_channels: Optional[int] = None,
27
+ norm_type: Literal["group", "layer"] = "layer",
28
+ ):
29
+ super().__init__()
30
+ self.channels = channels
31
+ self.out_channels = out_channels or channels
32
+
33
+ self.norm1 = norm_layer(norm_type, channels)
34
+ self.norm2 = norm_layer(norm_type, self.out_channels)
35
+ self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
36
+ self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
37
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ h = self.norm1(x)
41
+ h = F.silu(h)
42
+ h = self.conv1(h)
43
+ h = self.norm2(h)
44
+ h = F.silu(h)
45
+ h = self.conv2(h)
46
+ h = h + self.skip_connection(x)
47
+ return h
48
+
49
+
50
+ class DownsampleBlock3d(nn.Module):
51
+ def __init__(
52
+ self,
53
+ in_channels: int,
54
+ out_channels: int,
55
+ mode: Literal["conv", "avgpool"] = "conv",
56
+ ):
57
+ assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
58
+
59
+ super().__init__()
60
+ self.in_channels = in_channels
61
+ self.out_channels = out_channels
62
+
63
+ if mode == "conv":
64
+ self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
65
+ elif mode == "avgpool":
66
+ assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ if hasattr(self, "conv"):
70
+ return self.conv(x)
71
+ else:
72
+ return F.avg_pool3d(x, 2)
73
+
74
+
75
+ class UpsampleBlock3d(nn.Module):
76
+ def __init__(
77
+ self,
78
+ in_channels: int,
79
+ out_channels: int,
80
+ mode: Literal["conv", "nearest"] = "conv",
81
+ ):
82
+ assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
83
+
84
+ super().__init__()
85
+ self.in_channels = in_channels
86
+ self.out_channels = out_channels
87
+
88
+ if mode == "conv":
89
+ self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
90
+ elif mode == "nearest":
91
+ assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ if hasattr(self, "conv"):
95
+ x = self.conv(x)
96
+ return pixel_shuffle_3d(x, 2)
97
+ else:
98
+ return F.interpolate(x, scale_factor=2, mode="nearest")
99
+
100
+
101
+ class SparseStructureEncoder(nn.Module):
102
+ """
103
+ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
104
+
105
+ Args:
106
+ in_channels (int): Channels of the input.
107
+ latent_channels (int): Channels of the latent representation.
108
+ num_res_blocks (int): Number of residual blocks at each resolution.
109
+ channels (List[int]): Channels of the encoder blocks.
110
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
111
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
112
+ use_fp16 (bool): Whether to use FP16.
113
+ """
114
+ def __init__(
115
+ self,
116
+ in_channels: int,
117
+ latent_channels: int,
118
+ num_res_blocks: int,
119
+ channels: List[int],
120
+ num_res_blocks_middle: int = 2,
121
+ norm_type: Literal["group", "layer"] = "layer",
122
+ use_fp16: bool = False,
123
+ ):
124
+ super().__init__()
125
+ self.in_channels = in_channels
126
+ self.latent_channels = latent_channels
127
+ self.num_res_blocks = num_res_blocks
128
+ self.channels = channels
129
+ self.num_res_blocks_middle = num_res_blocks_middle
130
+ self.norm_type = norm_type
131
+ self.use_fp16 = use_fp16
132
+ self.dtype = torch.float16 if use_fp16 else torch.float32
133
+
134
+ self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
135
+
136
+ self.blocks = nn.ModuleList([])
137
+ for i, ch in enumerate(channels):
138
+ self.blocks.extend([
139
+ ResBlock3d(ch, ch)
140
+ for _ in range(num_res_blocks)
141
+ ])
142
+ if i < len(channels) - 1:
143
+ self.blocks.append(
144
+ DownsampleBlock3d(ch, channels[i+1])
145
+ )
146
+
147
+ self.middle_block = nn.Sequential(*[
148
+ ResBlock3d(channels[-1], channels[-1])
149
+ for _ in range(num_res_blocks_middle)
150
+ ])
151
+
152
+ self.out_layer = nn.Sequential(
153
+ norm_layer(norm_type, channels[-1]),
154
+ nn.SiLU(),
155
+ nn.Conv3d(channels[-1], latent_channels*2, 3, padding=1)
156
+ )
157
+
158
+ if use_fp16:
159
+ self.convert_to_fp16()
160
+
161
+ @property
162
+ def device(self) -> torch.device:
163
+ """
164
+ Return the device of the model.
165
+ """
166
+ return next(self.parameters()).device
167
+
168
+ def convert_to_fp16(self) -> None:
169
+ """
170
+ Convert the torso of the model to float16.
171
+ """
172
+ self.use_fp16 = True
173
+ self.dtype = torch.float16
174
+ self.blocks.apply(convert_module_to_f16)
175
+ self.middle_block.apply(convert_module_to_f16)
176
+
177
+ def convert_to_fp32(self) -> None:
178
+ """
179
+ Convert the torso of the model to float32.
180
+ """
181
+ self.use_fp16 = False
182
+ self.dtype = torch.float32
183
+ self.blocks.apply(convert_module_to_f32)
184
+ self.middle_block.apply(convert_module_to_f32)
185
+
186
+ def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
187
+ h = self.input_layer(x)
188
+ h = h.type(self.dtype)
189
+
190
+ for block in self.blocks:
191
+ h = block(h)
192
+ h = self.middle_block(h)
193
+
194
+ h = h.type(x.dtype)
195
+ h = self.out_layer(h)
196
+
197
+ mean, logvar = h.chunk(2, dim=1)
198
+
199
+ if sample_posterior:
200
+ std = torch.exp(0.5 * logvar)
201
+ z = mean + std * torch.randn_like(std)
202
+ else:
203
+ z = mean
204
+
205
+ if return_raw:
206
+ return z, mean, logvar
207
+ return z
208
+
209
+
210
+ class SparseStructureDecoder(nn.Module):
211
+ """
212
+ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
213
+
214
+ Args:
215
+ out_channels (int): Channels of the output.
216
+ latent_channels (int): Channels of the latent representation.
217
+ num_res_blocks (int): Number of residual blocks at each resolution.
218
+ channels (List[int]): Channels of the decoder blocks.
219
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
220
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
221
+ use_fp16 (bool): Whether to use FP16.
222
+ """
223
+ def __init__(
224
+ self,
225
+ out_channels: int,
226
+ latent_channels: int,
227
+ num_res_blocks: int,
228
+ channels: List[int],
229
+ num_res_blocks_middle: int = 2,
230
+ norm_type: Literal["group", "layer"] = "layer",
231
+ use_fp16: bool = False,
232
+ ):
233
+ super().__init__()
234
+ self.out_channels = out_channels
235
+ self.latent_channels = latent_channels
236
+ self.num_res_blocks = num_res_blocks
237
+ self.channels = channels
238
+ self.num_res_blocks_middle = num_res_blocks_middle
239
+ self.norm_type = norm_type
240
+ self.use_fp16 = use_fp16
241
+ self.dtype = torch.float16 if use_fp16 else torch.float32
242
+
243
+ self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
244
+
245
+ self.middle_block = nn.Sequential(*[
246
+ ResBlock3d(channels[0], channels[0])
247
+ for _ in range(num_res_blocks_middle)
248
+ ])
249
+
250
+ self.blocks = nn.ModuleList([])
251
+ for i, ch in enumerate(channels):
252
+ self.blocks.extend([
253
+ ResBlock3d(ch, ch)
254
+ for _ in range(num_res_blocks)
255
+ ])
256
+ if i < len(channels) - 1:
257
+ self.blocks.append(
258
+ UpsampleBlock3d(ch, channels[i+1])
259
+ )
260
+
261
+ self.out_layer = nn.Sequential(
262
+ norm_layer(norm_type, channels[-1]),
263
+ nn.SiLU(),
264
+ nn.Conv3d(channels[-1], out_channels, 3, padding=1)
265
+ )
266
+
267
+ if use_fp16:
268
+ self.convert_to_fp16()
269
+
270
+ @property
271
+ def device(self) -> torch.device:
272
+ """
273
+ Return the device of the model.
274
+ """
275
+ return next(self.parameters()).device
276
+
277
+ def convert_to_fp16(self) -> None:
278
+ """
279
+ Convert the torso of the model to float16.
280
+ """
281
+ self.use_fp16 = True
282
+ self.dtype = torch.float16
283
+ self.blocks.apply(convert_module_to_f16)
284
+ self.middle_block.apply(convert_module_to_f16)
285
+
286
+ def convert_to_fp32(self) -> None:
287
+ """
288
+ Convert the torso of the model to float32.
289
+ """
290
+ self.use_fp16 = False
291
+ self.dtype = torch.float32
292
+ self.blocks.apply(convert_module_to_f32)
293
+ self.middle_block.apply(convert_module_to_f32)
294
+
295
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
296
+ h = self.input_layer(x)
297
+
298
+ h = h.type(self.dtype)
299
+
300
+ h = self.middle_block(h)
301
+ for block in self.blocks:
302
+ h = block(h)
303
+
304
+ h = h.type(x.dtype)
305
+ h = self.out_layer(h)
306
+ return h
iscene/trellis/models/structured_latent_flow.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7
+ from ..modules.transformer import AbsolutePositionEmbedder
8
+ from ..modules.norm import LayerNorm32
9
+ from ..modules import sparse as sp
10
+ from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
11
+ from .sparse_structure_flow import TimestepEmbedder
12
+
13
+
14
+ class SparseResBlock3d(nn.Module):
15
+ def __init__(
16
+ self,
17
+ channels: int,
18
+ emb_channels: int,
19
+ out_channels: Optional[int] = None,
20
+ downsample: bool = False,
21
+ upsample: bool = False,
22
+ ):
23
+ super().__init__()
24
+ self.channels = channels
25
+ self.emb_channels = emb_channels
26
+ self.out_channels = out_channels or channels
27
+ self.downsample = downsample
28
+ self.upsample = upsample
29
+
30
+ assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
31
+
32
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
33
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
34
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
35
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
36
+ self.emb_layers = nn.Sequential(
37
+ nn.SiLU(),
38
+ nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
39
+ )
40
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
41
+ self.updown = None
42
+ if self.downsample:
43
+ self.updown = sp.SparseDownsample(2)
44
+ elif self.upsample:
45
+ self.updown = sp.SparseUpsample(2)
46
+
47
+ def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
48
+ if self.updown is not None:
49
+ x = self.updown(x)
50
+ return x
51
+
52
+ def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
53
+ emb_out = self.emb_layers(emb).type(x.dtype)
54
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
55
+
56
+ x = self._updown(x)
57
+ h = x.replace(self.norm1(x.feats))
58
+ h = h.replace(F.silu(h.feats))
59
+ h = self.conv1(h)
60
+ h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
61
+ h = h.replace(F.silu(h.feats))
62
+ h = self.conv2(h)
63
+ h = h + self.skip_connection(x)
64
+
65
+ return h
66
+
67
+
68
+ class SLatFlowModel(nn.Module):
69
+ def __init__(
70
+ self,
71
+ resolution: int,
72
+ in_channels: int,
73
+ model_channels: int,
74
+ cond_channels: int,
75
+ out_channels: int,
76
+ num_blocks: int,
77
+ num_heads: Optional[int] = None,
78
+ num_head_channels: Optional[int] = 64,
79
+ mlp_ratio: float = 4,
80
+ patch_size: int = 2,
81
+ num_io_res_blocks: int = 2,
82
+ io_block_channels: List[int] = None,
83
+ pe_mode: Literal["ape", "rope"] = "ape",
84
+ use_fp16: bool = False,
85
+ use_checkpoint: bool = False,
86
+ use_skip_connection: bool = True,
87
+ share_mod: bool = False,
88
+ qk_rms_norm: bool = False,
89
+ qk_rms_norm_cross: bool = False,
90
+ ):
91
+ super().__init__()
92
+ self.resolution = resolution
93
+ self.in_channels = in_channels
94
+ self.model_channels = model_channels
95
+ self.cond_channels = cond_channels
96
+ self.out_channels = out_channels
97
+ self.num_blocks = num_blocks
98
+ self.num_heads = num_heads or model_channels // num_head_channels
99
+ self.mlp_ratio = mlp_ratio
100
+ self.patch_size = patch_size
101
+ self.num_io_res_blocks = num_io_res_blocks
102
+ self.io_block_channels = io_block_channels
103
+ self.pe_mode = pe_mode
104
+ self.use_fp16 = use_fp16
105
+ self.use_checkpoint = use_checkpoint
106
+ self.use_skip_connection = use_skip_connection
107
+ self.share_mod = share_mod
108
+ self.qk_rms_norm = qk_rms_norm
109
+ self.qk_rms_norm_cross = qk_rms_norm_cross
110
+ self.dtype = torch.float16 if use_fp16 else torch.float32
111
+
112
+ if self.io_block_channels is not None:
113
+ assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
114
+ assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
115
+
116
+ self.t_embedder = TimestepEmbedder(model_channels)
117
+ if share_mod:
118
+ self.adaLN_modulation = nn.Sequential(
119
+ nn.SiLU(),
120
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
121
+ )
122
+
123
+ if pe_mode == "ape":
124
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
125
+
126
+ self.input_layer = sp.SparseLinear(in_channels, model_channels if io_block_channels is None else io_block_channels[0])
127
+
128
+ self.input_blocks = nn.ModuleList([])
129
+ if io_block_channels is not None:
130
+ for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
131
+ self.input_blocks.extend([
132
+ SparseResBlock3d(
133
+ chs,
134
+ model_channels,
135
+ out_channels=chs,
136
+ )
137
+ for _ in range(num_io_res_blocks-1)
138
+ ])
139
+ self.input_blocks.append(
140
+ SparseResBlock3d(
141
+ chs,
142
+ model_channels,
143
+ out_channels=next_chs,
144
+ downsample=True,
145
+ )
146
+ )
147
+
148
+ self.blocks = nn.ModuleList([
149
+ ModulatedSparseTransformerCrossBlock(
150
+ model_channels,
151
+ cond_channels,
152
+ num_heads=self.num_heads,
153
+ mlp_ratio=self.mlp_ratio,
154
+ attn_mode='full',
155
+ use_checkpoint=self.use_checkpoint,
156
+ use_rope=(pe_mode == "rope"),
157
+ share_mod=self.share_mod,
158
+ qk_rms_norm=self.qk_rms_norm,
159
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
160
+ )
161
+ for _ in range(num_blocks)
162
+ ])
163
+
164
+ self.out_blocks = nn.ModuleList([])
165
+ if io_block_channels is not None:
166
+ for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
167
+ self.out_blocks.append(
168
+ SparseResBlock3d(
169
+ prev_chs * 2 if self.use_skip_connection else prev_chs,
170
+ model_channels,
171
+ out_channels=chs,
172
+ upsample=True,
173
+ )
174
+ )
175
+ self.out_blocks.extend([
176
+ SparseResBlock3d(
177
+ chs * 2 if self.use_skip_connection else chs,
178
+ model_channels,
179
+ out_channels=chs,
180
+ )
181
+ for _ in range(num_io_res_blocks-1)
182
+ ])
183
+
184
+ self.out_layer = sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], out_channels)
185
+
186
+ self.initialize_weights()
187
+ if use_fp16:
188
+ self.convert_to_fp16()
189
+
190
+ @property
191
+ def device(self) -> torch.device:
192
+ """
193
+ Return the device of the model.
194
+ """
195
+ return next(self.parameters()).device
196
+
197
+ def convert_to_fp16(self) -> None:
198
+ """
199
+ Convert the torso of the model to float16.
200
+ """
201
+ self.input_blocks.apply(convert_module_to_f16)
202
+ self.blocks.apply(convert_module_to_f16)
203
+ self.out_blocks.apply(convert_module_to_f16)
204
+
205
+ def convert_to_fp32(self) -> None:
206
+ """
207
+ Convert the torso of the model to float32.
208
+ """
209
+ self.input_blocks.apply(convert_module_to_f32)
210
+ self.blocks.apply(convert_module_to_f32)
211
+ self.out_blocks.apply(convert_module_to_f32)
212
+
213
+ def initialize_weights(self) -> None:
214
+ # Initialize transformer layers:
215
+ def _basic_init(module):
216
+ if isinstance(module, nn.Linear):
217
+ torch.nn.init.xavier_uniform_(module.weight)
218
+ if module.bias is not None:
219
+ nn.init.constant_(module.bias, 0)
220
+ self.apply(_basic_init)
221
+
222
+ # Initialize timestep embedding MLP:
223
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
224
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
225
+
226
+ # Zero-out adaLN modulation layers in DiT blocks:
227
+ if self.share_mod:
228
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
229
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
230
+ else:
231
+ for block in self.blocks:
232
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
233
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
234
+
235
+ # Zero-out output layers:
236
+ nn.init.constant_(self.out_layer.weight, 0)
237
+ nn.init.constant_(self.out_layer.bias, 0)
238
+
239
+ def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
240
+ h = self.input_layer(x).type(self.dtype)
241
+ t_emb = self.t_embedder(t)
242
+ if self.share_mod:
243
+ t_emb = self.adaLN_modulation(t_emb)
244
+ t_emb = t_emb.type(self.dtype)
245
+ cond = cond.type(self.dtype)
246
+
247
+ skips = []
248
+ # pack with input blocks
249
+ for block in self.input_blocks:
250
+ h = block(h, t_emb)
251
+ skips.append(h.feats)
252
+
253
+ if self.pe_mode == "ape":
254
+ h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
255
+ for block in self.blocks:
256
+ h = block(h, t_emb, cond)
257
+
258
+ # unpack with output blocks
259
+ for block, skip in zip(self.out_blocks, reversed(skips)):
260
+ if self.use_skip_connection:
261
+ h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
262
+ else:
263
+ h = block(h, t_emb)
264
+
265
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
266
+ h = self.out_layer(h.type(x.dtype))
267
+ return h
iscene/trellis/models/structured_latent_vae/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .decoder_gs import SLatGaussianDecoder
2
+ from .decoder_mesh import SLatMeshDecoder
3
+
4
+ __all__ = ["SLatGaussianDecoder", "SLatMeshDecoder"]
iscene/trellis/models/structured_latent_vae/base.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ...modules.utils import convert_module_to_f16, convert_module_to_f32
5
+ from ...modules import sparse as sp
6
+ from ...modules.transformer import AbsolutePositionEmbedder
7
+ from ...modules.sparse.transformer import SparseTransformerBlock
8
+
9
+
10
+ def block_attn_config(self):
11
+ """
12
+ Return the attention configuration of the model.
13
+ """
14
+ for i in range(self.num_blocks):
15
+ if self.attn_mode == "shift_window":
16
+ yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
17
+ elif self.attn_mode == "shift_sequence":
18
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
19
+ elif self.attn_mode == "shift_order":
20
+ yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
21
+ elif self.attn_mode == "full":
22
+ yield "full", None, None, None, None
23
+ elif self.attn_mode == "swin":
24
+ yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
25
+
26
+
27
+ class SparseTransformerBase(nn.Module):
28
+ """
29
+ Sparse Transformer without output layers.
30
+ Serve as the base class for encoder and decoder.
31
+ """
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ model_channels: int,
36
+ num_blocks: int,
37
+ num_heads: Optional[int] = None,
38
+ num_head_channels: Optional[int] = 64,
39
+ mlp_ratio: float = 4.0,
40
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
41
+ window_size: Optional[int] = None,
42
+ pe_mode: Literal["ape", "rope"] = "ape",
43
+ use_fp16: bool = False,
44
+ use_checkpoint: bool = False,
45
+ qk_rms_norm: bool = False,
46
+ ):
47
+ super().__init__()
48
+ self.in_channels = in_channels
49
+ self.model_channels = model_channels
50
+ self.num_blocks = num_blocks
51
+ self.window_size = window_size
52
+ self.num_heads = num_heads or model_channels // num_head_channels
53
+ self.mlp_ratio = mlp_ratio
54
+ self.attn_mode = attn_mode
55
+ self.pe_mode = pe_mode
56
+ self.use_fp16 = use_fp16
57
+ self.use_checkpoint = use_checkpoint
58
+ self.qk_rms_norm = qk_rms_norm
59
+ self.dtype = torch.float16 if use_fp16 else torch.float32
60
+
61
+ if pe_mode == "ape":
62
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
63
+
64
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
65
+ self.blocks = nn.ModuleList([
66
+ SparseTransformerBlock(
67
+ model_channels,
68
+ num_heads=self.num_heads,
69
+ mlp_ratio=self.mlp_ratio,
70
+ attn_mode=attn_mode,
71
+ window_size=window_size,
72
+ shift_sequence=shift_sequence,
73
+ shift_window=shift_window,
74
+ serialize_mode=serialize_mode,
75
+ use_checkpoint=self.use_checkpoint,
76
+ use_rope=(pe_mode == "rope"),
77
+ qk_rms_norm=self.qk_rms_norm,
78
+ )
79
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
80
+ ])
81
+
82
+ @property
83
+ def device(self) -> torch.device:
84
+ """
85
+ Return the device of the model.
86
+ """
87
+ return next(self.parameters()).device
88
+
89
+ def convert_to_fp16(self) -> None:
90
+ """
91
+ Convert the torso of the model to float16.
92
+ """
93
+ self.blocks.apply(convert_module_to_f16)
94
+
95
+ def convert_to_fp32(self) -> None:
96
+ """
97
+ Convert the torso of the model to float32.
98
+ """
99
+ self.blocks.apply(convert_module_to_f32)
100
+
101
+ def initialize_weights(self) -> None:
102
+ # Initialize transformer layers:
103
+ def _basic_init(module):
104
+ if isinstance(module, nn.Linear):
105
+ torch.nn.init.xavier_uniform_(module.weight)
106
+ if module.bias is not None:
107
+ nn.init.constant_(module.bias, 0)
108
+ self.apply(_basic_init)
109
+
110
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
111
+ h = self.input_layer(x)
112
+ if self.pe_mode == "ape":
113
+ h = h + self.pos_embedder(x.coords[:, 1:])
114
+ h = h.type(self.dtype)
115
+ for block in self.blocks:
116
+ h = block(h)
117
+ return h
iscene/trellis/models/structured_latent_vae/decoder_gs.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ...modules import sparse as sp
6
+ from ...utils.random_utils import hammersley_sequence
7
+ from .base import SparseTransformerBase
8
+ from ...representations import Gaussian
9
+
10
+
11
+ class SLatGaussianDecoder(SparseTransformerBase):
12
+ def __init__(
13
+ self,
14
+ resolution: int,
15
+ model_channels: int,
16
+ latent_channels: int,
17
+ num_blocks: int,
18
+ num_heads: Optional[int] = None,
19
+ num_head_channels: Optional[int] = 64,
20
+ mlp_ratio: float = 4,
21
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
22
+ window_size: int = 8,
23
+ pe_mode: Literal["ape", "rope"] = "ape",
24
+ use_fp16: bool = False,
25
+ use_checkpoint: bool = False,
26
+ qk_rms_norm: bool = False,
27
+ representation_config: dict = None,
28
+ ):
29
+ super().__init__(
30
+ in_channels=latent_channels,
31
+ model_channels=model_channels,
32
+ num_blocks=num_blocks,
33
+ num_heads=num_heads,
34
+ num_head_channels=num_head_channels,
35
+ mlp_ratio=mlp_ratio,
36
+ attn_mode=attn_mode,
37
+ window_size=window_size,
38
+ pe_mode=pe_mode,
39
+ use_fp16=use_fp16,
40
+ use_checkpoint=use_checkpoint,
41
+ qk_rms_norm=qk_rms_norm,
42
+ )
43
+ self.resolution = resolution
44
+ self.rep_config = representation_config
45
+ self._calc_layout()
46
+ self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
47
+ self._build_perturbation()
48
+
49
+ self.initialize_weights()
50
+ if use_fp16:
51
+ self.convert_to_fp16()
52
+
53
+ def initialize_weights(self) -> None:
54
+ super().initialize_weights()
55
+ # Zero-out output layers:
56
+ nn.init.constant_(self.out_layer.weight, 0)
57
+ nn.init.constant_(self.out_layer.bias, 0)
58
+
59
+ def _build_perturbation(self) -> None:
60
+ perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
61
+ perturbation = torch.tensor(perturbation).float() * 2 - 1
62
+ perturbation = perturbation / self.rep_config['voxel_size']
63
+ perturbation = torch.atanh(perturbation).to(self.device)
64
+ self.register_buffer('offset_perturbation', perturbation)
65
+
66
+ def _calc_layout(self) -> None:
67
+ self.layout = {
68
+ '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
69
+ '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
70
+ '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
71
+ '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
72
+ '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
73
+ }
74
+ start = 0
75
+ for k, v in self.layout.items():
76
+ v['range'] = (start, start + v['size'])
77
+ start += v['size']
78
+ self.out_channels = start
79
+
80
+ def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
81
+ """
82
+ Convert a batch of network outputs to 3D representations.
83
+
84
+ Args:
85
+ x: The [N x * x C] sparse tensor output by the network.
86
+
87
+ Returns:
88
+ list of representations
89
+ """
90
+ ret = []
91
+ for i in range(x.shape[0]):
92
+ representation = Gaussian(
93
+ sh_degree=0,
94
+ aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
95
+ mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
96
+ scaling_bias = self.rep_config['scaling_bias'],
97
+ opacity_bias = self.rep_config['opacity_bias'],
98
+ scaling_activation = self.rep_config['scaling_activation']
99
+ )
100
+ xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
101
+ for k, v in self.layout.items():
102
+ if k == '_xyz':
103
+ offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
104
+ offset = offset * self.rep_config['lr'][k]
105
+ if self.rep_config['perturb_offset']:
106
+ offset = offset + self.offset_perturbation
107
+ offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
108
+ _xyz = xyz.unsqueeze(1) + offset
109
+ setattr(representation, k, _xyz.flatten(0, 1))
110
+ else:
111
+ feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
112
+ feats = feats * self.rep_config['lr'][k]
113
+ setattr(representation, k, feats)
114
+ ret.append(representation)
115
+ return ret
116
+
117
+ def forward(self, x: sp.SparseTensor) -> List[Gaussian]:
118
+ h = super().forward(x)
119
+ h = h.type(x.dtype)
120
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
121
+ h = self.out_layer(h)
122
+ return self.to_representation(h)
iscene/trellis/models/structured_latent_vae/decoder_mesh.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7
+ from ...modules import sparse as sp
8
+ from .base import SparseTransformerBase
9
+ from ...representations import MeshExtractResult
10
+ from ...representations.mesh import SparseFeatures2Mesh
11
+
12
+
13
+ class SparseSubdivideBlock3d(nn.Module):
14
+ """
15
+ A 3D subdivide block that can subdivide the sparse tensor.
16
+
17
+ Args:
18
+ channels: channels in the inputs and outputs.
19
+ out_channels: if specified, the number of output channels.
20
+ num_groups: the number of groups for the group norm.
21
+ """
22
+ def __init__(
23
+ self,
24
+ channels: int,
25
+ resolution: int,
26
+ out_channels: Optional[int] = None,
27
+ num_groups: int = 32
28
+ ):
29
+ super().__init__()
30
+ self.channels = channels
31
+ self.resolution = resolution
32
+ self.out_resolution = resolution * 2
33
+ self.out_channels = out_channels or channels
34
+
35
+ self.act_layers = nn.Sequential(
36
+ sp.SparseGroupNorm32(num_groups, channels),
37
+ sp.SparseSiLU()
38
+ )
39
+
40
+ self.sub = sp.SparseSubdivide()
41
+
42
+ self.out_layers = nn.Sequential(
43
+ sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
44
+ sp.SparseGroupNorm32(num_groups, self.out_channels),
45
+ sp.SparseSiLU(),
46
+ zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
47
+ )
48
+
49
+ if self.out_channels == channels:
50
+ self.skip_connection = nn.Identity()
51
+ else:
52
+ self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
53
+
54
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
55
+ """
56
+ Apply the block to a Tensor, conditioned on a timestep embedding.
57
+
58
+ Args:
59
+ x: an [N x C x ...] Tensor of features.
60
+ Returns:
61
+ an [N x C x ...] Tensor of outputs.
62
+ """
63
+ h = self.act_layers(x)
64
+ h = self.sub(h)
65
+ x = self.sub(x)
66
+ h = self.out_layers(h)
67
+ h = h + self.skip_connection(x)
68
+ return h
69
+
70
+
71
+ class SLatMeshDecoder(SparseTransformerBase):
72
+ def __init__(
73
+ self,
74
+ resolution: int,
75
+ model_channels: int,
76
+ latent_channels: int,
77
+ num_blocks: int,
78
+ num_heads: Optional[int] = None,
79
+ num_head_channels: Optional[int] = 64,
80
+ mlp_ratio: float = 4,
81
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
82
+ window_size: int = 8,
83
+ pe_mode: Literal["ape", "rope"] = "ape",
84
+ use_fp16: bool = False,
85
+ use_checkpoint: bool = False,
86
+ qk_rms_norm: bool = False,
87
+ representation_config: dict = None,
88
+ ):
89
+ super().__init__(
90
+ in_channels=latent_channels,
91
+ model_channels=model_channels,
92
+ num_blocks=num_blocks,
93
+ num_heads=num_heads,
94
+ num_head_channels=num_head_channels,
95
+ mlp_ratio=mlp_ratio,
96
+ attn_mode=attn_mode,
97
+ window_size=window_size,
98
+ pe_mode=pe_mode,
99
+ use_fp16=use_fp16,
100
+ use_checkpoint=use_checkpoint,
101
+ qk_rms_norm=qk_rms_norm,
102
+ )
103
+ self.resolution = resolution
104
+ self.rep_config = representation_config
105
+ self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*5, use_color=self.rep_config.get('use_color', False))
106
+ self.out_channels = self.mesh_extractor.feats_channels
107
+ self.upsample = nn.ModuleList([
108
+ SparseSubdivideBlock3d(
109
+ channels=model_channels,
110
+ resolution=resolution,
111
+ out_channels=model_channels // 4
112
+ ),
113
+ SparseSubdivideBlock3d(
114
+ channels=model_channels // 4,
115
+ resolution=resolution * 2,
116
+ out_channels=model_channels // 8
117
+ )
118
+ ])
119
+ self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
120
+
121
+ self.initialize_weights()
122
+ if use_fp16:
123
+ self.convert_to_fp16()
124
+
125
+ def initialize_weights(self) -> None:
126
+ super().initialize_weights()
127
+ # Zero-out output layers:
128
+ nn.init.constant_(self.out_layer.weight, 0)
129
+ nn.init.constant_(self.out_layer.bias, 0)
130
+
131
+ def convert_to_fp16(self) -> None:
132
+ """
133
+ Convert the torso of the model to float16.
134
+ """
135
+ super().convert_to_fp16()
136
+ self.upsample.apply(convert_module_to_f16)
137
+
138
+ def convert_to_fp32(self) -> None:
139
+ """
140
+ Convert the torso of the model to float32.
141
+ """
142
+ super().convert_to_fp32()
143
+ self.upsample.apply(convert_module_to_f32)
144
+
145
+ def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
146
+ """
147
+ Convert a batch of network outputs to 3D representations.
148
+
149
+ Args:
150
+ x: The [N x * x C] sparse tensor output by the network.
151
+
152
+ Returns:
153
+ list of representations
154
+ """
155
+ ret = []
156
+ for i in range(x.shape[0]):
157
+ mesh = self.mesh_extractor(x[i], training=self.training)
158
+ ret.append(mesh)
159
+ return ret
160
+
161
+ def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
162
+ h = super().forward(x)
163
+ for block in self.upsample:
164
+ h = block(h)
165
+ h = h.type(x.dtype)
166
+ h = self.out_layer(h)
167
+ return self.to_representation(h)
iscene/trellis/modules/attention/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'flash_attn'
4
+ DEBUG = False
5
+
6
+ def __from_env():
7
+ import os
8
+
9
+ global BACKEND
10
+ global DEBUG
11
+
12
+ env_attn_backend = os.environ.get('ATTN_BACKEND')
13
+ env_sttn_debug = os.environ.get('ATTN_DEBUG')
14
+
15
+ if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
16
+ BACKEND = env_attn_backend
17
+ if env_sttn_debug is not None:
18
+ DEBUG = env_sttn_debug == '1'
19
+
20
+ print(f"[ATTENTION] Using backend: {BACKEND}")
21
+
22
+
23
+ __from_env()
24
+
25
+
26
+ def set_backend(backend: Literal['xformers', 'flash_attn']):
27
+ global BACKEND
28
+ BACKEND = backend
29
+
30
+ def set_debug(debug: bool):
31
+ global DEBUG
32
+ DEBUG = debug
33
+
34
+
35
+ from .full_attn import *
36
+ from .modules import *
iscene/trellis/modules/attention/full_attn.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from . import DEBUG, BACKEND
5
+
6
+ if BACKEND == 'xformers':
7
+ import xformers.ops as xops
8
+ elif BACKEND == 'flash_attn':
9
+ import flash_attn
10
+ elif BACKEND == 'sdpa':
11
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
12
+ elif BACKEND == 'naive':
13
+ pass
14
+ else:
15
+ raise ValueError(f"Unknown attention backend: {BACKEND}")
16
+
17
+
18
+ __all__ = [
19
+ 'scaled_dot_product_attention',
20
+ ]
21
+
22
+
23
+ def _naive_sdpa(q, k, v):
24
+ """
25
+ Naive implementation of scaled dot product attention.
26
+ """
27
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
28
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
29
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
30
+ scale_factor = 1 / math.sqrt(q.size(-1))
31
+ attn_weight = q @ k.transpose(-2, -1) * scale_factor
32
+ attn_weight = torch.softmax(attn_weight, dim=-1)
33
+ out = attn_weight @ v
34
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
35
+ return out
36
+
37
+
38
+ @overload
39
+ def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ Apply scaled dot product attention.
42
+
43
+ Args:
44
+ qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
45
+ """
46
+ ...
47
+
48
+ @overload
49
+ def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
50
+ """
51
+ Apply scaled dot product attention.
52
+
53
+ Args:
54
+ q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
55
+ kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
56
+ """
57
+ ...
58
+
59
+ @overload
60
+ def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Apply scaled dot product attention.
63
+
64
+ Args:
65
+ q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
66
+ k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
67
+ v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
68
+
69
+ Note:
70
+ k and v are assumed to have the same coordinate map.
71
+ """
72
+ ...
73
+
74
+ def scaled_dot_product_attention(*args, **kwargs):
75
+ arg_names_dict = {
76
+ 1: ['qkv'],
77
+ 2: ['q', 'kv'],
78
+ 3: ['q', 'k', 'v']
79
+ }
80
+ num_all_args = len(args) + len(kwargs)
81
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
82
+ for key in arg_names_dict[num_all_args][len(args):]:
83
+ assert key in kwargs, f"Missing argument {key}"
84
+
85
+ if num_all_args == 1:
86
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
87
+ assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
88
+ device = qkv.device
89
+
90
+ elif num_all_args == 2:
91
+ q = args[0] if len(args) > 0 else kwargs['q']
92
+ kv = args[1] if len(args) > 1 else kwargs['kv']
93
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
94
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
95
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
96
+ device = q.device
97
+
98
+ elif num_all_args == 3:
99
+ q = args[0] if len(args) > 0 else kwargs['q']
100
+ k = args[1] if len(args) > 1 else kwargs['k']
101
+ v = args[2] if len(args) > 2 else kwargs['v']
102
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
103
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
104
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
105
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
106
+ device = q.device
107
+
108
+ if BACKEND == 'xformers':
109
+ if num_all_args == 1:
110
+ q, k, v = qkv.unbind(dim=2)
111
+ elif num_all_args == 2:
112
+ k, v = kv.unbind(dim=2)
113
+ out = xops.memory_efficient_attention(q, k, v)
114
+ elif BACKEND == 'flash_attn':
115
+ if num_all_args == 1:
116
+ out = flash_attn.flash_attn_qkvpacked_func(qkv)
117
+ elif num_all_args == 2:
118
+ out = flash_attn.flash_attn_kvpacked_func(q, kv)
119
+ elif num_all_args == 3:
120
+ out = flash_attn.flash_attn_func(q, k, v)
121
+ elif BACKEND == 'sdpa':
122
+ if num_all_args == 1:
123
+ q, k, v = qkv.unbind(dim=2)
124
+ elif num_all_args == 2:
125
+ k, v = kv.unbind(dim=2)
126
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
127
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
128
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
129
+ out = sdpa(q, k, v) # [N, H, L, C]
130
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
131
+ elif BACKEND == 'naive':
132
+ if num_all_args == 1:
133
+ q, k, v = qkv.unbind(dim=2)
134
+ elif num_all_args == 2:
135
+ k, v = kv.unbind(dim=2)
136
+ out = _naive_sdpa(q, k, v)
137
+ else:
138
+ raise ValueError(f"Unknown attention module: {BACKEND}")
139
+
140
+ return out
iscene/trellis/modules/attention/modules.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .full_attn import scaled_dot_product_attention
6
+ from einops import rearrange
7
+
8
+ class MultiHeadRMSNorm(nn.Module):
9
+ def __init__(self, dim: int, heads: int):
10
+ super().__init__()
11
+ self.scale = dim ** 0.5
12
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
16
+
17
+
18
+ class RotaryPositionEmbedder(nn.Module):
19
+ def __init__(self, hidden_size: int, in_channels: int = 3):
20
+ super().__init__()
21
+ assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
22
+ self.hidden_size = hidden_size
23
+ self.in_channels = in_channels
24
+ self.freq_dim = hidden_size // in_channels // 2
25
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
26
+ self.freqs = 1.0 / (10000 ** self.freqs)
27
+
28
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
29
+ self.freqs = self.freqs.to(indices.device)
30
+ phases = torch.outer(indices, self.freqs)
31
+ phases = torch.polar(torch.ones_like(phases), phases)
32
+ return phases
33
+
34
+ def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
35
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
36
+ x_rotated = x_complex * phases
37
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
38
+ return x_embed
39
+
40
+ def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
41
+ """
42
+ Args:
43
+ q (sp.SparseTensor): [..., N, D] tensor of queries
44
+ k (sp.SparseTensor): [..., N, D] tensor of keys
45
+ indices (torch.Tensor): [..., N, C] tensor of spatial positions
46
+ """
47
+ if indices is None:
48
+ indices = torch.arange(q.shape[-2], device=q.device)
49
+ if len(q.shape) > 2:
50
+ indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
51
+
52
+ phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
53
+ if phases.shape[1] < self.hidden_size // 2:
54
+ phases = torch.cat([phases, torch.polar(
55
+ torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
56
+ torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
57
+ )], dim=-1)
58
+ q_embed = self._rotary_embedding(q, phases)
59
+ k_embed = self._rotary_embedding(k, phases)
60
+ return q_embed, k_embed
61
+
62
+
63
+ class MultiHeadAttention(nn.Module):
64
+ def __init__(
65
+ self,
66
+ channels: int,
67
+ num_heads: int,
68
+ ctx_channels: Optional[int]=None,
69
+ type: Literal["self", "cross"] = "self",
70
+ attn_mode: Literal["full", "windowed"] = "full",
71
+ window_size: Optional[int] = None,
72
+ shift_window: Optional[Tuple[int, int, int]] = None,
73
+ qkv_bias: bool = True,
74
+ use_rope: bool = False,
75
+ qk_rms_norm: bool = False,
76
+ ):
77
+ super().__init__()
78
+ assert channels % num_heads == 0
79
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
80
+ assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
81
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
82
+
83
+ if attn_mode == "windowed":
84
+ raise NotImplementedError("Windowed attention is not yet implemented")
85
+
86
+ self.channels = channels
87
+ self.head_dim = channels // num_heads
88
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
89
+ self.num_heads = num_heads
90
+ self._type = type
91
+ self.attn_mode = attn_mode
92
+ self.window_size = window_size
93
+ self.shift_window = shift_window
94
+ self.use_rope = use_rope
95
+ self.qk_rms_norm = qk_rms_norm
96
+
97
+ if self._type == "self":
98
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
99
+ else:
100
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
101
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
102
+
103
+ if self.qk_rms_norm:
104
+ self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
105
+ self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
106
+
107
+ self.to_out = nn.Linear(channels, channels)
108
+
109
+ if use_rope:
110
+ self.rope = RotaryPositionEmbedder(channels)
111
+ self.use_positional_encoding = False
112
+
113
+ def initialize_positional_encoding(self, num_external_sources: int = 2, enable_gate: bool = True, enable_k_bias: bool = False, k_bias_scale: float = 0.1):
114
+ self.use_positional_encoding = True
115
+ # Controls for optional mechanisms
116
+ self.enable_ext_gate = bool(enable_gate)
117
+ self.enable_ext_k_bias = bool(enable_k_bias)
118
+ self.ext_k_bias_scale = float(k_bias_scale)
119
+
120
+ # K-gate for external keys only (values unchanged)
121
+ if self.enable_ext_gate:
122
+ self.ext_gate = nn.Parameter(torch.full((num_external_sources, self.num_heads,), 0.0))
123
+
124
+ # Per-source, per-head K additive bias vector (bounded via tanh during application)
125
+ if self.enable_ext_k_bias:
126
+ self.k_type_bias = nn.Parameter(torch.zeros(num_external_sources, self.num_heads, self.head_dim))
127
+
128
+
129
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
130
+ B, L, C = x.shape
131
+ if self._type == "self":
132
+ qkv = self.to_qkv(x)
133
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
134
+ if self.use_rope:
135
+ q, k, v = qkv.unbind(dim=2)
136
+ q, k = self.rope(q, k, indices)
137
+ qkv = torch.stack([q, k, v], dim=2)
138
+ if self.attn_mode == "full":
139
+ if self.qk_rms_norm:
140
+ q, k, v = qkv.unbind(dim=2)
141
+ q = self.q_rms_norm(q)
142
+ k = self.k_rms_norm(k)
143
+ h = scaled_dot_product_attention(q, k, v)
144
+ else:
145
+ h = scaled_dot_product_attention(qkv)
146
+ elif self.attn_mode == "windowed":
147
+ raise NotImplementedError("Windowed attention is not yet implemented")
148
+ else:
149
+ Lkv = context.shape[1]
150
+ q = self.to_q(x)
151
+ kv = self.to_kv(context)
152
+ q = q.reshape(B, L, self.num_heads, -1)
153
+ kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
154
+ if self.qk_rms_norm:
155
+ q = self.q_rms_norm(q)
156
+ k, v = kv.unbind(dim=2)
157
+ k = self.k_rms_norm(k)
158
+ h = scaled_dot_product_attention(q, k, v)
159
+ else:
160
+ h = scaled_dot_product_attention(q, kv)
161
+ h = h.reshape(B, L, -1)
162
+ h = self.to_out(h)
163
+ return h
164
+
165
+ def mi_attention(self, x: torch.Tensor, num_instances: int, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
166
+ """
167
+ Multi-instance self-attention.
168
+ q stays (B_total, L, ...).
169
+ k, v are concatenated across instances (N) -> (B, N*L, ...), then expanded to (B*N, N*L, ...).
170
+ """
171
+ B_total, L, C = x.shape
172
+
173
+ # 1. QKV projection
174
+ qkv = self.to_qkv(x).reshape(B_total, L, 3, self.num_heads, -1)
175
+ q, k, v = qkv.unbind(dim=2)
176
+
177
+ # 2. RoPE
178
+ if self.use_rope:
179
+ q, k = self.rope(q, k, indices)
180
+
181
+ if self.qk_rms_norm:
182
+ q = self.q_rms_norm(q)
183
+ k = self.k_rms_norm(k)
184
+
185
+ # q: (B*N, L, H, D)
186
+
187
+ # 3. Prepare K, V: merge instances in scene, then broadcast to each instance
188
+ # (B*N, L, H, D) -> (B, N*L, H, D)
189
+ k_scene = rearrange(k, '(b n) l h d -> b (n l) h d', n=num_instances)
190
+ v_scene = rearrange(v, '(b n) l h d -> b (n l) h d', n=num_instances)
191
+
192
+ # Expand to (B*N, N*L, H, D)
193
+ # We want each of the N instances in batch b to see the same k_scene[b]
194
+ # k_scene: (B, 1, NL, H, D) -> expand -> (B, N, NL, H, D) -> reshape -> (BN, NL, H, D)
195
+ k_all = k_scene.unsqueeze(1).expand(-1, num_instances, -1, -1, -1)
196
+ k_all = rearrange(k_all, 'b n nl h d -> (b n) nl h d')
197
+
198
+ v_all = v_scene.unsqueeze(1).expand(-1, num_instances, -1, -1, -1)
199
+ v_all = rearrange(v_all, 'b n nl h d -> (b n) nl h d')
200
+
201
+ # 4. Attention
202
+ # q: (BN, L, H, D)
203
+ # k_all: (BN, NL, H, D)
204
+ # out: (BN, L, H, D)
205
+ h = scaled_dot_product_attention(q, k_all, v_all)
206
+
207
+ # 6. Output projection
208
+ h = h.reshape(B_total, L, -1)
209
+ h = self.to_out(h)
210
+ return h
211
+
212
+ def scene_context_attn(self, x: torch.Tensor, context: torch.Tensor, num_instances=3, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
213
+ B, L, C = x.shape
214
+
215
+ # Project to QKV and apply rotary/QK RMS-norm as configured
216
+ qkv = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1)
217
+ q, k, v = qkv.unbind(dim=2)
218
+ if self.use_rope:
219
+ q, k = self.rope(q, k, indices)
220
+ if self.qk_rms_norm:
221
+ q = self.q_rms_norm(q)
222
+ k = self.k_rms_norm(k)
223
+
224
+ # Reshape into pairs: (bp, num_instances, L, H, C)
225
+ qp = rearrange(q, '(bp ni) L h c -> bp ni L h c', ni=num_instances)
226
+ kp = rearrange(k, '(bp ni) L h c -> bp ni L h c', ni=num_instances)
227
+ vp = rearrange(v, '(bp ni) L h c -> bp ni L h c', ni=num_instances)
228
+
229
+ output_list =[]
230
+ ext_k_list = []
231
+ for ins_idx in range(1, num_instances):
232
+ k_j = kp[:, ins_idx] # (bp, L, H, C)
233
+
234
+ if self.use_positional_encoding:
235
+ # pick a source id for this external (share or per-instance)
236
+ # share: src_id = 0 # if you only defined one external source
237
+ src_id = ins_idx - 1
238
+
239
+ if getattr(self, 'enable_ext_k_bias', False):
240
+ bias = torch.tanh(self.k_type_bias[src_id])[None, None, :, :].to(dtype=k_j.dtype, device=k_j.device)
241
+ k_j = k_j + self.ext_k_bias_scale * bias
242
+
243
+ if getattr(self, 'enable_ext_gate', False):
244
+ alpha = torch.sigmoid(self.ext_gate[src_id])[None, None, :, None].to(dtype=k_j.dtype, device=k_j.device)
245
+ k_j = k_j * alpha
246
+
247
+ ext_k_list.append(k_j)
248
+
249
+ k_full = torch.cat([kp[:, 0]] + ext_k_list, dim=1) # (bp, num_instances * L, H, C)
250
+ v_full = torch.cat([vp[:, i] for i in range(num_instances)], dim=1)
251
+ out_inst = scaled_dot_product_attention(qp[:, 0], k_full, v_full)
252
+ output_list.append(out_inst)
253
+
254
+ # num_instance > 1 are separated for scene and instance
255
+ # Scene/canonical attends only to scene KV
256
+ for i in range(1, num_instances):
257
+ self_attn_instance = scaled_dot_product_attention(qp[:, i], kp[:, i], vp[:, i])
258
+ output_list.append(self_attn_instance)
259
+
260
+ # Stitch back to (B, L, H, C) → (B, L, C_all) → linear proj
261
+ h = torch.stack(output_list, dim=1) # (bp, num_instances, L, H, C)
262
+ h = rearrange(h, 'bp ni L h c -> (bp ni) L h c')
263
+ h = h.reshape(B, L, -1)
264
+ h = self.to_out(h)
265
+ return h
266
+
267
+ def self_attn_join_external(self, x: torch.Tensor, external_tokens: Union[torch.Tensor, List[torch.Tensor]], indices: Optional[torch.Tensor] = None) -> torch.Tensor:
268
+ """
269
+ Self-attention where queries come from x, and keys/values are augmented
270
+ with one or more external token sequences. All projections (Q/K/V) use
271
+ this module's own projection weights to keep them in the same space.
272
+
273
+ Args:
274
+ x: (B, Lq, C) queries from the current stream
275
+ external_tokens: either a tensor (B, Lext, C) or a list of tensors
276
+ each of shape (B, Lext_i, C)
277
+ indices: optional rotary indices
278
+ Returns:
279
+ (B, Lq, C) attended output
280
+ """
281
+ assert self._type == "self", "self_attn_join_external is only valid for self-attention"
282
+
283
+ if isinstance(external_tokens, torch.Tensor):
284
+ external_list: List[torch.Tensor] = [external_tokens]
285
+ else:
286
+ external_list = list(external_tokens)
287
+
288
+ B, Lq, C = x.shape
289
+
290
+ # Project Q/K/V for x
291
+ qkv = self.to_qkv(x).reshape(B, Lq, 3, self.num_heads, -1)
292
+ q, k, v = qkv.unbind(dim=2)
293
+
294
+ if self.use_rope:
295
+ q, k = self.rope(q, k, indices)
296
+
297
+ # Optional Q/K RMSNorm
298
+ if self.qk_rms_norm:
299
+ q = self.q_rms_norm(q)
300
+ k = self.k_rms_norm(k)
301
+
302
+ # Project only K/V for external tokens using the SAME to_qkv weights
303
+ k_ext_list: List[torch.Tensor] = []
304
+ v_ext_list: List[torch.Tensor] = []
305
+ for i, ext in enumerate(external_list):
306
+ assert ext.dim() == 3, f"external token must be 3D (B, L, C), got {ext.shape}"
307
+ assert ext.shape[0] == B, f"Batch size mismatch: ext B={ext.shape[0]} vs x B={B}"
308
+ # Do not alter raw external token content; avoid adding source/type embedding to ext tokens
309
+ ext_qkv = self.to_qkv(ext).reshape(ext.shape[0], ext.shape[1], 3, self.num_heads, -1)
310
+ _, k_ext, v_ext = ext_qkv.unbind(dim=2)
311
+ if self.use_rope:
312
+ # apply RoPE to external K; use K as both inputs to get rotated K
313
+ _, k_ext = self.rope(k_ext, k_ext, indices)
314
+ if self.qk_rms_norm:
315
+ k_ext = self.k_rms_norm(k_ext)
316
+
317
+ if self.use_positional_encoding:
318
+ # Optional per-head K type bias (vector) applied after RoPE/RMSNorm
319
+ if getattr(self, 'enable_ext_k_bias', False):
320
+ bias_vec = torch.tanh(self.k_type_bias[i])[None, None, :, :].to(k_ext.dtype)
321
+ k_ext = k_ext + self.ext_k_bias_scale * bias_vec
322
+
323
+ # Optional per-head gate to modulate influence of external keys only (values unchanged)
324
+ if getattr(self, 'enable_ext_gate', False):
325
+ alpha = torch.sigmoid(self.ext_gate[i])[None, None, :, None].to(k_ext.dtype)
326
+ k_ext = k_ext * alpha
327
+
328
+ k_ext_list.append(k_ext)
329
+ v_ext_list.append(v_ext)
330
+
331
+ # Concatenate K/V along sequence dimension
332
+ if len(k_ext_list) > 0:
333
+ k_cat = torch.cat([k] + k_ext_list, dim=1)
334
+ v_cat = torch.cat([v] + v_ext_list, dim=1)
335
+ else:
336
+ k_cat, v_cat = k, v
337
+
338
+ # Attention and output
339
+ h = scaled_dot_product_attention(q, k_cat, v_cat)
340
+ h = h.reshape(B, Lq, -1)
341
+ h = self.to_out(h)
342
+ return h
iscene/trellis/modules/attention_resample.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ try:
9
+ import flash_attn
10
+ except ImportError: # pragma: no cover - flash-attn is optional
11
+ flash_attn = None
12
+
13
+ __all__ = ["AttentionResample"]
14
+
15
+
16
+ class AttentionResample(nn.Module):
17
+ """Resample a variable-length token sequence to a fixed target length."""
18
+
19
+ def __init__(
20
+ self,
21
+ d_model: int = 1024,
22
+ n_target: int = 4096,
23
+ *,
24
+ n_heads: int = 16,
25
+ use_flash: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+
29
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
30
+ self.d_model = d_model
31
+ self.n_target = n_target
32
+ self.n_heads = n_heads
33
+ self.head_dim = d_model // n_heads
34
+ self.scale = self.head_dim ** -0.5
35
+
36
+ self.latent = nn.Parameter(torch.randn(n_target, d_model))
37
+ self.to_kv = nn.Linear(d_model, 2 * d_model, bias=False)
38
+ self._flash_available = use_flash and flash_attn is not None
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ """Return a tensor with shape (B, n_target, d_model)."""
42
+ batch_size, _, dim = x.shape
43
+ assert dim == self.d_model, f"Expected input dim {self.d_model}, got {dim}"
44
+
45
+ q = self.latent.unsqueeze(0).expand(batch_size, -1, -1)
46
+ k, v = self.to_kv(x).chunk(2, dim=-1)
47
+
48
+ if self._flash_available:
49
+ return self._forward_flash(q, k, v)
50
+ return self._forward_torch(q, k, v)
51
+
52
+ def _forward_torch(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
53
+ batch_size = q.size(0)
54
+ q = q.view(batch_size, self.n_target, self.n_heads, self.head_dim).transpose(1, 2)
55
+ k = k.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
56
+ v = v.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
57
+
58
+ attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
59
+ weights = torch.softmax(attn, dim=-1, dtype=attn.dtype)
60
+ out = torch.matmul(weights, v)
61
+ return out.transpose(1, 2).contiguous().view(batch_size, self.n_target, self.d_model)
62
+
63
+ def _forward_flash(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
64
+ batch_size = q.size(0)
65
+ q = q.view(batch_size, self.n_target, self.n_heads, self.head_dim).contiguous()
66
+ k = k.view(batch_size, -1, self.n_heads, self.head_dim).contiguous()
67
+ v = v.view(batch_size, -1, self.n_heads, self.head_dim).contiguous()
68
+
69
+ assert flash_attn is not None
70
+ out = flash_attn.flash_attn_func(
71
+ q,
72
+ k,
73
+ v, # type: ignore[arg-type]
74
+ causal=False,
75
+ softmax_scale=self.scale,
76
+ )
77
+ return out.reshape(batch_size, self.n_target, self.d_model)
iscene/trellis/modules/norm.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LayerNorm32(nn.LayerNorm):
6
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
7
+ return super().forward(x.float()).type(x.dtype)
8
+
9
+
10
+ class GroupNorm32(nn.GroupNorm):
11
+ """
12
+ A GroupNorm layer that converts to float32 before the forward pass.
13
+ """
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return super().forward(x.float()).type(x.dtype)
16
+
17
+
18
+ class ChannelLayerNorm32(LayerNorm32):
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ DIM = x.dim()
21
+ x = x.permute(0, *range(2, DIM), 1).contiguous()
22
+ x = super().forward(x)
23
+ x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
24
+ return x
iscene/trellis/modules/sparse/__init__.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'spconv'
4
+ DEBUG = False
5
+ ATTN = 'flash_attn'
6
+
7
+ def __from_env():
8
+ import os
9
+
10
+ global BACKEND
11
+ global DEBUG
12
+ global ATTN
13
+
14
+ env_sparse_backend = os.environ.get('SPARSE_BACKEND')
15
+ env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
+ env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
17
+ if env_sparse_attn is None:
18
+ env_sparse_attn = os.environ.get('ATTN_BACKEND')
19
+
20
+ if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
21
+ BACKEND = env_sparse_backend
22
+ if env_sparse_debug is not None:
23
+ DEBUG = env_sparse_debug == '1'
24
+ if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
+ ATTN = env_sparse_attn
26
+
27
+ print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
+
29
+
30
+ __from_env()
31
+
32
+
33
+ def set_backend(backend: Literal['spconv', 'torchsparse']):
34
+ global BACKEND
35
+ BACKEND = backend
36
+
37
+ def set_debug(debug: bool):
38
+ global DEBUG
39
+ DEBUG = debug
40
+
41
+ def set_attn(attn: Literal['xformers', 'flash_attn']):
42
+ global ATTN
43
+ ATTN = attn
44
+
45
+
46
+ import importlib
47
+
48
+ __attributes = {
49
+ 'SparseTensor': 'basic',
50
+ 'sparse_batch_broadcast': 'basic',
51
+ 'sparse_batch_op': 'basic',
52
+ 'sparse_cat': 'basic',
53
+ 'sparse_unbind': 'basic',
54
+ 'SparseGroupNorm': 'norm',
55
+ 'SparseLayerNorm': 'norm',
56
+ 'SparseGroupNorm32': 'norm',
57
+ 'SparseLayerNorm32': 'norm',
58
+ 'SparseReLU': 'nonlinearity',
59
+ 'SparseSiLU': 'nonlinearity',
60
+ 'SparseGELU': 'nonlinearity',
61
+ 'SparseActivation': 'nonlinearity',
62
+ 'SparseLinear': 'linear',
63
+ 'sparse_scaled_dot_product_attention': 'attention',
64
+ 'SerializeMode': 'attention',
65
+ 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
66
+ 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
67
+ 'SparseMultiHeadAttention': 'attention',
68
+ 'SparseConv3d': 'conv',
69
+ 'SparseInverseConv3d': 'conv',
70
+ 'SparseDownsample': 'spatial',
71
+ 'SparseUpsample': 'spatial',
72
+ 'SparseSubdivide' : 'spatial'
73
+ }
74
+
75
+ __submodules = ['transformer']
76
+
77
+ __all__ = list(__attributes.keys()) + __submodules
78
+
79
+ def __getattr__(name):
80
+ if name not in globals():
81
+ if name in __attributes:
82
+ module_name = __attributes[name]
83
+ module = importlib.import_module(f".{module_name}", __name__)
84
+ globals()[name] = getattr(module, name)
85
+ elif name in __submodules:
86
+ module = importlib.import_module(f".{name}", __name__)
87
+ globals()[name] = module
88
+ else:
89
+ raise AttributeError(f"module {__name__} has no attribute {name}")
90
+ return globals()[name]
91
+
92
+
93
+ # For Pylance
94
+ if __name__ == '__main__':
95
+ from .basic import *
96
+ from .norm import *
97
+ from .nonlinearity import *
98
+ from .linear import *
99
+ from .attention import *
100
+ from .conv import *
101
+ from .spatial import *
102
+ import transformer
iscene/trellis/modules/sparse/attention/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .full_attn import *
2
+ from .serialized_attn import *
3
+ from .windowed_attn import *
4
+ from .modules import *
iscene/trellis/modules/sparse/attention/full_attn.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ from .. import SparseTensor
4
+ from .. import DEBUG, ATTN
5
+
6
+ if ATTN == 'xformers':
7
+ import xformers.ops as xops
8
+ elif ATTN == 'flash_attn':
9
+ import flash_attn
10
+ else:
11
+ raise ValueError(f"Unknown attention module: {ATTN}")
12
+
13
+
14
+ __all__ = [
15
+ 'sparse_scaled_dot_product_attention',
16
+ ]
17
+
18
+
19
+ @overload
20
+ def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
21
+ """
22
+ Apply scaled dot product attention to a sparse tensor.
23
+
24
+ Args:
25
+ qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
26
+ """
27
+ ...
28
+
29
+ @overload
30
+ def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
31
+ """
32
+ Apply scaled dot product attention to a sparse tensor.
33
+
34
+ Args:
35
+ q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
36
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
37
+ """
38
+ ...
39
+
40
+ @overload
41
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
42
+ """
43
+ Apply scaled dot product attention to a sparse tensor.
44
+
45
+ Args:
46
+ q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
47
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
48
+ """
49
+ ...
50
+
51
+ @overload
52
+ def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
53
+ """
54
+ Apply scaled dot product attention to a sparse tensor.
55
+
56
+ Args:
57
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
58
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
59
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
60
+
61
+ Note:
62
+ k and v are assumed to have the same coordinate map.
63
+ """
64
+ ...
65
+
66
+ @overload
67
+ def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
68
+ """
69
+ Apply scaled dot product attention to a sparse tensor.
70
+
71
+ Args:
72
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
73
+ k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
74
+ v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
75
+ """
76
+ ...
77
+
78
+ @overload
79
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
80
+ """
81
+ Apply scaled dot product attention to a sparse tensor.
82
+
83
+ Args:
84
+ q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
85
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
86
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
87
+ """
88
+ ...
89
+
90
+ def sparse_scaled_dot_product_attention(*args, **kwargs):
91
+ arg_names_dict = {
92
+ 1: ['qkv'],
93
+ 2: ['q', 'kv'],
94
+ 3: ['q', 'k', 'v']
95
+ }
96
+ num_all_args = len(args) + len(kwargs)
97
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
98
+ for key in arg_names_dict[num_all_args][len(args):]:
99
+ assert key in kwargs, f"Missing argument {key}"
100
+
101
+ if num_all_args == 1:
102
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
103
+ assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
104
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
105
+ device = qkv.device
106
+
107
+ s = qkv
108
+ q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
109
+ kv_seqlen = q_seqlen
110
+ qkv = qkv.feats # [T, 3, H, C]
111
+
112
+ elif num_all_args == 2:
113
+ q = args[0] if len(args) > 0 else kwargs['q']
114
+ kv = args[1] if len(args) > 1 else kwargs['kv']
115
+ assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
116
+ isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
117
+ f"Invalid types, got {type(q)} and {type(kv)}"
118
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
119
+ device = q.device
120
+
121
+ if isinstance(q, SparseTensor):
122
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
123
+ s = q
124
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
125
+ q = q.feats # [T_Q, H, C]
126
+ else:
127
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
128
+ s = None
129
+ N, L, H, C = q.shape
130
+ q_seqlen = [L] * N
131
+ q = q.reshape(N * L, H, C) # [T_Q, H, C]
132
+
133
+ if isinstance(kv, SparseTensor):
134
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
135
+ kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
136
+ kv = kv.feats # [T_KV, 2, H, C]
137
+ else:
138
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
139
+ N, L, _, H, C = kv.shape
140
+ kv_seqlen = [L] * N
141
+ kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
142
+
143
+ elif num_all_args == 3:
144
+ q = args[0] if len(args) > 0 else kwargs['q']
145
+ k = args[1] if len(args) > 1 else kwargs['k']
146
+ v = args[2] if len(args) > 2 else kwargs['v']
147
+ assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
148
+ isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
149
+ f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
150
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
151
+ device = q.device
152
+
153
+ if isinstance(q, SparseTensor):
154
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
155
+ s = q
156
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
157
+ q = q.feats # [T_Q, H, Ci]
158
+ else:
159
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
160
+ s = None
161
+ N, L, H, CI = q.shape
162
+ q_seqlen = [L] * N
163
+ q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
164
+
165
+ if isinstance(k, SparseTensor):
166
+ assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
167
+ assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
168
+ kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
169
+ k = k.feats # [T_KV, H, Ci]
170
+ v = v.feats # [T_KV, H, Co]
171
+ else:
172
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
173
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
174
+ N, L, H, CI, CO = *k.shape, v.shape[-1]
175
+ kv_seqlen = [L] * N
176
+ k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
177
+ v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
178
+
179
+ if DEBUG:
180
+ if s is not None:
181
+ for i in range(s.shape[0]):
182
+ assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
183
+ if num_all_args in [2, 3]:
184
+ assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
185
+ if num_all_args == 3:
186
+ assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
187
+ assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
188
+
189
+ if ATTN == 'xformers':
190
+ if num_all_args == 1:
191
+ q, k, v = qkv.unbind(dim=1)
192
+ elif num_all_args == 2:
193
+ k, v = kv.unbind(dim=1)
194
+ q = q.unsqueeze(0)
195
+ k = k.unsqueeze(0)
196
+ v = v.unsqueeze(0)
197
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
198
+ out = xops.memory_efficient_attention(q, k, v, mask)[0]
199
+ elif ATTN == 'flash_attn':
200
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
201
+ if num_all_args in [2, 3]:
202
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
203
+ if num_all_args == 1:
204
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
205
+ elif num_all_args == 2:
206
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
207
+ elif num_all_args == 3:
208
+ out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
209
+ else:
210
+ raise ValueError(f"Unknown attention module: {ATTN}")
211
+
212
+ if s is not None:
213
+ return s.replace(out)
214
+ else:
215
+ return out.reshape(N, L, H, -1)
iscene/trellis/modules/sparse/attention/modules.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .. import SparseTensor
6
+ from .full_attn import sparse_scaled_dot_product_attention
7
+ from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
8
+ from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
9
+ from ...attention import RotaryPositionEmbedder
10
+
11
+
12
+ class SparseMultiHeadRMSNorm(nn.Module):
13
+ def __init__(self, dim: int, heads: int):
14
+ super().__init__()
15
+ self.scale = dim ** 0.5
16
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
17
+
18
+ def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
19
+ x_type = x.dtype
20
+ x = x.float()
21
+ if isinstance(x, SparseTensor):
22
+ x = x.replace(F.normalize(x.feats, dim=-1))
23
+ else:
24
+ x = F.normalize(x, dim=-1)
25
+ return (x * self.gamma * self.scale).to(x_type)
26
+
27
+
28
+ class SparseMultiHeadAttention(nn.Module):
29
+ def __init__(
30
+ self,
31
+ channels: int,
32
+ num_heads: int,
33
+ ctx_channels: Optional[int] = None,
34
+ type: Literal["self", "cross"] = "self",
35
+ attn_mode: Literal["full", "serialized", "windowed"] = "full",
36
+ window_size: Optional[int] = None,
37
+ shift_sequence: Optional[int] = None,
38
+ shift_window: Optional[Tuple[int, int, int]] = None,
39
+ serialize_mode: Optional[SerializeMode] = None,
40
+ qkv_bias: bool = True,
41
+ use_rope: bool = False,
42
+ qk_rms_norm: bool = False,
43
+ ):
44
+ super().__init__()
45
+ assert channels % num_heads == 0
46
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
47
+ assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
48
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
49
+ assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
50
+ self.channels = channels
51
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
52
+ self.num_heads = num_heads
53
+ self._type = type
54
+ self.attn_mode = attn_mode
55
+ self.window_size = window_size
56
+ self.shift_sequence = shift_sequence
57
+ self.shift_window = shift_window
58
+ self.serialize_mode = serialize_mode
59
+ self.use_rope = use_rope
60
+ self.qk_rms_norm = qk_rms_norm
61
+
62
+ if self._type == "self":
63
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
64
+ else:
65
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
66
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
67
+
68
+ if self.qk_rms_norm:
69
+ self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
70
+ self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
71
+
72
+ self.to_out = nn.Linear(channels, channels)
73
+
74
+ if use_rope:
75
+ self.rope = RotaryPositionEmbedder(channels)
76
+
77
+ @staticmethod
78
+ def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
79
+ if isinstance(x, SparseTensor):
80
+ return x.replace(module(x.feats))
81
+ else:
82
+ return module(x)
83
+
84
+ @staticmethod
85
+ def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
86
+ if isinstance(x, SparseTensor):
87
+ return x.reshape(*shape)
88
+ else:
89
+ return x.reshape(*x.shape[:2], *shape)
90
+
91
+ def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
92
+ if isinstance(x, SparseTensor):
93
+ x_feats = x.feats.unsqueeze(0)
94
+ else:
95
+ x_feats = x
96
+ x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
97
+ return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
98
+
99
+ def _rope(self, qkv: SparseTensor) -> SparseTensor:
100
+ q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
101
+ q, k = self.rope(q, k, qkv.coords[:, 1:])
102
+ qkv = qkv.replace(torch.stack([q, k, v], dim=1))
103
+ return qkv
104
+
105
+ def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
106
+ if self._type == "self":
107
+ qkv = self._linear(self.to_qkv, x)
108
+ qkv = self._fused_pre(qkv, num_fused=3)
109
+ if self.use_rope:
110
+ qkv = self._rope(qkv)
111
+ if self.qk_rms_norm:
112
+ q, k, v = qkv.unbind(dim=1)
113
+ q = self.q_rms_norm(q)
114
+ k = self.k_rms_norm(k)
115
+ qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
116
+ if self.attn_mode == "full":
117
+ h = sparse_scaled_dot_product_attention(qkv)
118
+ elif self.attn_mode == "serialized":
119
+ h = sparse_serialized_scaled_dot_product_self_attention(
120
+ qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
121
+ )
122
+ elif self.attn_mode == "windowed":
123
+ h = sparse_windowed_scaled_dot_product_self_attention(
124
+ qkv, self.window_size, shift_window=self.shift_window
125
+ )
126
+ else:
127
+ q = self._linear(self.to_q, x)
128
+ q = self._reshape_chs(q, (self.num_heads, -1))
129
+ kv = self._linear(self.to_kv, context)
130
+ kv = self._fused_pre(kv, num_fused=2)
131
+ if self.qk_rms_norm:
132
+ q = self.q_rms_norm(q)
133
+ k, v = kv.unbind(dim=1)
134
+ k = self.k_rms_norm(k)
135
+ kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
136
+ h = sparse_scaled_dot_product_attention(q, kv)
137
+ h = self._reshape_chs(h, (-1,))
138
+ h = self._linear(self.to_out, h)
139
+ return h
iscene/trellis/modules/sparse/attention/serialized_attn.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from enum import Enum
3
+ import torch
4
+ import math
5
+ from .. import SparseTensor
6
+ from .. import DEBUG, ATTN
7
+
8
+ if ATTN == 'xformers':
9
+ import xformers.ops as xops
10
+ elif ATTN == 'flash_attn':
11
+ import flash_attn
12
+ else:
13
+ raise ValueError(f"Unknown attention module: {ATTN}")
14
+
15
+
16
+ __all__ = [
17
+ 'sparse_serialized_scaled_dot_product_self_attention',
18
+ ]
19
+
20
+
21
+ class SerializeMode(Enum):
22
+ Z_ORDER = 0
23
+ Z_ORDER_TRANSPOSED = 1
24
+ HILBERT = 2
25
+ HILBERT_TRANSPOSED = 3
26
+
27
+
28
+ SerializeModes = [
29
+ SerializeMode.Z_ORDER,
30
+ SerializeMode.Z_ORDER_TRANSPOSED,
31
+ SerializeMode.HILBERT,
32
+ SerializeMode.HILBERT_TRANSPOSED
33
+ ]
34
+
35
+
36
+ def calc_serialization(
37
+ tensor: SparseTensor,
38
+ window_size: int,
39
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
40
+ shift_sequence: int = 0,
41
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
42
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
43
+ """
44
+ Calculate serialization and partitioning for a set of coordinates.
45
+
46
+ Args:
47
+ tensor (SparseTensor): The input tensor.
48
+ window_size (int): The window size to use.
49
+ serialize_mode (SerializeMode): The serialization mode to use.
50
+ shift_sequence (int): The shift of serialized sequence.
51
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
52
+
53
+ Returns:
54
+ (torch.Tensor, torch.Tensor): Forwards and backwards indices.
55
+ """
56
+ fwd_indices = []
57
+ bwd_indices = []
58
+ seq_lens = []
59
+ seq_batch_indices = []
60
+ offsets = [0]
61
+
62
+ if 'vox2seq' not in globals():
63
+ import vox2seq
64
+
65
+ # Serialize the input
66
+ serialize_coords = tensor.coords[:, 1:].clone()
67
+ serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
68
+ if serialize_mode == SerializeMode.Z_ORDER:
69
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
70
+ elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
71
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
72
+ elif serialize_mode == SerializeMode.HILBERT:
73
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
74
+ elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
75
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
76
+ else:
77
+ raise ValueError(f"Unknown serialize mode: {serialize_mode}")
78
+
79
+ for bi, s in enumerate(tensor.layout):
80
+ num_points = s.stop - s.start
81
+ num_windows = (num_points + window_size - 1) // window_size
82
+ valid_window_size = num_points / num_windows
83
+ to_ordered = torch.argsort(code[s.start:s.stop])
84
+ if num_windows == 1:
85
+ fwd_indices.append(to_ordered)
86
+ bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
87
+ fwd_indices[-1] += s.start
88
+ bwd_indices[-1] += offsets[-1]
89
+ seq_lens.append(num_points)
90
+ seq_batch_indices.append(bi)
91
+ offsets.append(offsets[-1] + seq_lens[-1])
92
+ else:
93
+ # Partition the input
94
+ offset = 0
95
+ mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
96
+ split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
97
+ bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
98
+ for i in range(num_windows):
99
+ mid = mids[i]
100
+ valid_start = split[i]
101
+ valid_end = split[i + 1]
102
+ padded_start = math.floor(mid - 0.5 * window_size)
103
+ padded_end = padded_start + window_size
104
+ fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
105
+ offset += valid_start - padded_start
106
+ bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
107
+ offset += padded_end - valid_start
108
+ fwd_indices[-1] += s.start
109
+ seq_lens.extend([window_size] * num_windows)
110
+ seq_batch_indices.extend([bi] * num_windows)
111
+ bwd_indices.append(bwd_index + offsets[-1])
112
+ offsets.append(offsets[-1] + num_windows * window_size)
113
+
114
+ fwd_indices = torch.cat(fwd_indices)
115
+ bwd_indices = torch.cat(bwd_indices)
116
+
117
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
118
+
119
+
120
+ def sparse_serialized_scaled_dot_product_self_attention(
121
+ qkv: SparseTensor,
122
+ window_size: int,
123
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
124
+ shift_sequence: int = 0,
125
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
126
+ ) -> SparseTensor:
127
+ """
128
+ Apply serialized scaled dot product self attention to a sparse tensor.
129
+
130
+ Args:
131
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
132
+ window_size (int): The window size to use.
133
+ serialize_mode (SerializeMode): The serialization mode to use.
134
+ shift_sequence (int): The shift of serialized sequence.
135
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
136
+ shift (int): The shift to use.
137
+ """
138
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
139
+
140
+ serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
141
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
142
+ if serialization_spatial_cache is None:
143
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
144
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
145
+ else:
146
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
147
+
148
+ M = fwd_indices.shape[0]
149
+ T = qkv.feats.shape[0]
150
+ H = qkv.feats.shape[2]
151
+ C = qkv.feats.shape[3]
152
+
153
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
154
+
155
+ if DEBUG:
156
+ start = 0
157
+ qkv_coords = qkv.coords[fwd_indices]
158
+ for i in range(len(seq_lens)):
159
+ assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
160
+ start += seq_lens[i]
161
+
162
+ if all([seq_len == window_size for seq_len in seq_lens]):
163
+ B = len(seq_lens)
164
+ N = window_size
165
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
166
+ if ATTN == 'xformers':
167
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
168
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
169
+ elif ATTN == 'flash_attn':
170
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
171
+ else:
172
+ raise ValueError(f"Unknown attention module: {ATTN}")
173
+ out = out.reshape(B * N, H, C) # [M, H, C]
174
+ else:
175
+ if ATTN == 'xformers':
176
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
177
+ q = q.unsqueeze(0) # [1, M, H, C]
178
+ k = k.unsqueeze(0) # [1, M, H, C]
179
+ v = v.unsqueeze(0) # [1, M, H, C]
180
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
181
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
182
+ elif ATTN == 'flash_attn':
183
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
184
+ .to(qkv.device).int()
185
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
186
+
187
+ out = out[bwd_indices] # [T, H, C]
188
+
189
+ if DEBUG:
190
+ qkv_coords = qkv_coords[bwd_indices]
191
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
192
+
193
+ return qkv.replace(out)
iscene/trellis/modules/sparse/attention/windowed_attn.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from .. import SparseTensor
5
+ from .. import DEBUG, ATTN
6
+
7
+ if ATTN == 'xformers':
8
+ import xformers.ops as xops
9
+ elif ATTN == 'flash_attn':
10
+ import flash_attn
11
+ else:
12
+ raise ValueError(f"Unknown attention module: {ATTN}")
13
+
14
+
15
+ __all__ = [
16
+ 'sparse_windowed_scaled_dot_product_self_attention',
17
+ ]
18
+
19
+
20
+ def _lexsort_columns(columns: List[torch.Tensor]) -> torch.Tensor:
21
+ if not columns:
22
+ raise ValueError("columns must be non-empty")
23
+ if columns[0].numel() == 0:
24
+ return torch.empty(0, dtype=torch.long, device=columns[0].device)
25
+
26
+ cols64 = [col.to(torch.int64) for col in columns]
27
+ max_vals = [int(col.max().item()) + 1 for col in cols64]
28
+ key = cols64[0]
29
+ for col, max_val in zip(cols64[1:], max_vals[1:]):
30
+ key = key * max_val + col
31
+ return torch.argsort(key)
32
+
33
+
34
+ def calc_window_partition(
35
+ tensor: SparseTensor,
36
+ window_size: Union[int, Tuple[int, ...]],
37
+ shift_window: Union[int, Tuple[int, ...]] = 0
38
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
39
+ """
40
+ Calculate serialization and partitioning for a set of coordinates.
41
+
42
+ Args:
43
+ tensor (SparseTensor): The input tensor.
44
+ window_size (int): The window size to use.
45
+ shift_window (Tuple[int, ...]): The shift of serialized coordinates.
46
+
47
+ Returns:
48
+ (torch.Tensor): Forwards indices.
49
+ (torch.Tensor): Backwards indices.
50
+ (List[int]): Sequence lengths.
51
+ (List[int]): Sequence batch indices.
52
+ """
53
+ DIM = tensor.coords.shape[1] - 1
54
+ shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
55
+ window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
56
+ shifted_coords = tensor.coords.clone().detach()
57
+ shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
58
+ fine_coords = shifted_coords[:, 1:].clone()
59
+
60
+ MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
61
+ NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
62
+ OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
63
+
64
+ shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
65
+ shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
66
+ fwd_indices = _lexsort_columns([shifted_indices, fine_coords[:, 0], fine_coords[:, 1], fine_coords[:, 2]])
67
+ bwd_indices = torch.empty_like(fwd_indices)
68
+ bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
69
+ seq_lens = torch.bincount(shifted_indices)
70
+ seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
71
+ mask = seq_lens != 0
72
+ seq_lens = seq_lens[mask].tolist()
73
+ seq_batch_indices = seq_batch_indices[mask].tolist()
74
+
75
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
76
+
77
+
78
+ def sparse_windowed_scaled_dot_product_self_attention(
79
+ qkv: SparseTensor,
80
+ window_size: int,
81
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
82
+ ) -> SparseTensor:
83
+ """
84
+ Apply windowed scaled dot product self attention to a sparse tensor.
85
+
86
+ Args:
87
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
88
+ window_size (int): The window size to use.
89
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
90
+ shift (int): The shift to use.
91
+ """
92
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
93
+
94
+ serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
95
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
96
+ if serialization_spatial_cache is None:
97
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
98
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
99
+ else:
100
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
101
+
102
+ M = fwd_indices.shape[0]
103
+ T = qkv.feats.shape[0]
104
+ H = qkv.feats.shape[2]
105
+ C = qkv.feats.shape[3]
106
+
107
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
108
+
109
+ if DEBUG:
110
+ start = 0
111
+ qkv_coords = qkv.coords[fwd_indices]
112
+ for i in range(len(seq_lens)):
113
+ seq_coords = qkv_coords[start:start+seq_lens[i]]
114
+ assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
115
+ assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
116
+ f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
117
+ start += seq_lens[i]
118
+
119
+ if all([seq_len == window_size for seq_len in seq_lens]):
120
+ B = len(seq_lens)
121
+ N = window_size
122
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
123
+ if ATTN == 'xformers':
124
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
125
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
126
+ elif ATTN == 'flash_attn':
127
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
128
+ else:
129
+ raise ValueError(f"Unknown attention module: {ATTN}")
130
+ out = out.reshape(B * N, H, C) # [M, H, C]
131
+ else:
132
+ if ATTN == 'xformers':
133
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
134
+ q = q.unsqueeze(0) # [1, M, H, C]
135
+ k = k.unsqueeze(0) # [1, M, H, C]
136
+ v = v.unsqueeze(0) # [1, M, H, C]
137
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
138
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
139
+ elif ATTN == 'flash_attn':
140
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
141
+ .to(qkv.device).int()
142
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
143
+
144
+ out = out[bwd_indices] # [T, H, C]
145
+
146
+ if DEBUG:
147
+ qkv_coords = qkv_coords[bwd_indices]
148
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
149
+
150
+ return qkv.replace(out)