charbel-malo commited on
Commit
58c9024
·
verified ·
1 Parent(s): 25f2611

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .editorconfig +12 -0
  2. .gitattributes +3 -0
  3. .gitignore +217 -0
  4. Installation.md +170 -0
  5. LICENSE +21 -0
  6. README.md +132 -7
  7. README_jp.md +126 -0
  8. README_zh.md +62 -0
  9. app/__init__.py +0 -0
  10. app/all_models.py +22 -0
  11. app/custom_models/image2mvimage.yaml +63 -0
  12. app/custom_models/image2normal.yaml +61 -0
  13. app/custom_models/mvimg_prediction.py +57 -0
  14. app/custom_models/normal_prediction.py +26 -0
  15. app/custom_models/utils.py +75 -0
  16. app/examples/Groot.png +0 -0
  17. app/examples/aaa.png +0 -0
  18. app/examples/abma.png +0 -0
  19. app/examples/akun.png +0 -0
  20. app/examples/anya.png +0 -0
  21. app/examples/bag.png +3 -0
  22. app/examples/ex1.png +3 -0
  23. app/examples/ex2.png +0 -0
  24. app/examples/ex3.jpg +0 -0
  25. app/examples/ex4.png +0 -0
  26. app/examples/generated_1715761545_frame0.png +0 -0
  27. app/examples/generated_1715762357_frame0.png +0 -0
  28. app/examples/generated_1715763329_frame0.png +0 -0
  29. app/examples/hatsune_miku.png +0 -0
  30. app/examples/princess-large.png +0 -0
  31. app/gradio_3dgen.py +71 -0
  32. app/gradio_3dgen_steps.py +87 -0
  33. app/gradio_local.py +76 -0
  34. app/utils.py +112 -0
  35. assets/teaser.jpg +0 -0
  36. assets/teaser_safe.jpg +3 -0
  37. custum_3d_diffusion/custum_modules/attention_processors.py +385 -0
  38. custum_3d_diffusion/custum_modules/unifield_processor.py +460 -0
  39. custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py +298 -0
  40. custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py +296 -0
  41. custum_3d_diffusion/modules.py +14 -0
  42. custum_3d_diffusion/trainings/__init__.py +0 -0
  43. custum_3d_diffusion/trainings/base.py +208 -0
  44. custum_3d_diffusion/trainings/config_classes.py +35 -0
  45. custum_3d_diffusion/trainings/image2image_trainer.py +86 -0
  46. custum_3d_diffusion/trainings/image2mvimage_trainer.py +139 -0
  47. custum_3d_diffusion/trainings/utils.py +25 -0
  48. docker/Dockerfile +54 -0
  49. docker/README.md +35 -0
  50. gradio_app.py +41 -0
.editorconfig ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ root = true
2
+
3
+ [*.py]
4
+ charset = utf-8
5
+ trim_trailing_whitespace = true
6
+ end_of_line = lf
7
+ insert_final_newline = true
8
+ indent_style = space
9
+ indent_size = 4
10
+
11
+ [*.md]
12
+ trim_trailing_whitespace = false
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ app/examples/bag.png filter=lfs diff=lfs merge=lfs -text
37
+ app/examples/ex1.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/teaser_safe.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ .idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+ # ruff
171
+ .ruff_cache/
172
+
173
+ # LSP config files
174
+ pyrightconfig.json
175
+
176
+ # End of https://www.toptal.com/developers/gitignore/api/python
177
+
178
+ .vscode/
179
+ .threestudio_cache/
180
+ outputs
181
+ outputs/
182
+ outputs-gradio
183
+ outputs-gradio/
184
+ lightning_logs/
185
+
186
+ # pretrained model weights
187
+ *.ckpt
188
+ *.pt
189
+ *.pth
190
+ *.bin
191
+ *.param
192
+
193
+ # wandb
194
+ wandb/
195
+
196
+ # obj results
197
+ *.obj
198
+ *.glb
199
+ *.ply
200
+
201
+ # ckpts
202
+ ckpt/*
203
+ *.pth
204
+ *.pt
205
+
206
+ # tensorrt
207
+ *.engine
208
+ *.profile
209
+
210
+ # zipfiles
211
+ *.zip
212
+ *.tar
213
+ *.tar.gz
214
+
215
+ # others
216
+ run_30.sh
217
+ ckpt
Installation.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 官方安装指南
2
+
3
+ * 在 requirements-detail.txt 里,我们提供了详细的各个库的版本,这个对应的环境是 `python3.10 + cuda12.2`。
4
+ * 本项目依赖于几个重要的pypi包,这几个包安装起来会有一些困难。
5
+
6
+ ### nvdiffrast 安装
7
+
8
+ * nvdiffrast 会在第一次运行时,编译对应的torch插件,这一步需要 ninja 及 cudatoolkit的支持。
9
+ * 因此需要先确保正确安装了 ninja 以及 cudatoolkit 并正确配置了 CUDA_HOME 环境变量。
10
+ * cudatoolkit 安装可以参考 [linux-cuda-installation-guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html), [windows-cuda-installation-guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html)
11
+ * ninja 则可用直接 `pip install ninja`
12
+ * 然后设置 CUDA_HOME 变量为 cudatoolkit 的安装目录,如 `/usr/local/cuda`。
13
+ * 最后 `pip install nvdiffrast` 即可。
14
+ * 如果无法在目标服务器上安装 cudatoolkit (如权限不够),可用使用我修改的[预编译版本 nvdiffrast](https://github.com/wukailu/nvdiffrast-torch) 在另一台拥有 cudatoolkit 且环境相似(python, torch, cuda版本相同)的服务器上预编译后安装。
15
+
16
+ ### onnxruntime-gpu 安装
17
+
18
+ * 注意,同时安装 `onnxruntime` 与 `onnxruntime-gpu` 可能导致最终程序无法运行在GPU,而运行在CPU,导致极慢的推理速度。
19
+ * [onnxruntime 官方安装指南](https://onnxruntime.ai/docs/install/#python-installs)
20
+ * TLDR: For cuda11.x, `pip install onnxruntime-gpu`. For cuda12.x, `pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
21
+ `.
22
+ * 进一步的,可用安装基于 tensorrt 的 onnxruntime,进一步加快推理速度。
23
+ * 注意:如果没有安装基于 tensorrt 的 onnxruntime,建议将 `https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4` 中 `TensorrtExecutionProvider` 删除。
24
+ * 对于 cuda12.x 可用使用如下命令快速安装带有tensorrt的onnxruntime (注意将 `/root/miniconda3/lib/python3.10/site-packages` 修改为你的python 对应路径,将 `/root/.bashrc` 改为你的用户下路径 `.bashrc` 路劲)
25
+ ```
26
+ pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/
27
+ pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/
28
+ pip install tensorrt==8.6.0
29
+ echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc
30
+ ```
31
+
32
+ ### pytorch3d 安装
33
+
34
+ * 根据 [pytorch3d 官方的安装建议](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux),建议使用预编译版本
35
+ ```
36
+ import sys
37
+ import torch
38
+ pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
39
+ version_str="".join([
40
+ f"py3{sys.version_info.minor}_cu",
41
+ torch.version.cuda.replace(".",""),
42
+ f"_pyt{pyt_version_str}"
43
+ ])
44
+ !pip install fvcore iopath
45
+ !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
46
+ ```
47
+
48
+ ### torch_scatter 安装
49
+
50
+ * 在[torch_scatter 官方安装指南](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation) 使用预编译的安装包快速安装。
51
+ * 或者直接编译安装 `pip install git+https://github.com/rusty1s/pytorch_scatter.git`
52
+
53
+ ### 其他安装
54
+
55
+ * 其他文件 `pip install -r requirements.txt` 即可。
56
+
57
+ -----
58
+
59
+ # Detailed Installation Guide
60
+
61
+ * In `requirements-detail.txt`, we provide detailed versions of all packages, which correspond to the environment of `python3.10 + cuda12.2`.
62
+ * This project relies on several important PyPI packages, which may be difficult to install.
63
+
64
+ ### Installation of nvdiffrast
65
+
66
+ * nvdiffrast will compile the corresponding torch plugin the first time it runs, which requires support from ninja and cudatoolkit.
67
+ * Therefore, it is necessary to ensure that ninja and cudatoolkit are correctly installed and that the CUDA_HOME environment variable is properly configured.
68
+ * For the installation of cudatoolkit, you can refer to the [Linux CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) and [Windows CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html).
69
+ * Ninja can be directly installed with `pip install ninja`.
70
+ * Then set the CUDA_HOME variable to the installation directory of cudatoolkit, such as `/usr/local/cuda`.
71
+ * Finally, `pip install nvdiffrast`.
72
+ * If you cannot install cudatoolkit on the computer (e.g., insufficient permissions), you can use my modified [pre-compiled version of nvdiffrast](https://github.com/wukailu/nvdiffrast-torch) to pre-compile on another computer that has cudatoolkit and a similar environment (same versions of python, torch, cuda) and then install the `.whl`.
73
+
74
+ ### Installation of onnxruntime-gpu
75
+
76
+ * Note that installing both `onnxruntime` and `onnxruntime-gpu` may result in not running on the GPU but on the CPU, leading to extremely slow inference speed.
77
+ * [Official ONNX Runtime Installation Guide](https://onnxruntime.ai/docs/install/#python-installs)
78
+ * TLDR: For cuda11.x, `pip install onnxruntime-gpu`. For cuda12.x, `pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`.
79
+ * Furthermore, you can install onnxruntime based on tensorrt to further increase the inference speed.
80
+ * Note: If you do not correctly installed onnxruntime based on tensorrt, it is recommended to remove `TensorrtExecutionProvider` from `https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4`.
81
+ * For cuda12.x, you can quickly install onnxruntime with tensorrt using the following commands (note to change the path `/root/miniconda3/lib/python3.10/site-packages` to the corresponding path of your python, and change `/root/.bashrc` to the path of `.bashrc` under your user directory):
82
+ ```
83
+ pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/
84
+ pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/
85
+ pip install tensorrt==8.6.0
86
+ echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc
87
+ ```
88
+
89
+ ### Installation of pytorch3d
90
+
91
+ * According to the [official installation recommendations of pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux), it is recommended to use the pre-compiled version:
92
+ ```
93
+ import sys
94
+ import torch
95
+ pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
96
+ version_str="".join([
97
+ f"py3{sys.version_info.minor}_cu",
98
+ torch.version.cuda.replace(".",""),
99
+ f"_pyt{pyt_version_str}"
100
+ ])
101
+ !pip install fvcore iopath
102
+ !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
103
+ ```
104
+
105
+ ### Installation of torch_scatter
106
+
107
+ * Use the pre-compiled installation package according to the [official installation guide of torch_scatter](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation) for a quick installation.
108
+ * Alternatively, you can directly compile and install with `pip install git+https://github.com/rusty1s/pytorch_scatter.git`.
109
+
110
+ ### Other Installations
111
+
112
+ * For other packages, simply `pip install -r requirements.txt`.
113
+
114
+ -----
115
+
116
+ # 官方インストールガイド
117
+
118
+ * `requirements-detail.txt` には、各ライブラリのバージョンが詳細に提供されており、これは Python 3.10 + CUDA 12.2 に対応する環境です。
119
+ * このプロジェクトは、いくつかの重要な PyPI パッケージに依存しており、これらのパッケージのインストールにはいくつかの困難が伴います。
120
+
121
+ ### nvdiffrast のインストール
122
+
123
+ * nvdiffrast は、最初に実行するときに、torch プラグインの対応バージョンをコンパイルします。このステップには、ninja および cudatoolkit のサポートが必要です。
124
+ * したがって、ninja および cudatoolkit の正確なインストールと、CUDA_HOME 環境変数の正確な設定を確保する必要があります。
125
+ * cudatoolkit のインストールについては、[Linux CUDA インストールガイド](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)、[Windows CUDA インストールガイド](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) を参照してください。
126
+ * ninja は、直接 `pip install ninja` でインストールできます。
127
+ * 次に、CUDA_HOME 変数を cudatoolkit のインストールディレクトリに設定します。例えば、`/usr/local/cuda` のように。
128
+ * 最後に、`pip install nvdiffrast` を実行します。
129
+ * 目標サーバーで cudatoolkit をインストールできない場合(例えば、権限が不足している場合)、私の修正した[事前コンパイル済みバージョンの nvdiffrast](https://github.com/wukailu/nvdiffrast-torch)を使用できます。これは、cudatoolkit があり、環境が似ている(Python、torch、cudaのバージョンが同じ)別のサーバーで事前コンパイルしてからインストールすることができます。
130
+
131
+ ### onnxruntime-gpu のインストール
132
+
133
+ * 注意:`onnxruntime` と `onnxruntime-gpu` を同時にインストールすると、最終的なプログラムが GPU 上で実行されず、CPU 上で実行される可能性があり、推論速度���非常に遅くなることがあります。
134
+ * [onnxruntime 公式インストールガイド](https://onnxruntime.ai/docs/install/#python-installs)
135
+ * TLDR: cuda11.x 用には、`pip install onnxruntime-gpu` を使用します。cuda12.x 用には、`pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/` を使用します。
136
+ * さらに、TensorRT ベースの onnxruntime をインストールして、推論速度をさらに向上させることができます。
137
+ * 注意:TensorRT ベースの onnxruntime がインストールされていない場合は、`https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4` の `TensorrtExecutionProvider` を削除することをお勧めします。
138
+ * cuda12.x の場合、次のコマンドを使用して迅速に TensorRT を備えた onnxruntime をインストールできます(`/root/miniconda3/lib/python3.10/site-packages` をあなたの Python に対応するパスに、`/root/.bashrc` をあなたのユーザーのパスの下の `.bashrc` に変更してください)。
139
+ ```bash
140
+ pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/
141
+ pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/
142
+ pip install tensorrt==8.6.0
143
+ echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc
144
+ ```
145
+
146
+ ### pytorch3d のインストール
147
+
148
+ * [pytorch3d 公式のインストール提案](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux)に従い、事前コンパイル済みバージョンを使用することをお勧めします。
149
+ ```python
150
+ import sys
151
+ import torch
152
+ pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
153
+ version_str="".join([
154
+ f"py3{sys.version_info.minor}_cu",
155
+ torch.version.cuda.replace(".",""),
156
+ f"_pyt{pyt_version_str}"
157
+ ])
158
+ !pip install fvcore iopath
159
+ !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
160
+ ```
161
+
162
+ ### torch_scatter のインストール
163
+
164
+ * [torch_scatter 公式インストールガイド](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation)に従い、事前コンパイル済みのインストールパッケージを使用して迅速インストールします。
165
+ * または、直接コンパイルしてインストールする `pip install git+https://github.com/rusty1s/pytorch_scatter.git` も可能です。
166
+
167
+ ### その他のインストール
168
+
169
+ * その他のファイルについては、`pip install -r requirements.txt` を実行するだけです。
170
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 AiuniAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,137 @@
1
  ---
2
- title: 3D Genesis
3
- emoji: 🏆
4
- colorFrom: indigo
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.5.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: 3D-Genesis
3
+ app_file: gradio_app.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.5.0
 
 
6
  ---
7
+ **[中文版本](README_zh.md)**
8
 
9
+ **[日本語版](README_jp.md)**
10
+
11
+ # Unique3D
12
+ Official implementation of Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image.
13
+
14
+ [Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/)
15
+
16
+ ## [Paper](https://arxiv.org/abs/2405.20343) | [Project page](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradio Demo](http://unique3d.demo.avar.cn/) | [Online Demo](https://www.aiuni.ai/)
17
+
18
+ * Demo inference speed: Gradio Demo > Huggingface Demo > Huggingface Demo2 > Online Demo
19
+
20
+ **If the Gradio Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.**
21
+
22
+ <p align="center">
23
+ <img src="assets/teaser_safe.jpg">
24
+ </p>
25
+
26
+ High-fidelity and diverse textured meshes generated by Unique3D from single-view wild images in 30 seconds.
27
+
28
+ ## More features
29
+
30
+ The repo is still being under construction, thanks for your patience.
31
+ - [x] Upload weights.
32
+ - [x] Local gradio demo.
33
+ - [x] Detailed tutorial.
34
+ - [x] Huggingface demo.
35
+ - [ ] Detailed local demo.
36
+ - [x] Comfyui support.
37
+ - [x] Windows support.
38
+ - [x] Docker support.
39
+ - [ ] More stable reconstruction with normal.
40
+ - [ ] Training code release.
41
+
42
+ ## Preparation for inference
43
+
44
+ * [Detailed linux installation guide](Installation.md).
45
+
46
+ ### Linux System Setup.
47
+
48
+ Adapted for Ubuntu 22.04.4 LTS and CUDA 12.1.
49
+ ```angular2html
50
+ conda create -n unique3d python=3.11
51
+ conda activate unique3d
52
+
53
+ pip install ninja
54
+ pip install diffusers==0.27.2
55
+
56
+ pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3.1/index.html
57
+
58
+ pip install -r requirements.txt
59
+ ```
60
+
61
+ [oak-barry](https://github.com/oak-barry) provide another setup script for torch210+cu121 at [here](https://github.com/oak-barry/Unique3D).
62
+
63
+ ### Windows Setup.
64
+
65
+ * Thank you very much `jtydhr88` for the windows installation method! See [issues/15](https://github.com/AiuniAI/Unique3D/issues/15).
66
+
67
+ According to [issues/15](https://github.com/AiuniAI/Unique3D/issues/15), implemented a bat script to run the commands, so you can:
68
+ 1. Might still require Visual Studio Build Tools, you can find it from [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/?q=build+tools).
69
+ 2. Create conda env and activate it
70
+ 1. `conda create -n unique3d-py311 python=3.11`
71
+ 2. `conda activate unique3d-py311`
72
+ 3. download [triton whl](https://huggingface.co/madbuda/triton-windows-builds/resolve/main/triton-2.1.0-cp311-cp311-win_amd64.whl) for py311, and put it into this project.
73
+ 4. run **install_windows_win_py311_cu121.bat**
74
+ 5. answer y while asking you uninstall onnxruntime and onnxruntime-gpu
75
+ 6. create the output folder **tmp\gradio** under the driver root, such as F:\tmp\gradio for me.
76
+ 7. python app/gradio_local.py --port 7860
77
+
78
+ More details prefer to [issues/15](https://github.com/AiuniAI/Unique3D/issues/15).
79
+
80
+ ### Interactive inference: run your local gradio demo.
81
+
82
+ 1. Download the weights from [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt) or [Tsinghua Cloud Drive](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/), and extract it to `ckpt/*`.
83
+ ```
84
+ Unique3D
85
+ ├──ckpt
86
+ ├── controlnet-tile/
87
+ ├── image2normal/
88
+ ├── img2mvimg/
89
+ ├── realesrgan-x4.onnx
90
+ └── v1-inference.yaml
91
+ ```
92
+
93
+ 2. Run the interactive inference locally.
94
+ ```bash
95
+ python app/gradio_local.py --port 7860
96
+ ```
97
+
98
+ ## ComfyUI Support
99
+
100
+ Thanks for the [ComfyUI-Unique3D](https://github.com/jtydhr88/ComfyUI-Unique3D) implementation from [jtydhr88](https://github.com/jtydhr88)!
101
+
102
+ ## Tips to get better results
103
+
104
+ **Important: Because the mesh is normalized by the longest edge of xyz during training, it is desirable that the input image needs to contain the longest edge of the object during inference, or else you may get erroneously squashed results.**
105
+ 1. Unique3D is sensitive to the facing direction of input images. Due to the distribution of the training data, orthographic front-facing images with a rest pose always lead to good reconstructions.
106
+ 2. Images with occlusions will cause worse reconstructions, since four views cannot cover the complete object. Images with fewer occlusions lead to better results.
107
+ 3. Pass an image with as high a resolution as possible to the input when resolution is a factor.
108
+
109
+ ## Acknowledgement
110
+
111
+ We have intensively borrowed code from the following repositories. Many thanks to the authors for sharing their code.
112
+ - [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
113
+ - [Wonder3d](https://github.com/xxlong0/Wonder3D)
114
+ - [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
115
+ - [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
116
+ - [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
117
+
118
+ ## Collaborations
119
+ Our mission is to create a 4D generative model with 3D concepts. This is just our first step, and the road ahead is still long, but we are confident. We warmly invite you to join the discussion and explore potential collaborations in any capacity. <span style="color:red">**If you're interested in connecting or partnering with us, please don't hesitate to reach out via email (wkl22@mails.tsinghua.edu.cn)**</span>.
120
+
121
+ - Follow us on twitter for the latest updates: https://x.com/aiuni_ai
122
+ - Join AIGC 3D/4D generation community on discord: https://discord.gg/aiuni
123
+ - Research collaboration, please contact: ai@aiuni.ai
124
+
125
+ ## Citation
126
+
127
+ If you found Unique3D helpful, please cite our report:
128
+ ```bibtex
129
+ @misc{wu2024unique3d,
130
+ title={Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image},
131
+ author={Kailu Wu and Fangfu Liu and Zhihan Cai and Runjie Yan and Hanyang Wang and Yating Hu and Yueqi Duan and Kaisheng Ma},
132
+ year={2024},
133
+ eprint={2405.20343},
134
+ archivePrefix={arXiv},
135
+ primaryClass={cs.CV}
136
+ }
137
+ ```
README_jp.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **他の言語のバージョン [英語](README.md) [中国語](README_zh.md)**
2
+
3
+ # Unique3D
4
+ Unique3D: 単一画像からの高品質かつ効率的な3Dメッシュ生成の公式実装。
5
+
6
+ [Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/)
7
+
8
+ ## [論文](https://arxiv.org/abs/2405.20343) | [プロジェクトページ](https://wukailu.github.io/Unique3D/) | [Huggingfaceデモ](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradioデモ](http://unique3d.demo.avar.cn/) | [オンラインデモ](https://www.aiuni.ai/)
9
+
10
+ * デモ推論速度: Gradioデモ > Huggingfaceデモ > Huggingfaceデモ2 > オンラインデモ
11
+
12
+ **Gradioデモが残念ながらハングアップしたり、非常に混雑している場合は、[aiuni.ai](https://www.aiuni.ai/)のオンラインデモを使用できます。これは無料で試すことができます(登録招待コードを取得するには、Discordに参加してください: https://discord.gg/aiuni)。ただし、オンラインデモはGradioデモとは少し異なり、推論速度が遅く、生成結果が安定していない可能性がありますが、素材の品質は良いです。**
13
+
14
+ <p align="center">
15
+ <img src="assets/teaser_safe.jpg">
16
+ </p>
17
+
18
+ Unique3Dは、野生の単一画像から高忠実度および多様なテクスチャメッシュを30秒で生成します。
19
+
20
+ ## より多くの機能
21
+
22
+ リポジトリはまだ構築中です。ご理解いただきありがとうございます。
23
+ - [x] 重みのアップロード。
24
+ - [x] ローカルGradioデモ。
25
+ - [ ] 詳細なチュートリアル。
26
+ - [x] Huggingfaceデモ。
27
+ - [ ] 詳細なローカルデモ。
28
+ - [x] Comfyuiサポート。
29
+ - [x] Windowsサポート。
30
+ - [ ] Dockerサポート。
31
+ - [ ] ノーマルでより安定した再構築。
32
+ - [ ] トレーニングコードのリリース。
33
+
34
+ ## 推論の準備
35
+
36
+ ### Linuxシステムセットアップ
37
+
38
+ Ubuntu 22.04.4 LTSおよびCUDA 12.1に適応。
39
+ ```angular2html
40
+ conda create -n unique3d python=3.11
41
+ conda activate unique3d
42
+
43
+ pip install ninja
44
+ pip install diffusers==0.27.2
45
+
46
+ pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3.1/index.html
47
+
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ [oak-barry](https://github.com/oak-barry)は、[こちら](https://github.com/oak-barry/Unique3D)でtorch210+cu121の別のセットアップスクリプトを提供しています。
52
+
53
+ ### Windowsセットアップ
54
+
55
+ * `jtydhr88`によるWindowsインストール方法に非常に感謝します![issues/15](https://github.com/AiuniAI/Unique3D/issues/15)を参照してください。
56
+
57
+ [issues/15](https://github.com/AiuniAI/Unique3D/issues/15)によると、コマンドを実行するバッチスクリプトを実装したので、以下の手順に従ってください。
58
+ 1. [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/?q=build+tools)からVisual Studio Build Toolsが必要になる場合があります。
59
+ 2. conda envを作成し、アクティブにします。
60
+ 1. `conda create -n unique3d-py311 python=3.11`
61
+ 2. `conda activate unique3d-py311`
62
+ 3. [triton whl](https://huggingface.co/madbuda/triton-windows-builds/resolve/main/triton-2.1.0-cp311-cp311-win_amd64.whl)をダウンロードし、このプロジェクトに配置します。
63
+ 4. **install_windows_win_py311_cu121.bat**を実行します。
64
+ 5. onnxruntimeおよびonnxruntime-gpuのアンインストールを求められた場合は、yと回答します。
65
+ 6. ドライバールートの下に**tmp\gradio**フォルダを作成します(例:F:\tmp\gradio)。
66
+ 7. python app/gradio_local.py --port 7860
67
+
68
+ 詳細は[issues/15](https://github.com/AiuniAI/Unique3D/issues/15)を参照してください。
69
+
70
+ ### インタラクティブ推論:ローカルGradioデモを実行する
71
+
72
+ 1. [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt)または[Tsinghua Cloud Drive](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/)から重みをダウンロードし、`ckpt/*`に抽出します。
73
+ ```
74
+ Unique3D
75
+ ├──ckpt
76
+ ├── controlnet-tile/
77
+ ├── image2normal/
78
+ ├── img2mvimg/
79
+ ├── realesrgan-x4.onnx
80
+ └── v1-inference.yaml
81
+ ```
82
+
83
+ 2. インタラクティブ推論をローカルで実行します。
84
+ ```bash
85
+ python app/gradio_local.py --port 7860
86
+ ```
87
+
88
+ ## ComfyUIサポート
89
+
90
+ [jtydhr88](https://github.com/jtydhr88)からの[ComfyUI-Unique3D](https://github.com/jtydhr88/ComfyUI-Unique3D)の実装に感謝します!
91
+
92
+ ## より良い結果を得るためのヒント
93
+
94
+ 1. Unique3Dは入力画像の向きに敏感です。トレーニングデータの分布により、正面を向いた直交画像は常に良い再構築につながります。
95
+ 2. 遮���のある画像は、4つのビューがオブジェクトを完全にカバーできないため、再構築が悪化します。遮蔽の少ない画像は、より良い結果につながります。
96
+ 3. 可能な限り高解像度の画像を入力として使用してください。
97
+
98
+ ## 謝辞
99
+
100
+ 以下のリポジトリからコードを大量に借用しました。コードを共有してくれた著者に感謝します。
101
+ - [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
102
+ - [Wonder3d](https://github.com/xxlong0/Wonder3D)
103
+ - [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
104
+ - [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
105
+ - [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
106
+
107
+ ## コラボレーション
108
+ 私たちの使命は、3Dの概念を持つ4D生成モデルを作成することです。これは私たちの最初のステップであり、前途はまだ長いですが、私たちは自信を持っています。あらゆる形態の潜在的なコラボレーションを探求し、議論に参加することを心から歓迎します。<span style="color:red">**私たちと連絡を取りたい、またはパートナーシップを結びたい方は、メールでお気軽にお問い合わせください (wkl22@mails.tsinghua.edu.cn)**</span>。
109
+
110
+ - 最新情報を入手するには、Twitterをフォローしてください: https://x.com/aiuni_ai
111
+ - DiscordでAIGC 3D/4D生成コミュニティに参加してください: https://discord.gg/aiuni
112
+ - 研究協力については、ai@aiuni.aiまでご連絡ください。
113
+
114
+ ## 引用
115
+
116
+ Unique3Dが役立つと思われる場合は、私たちのレポートを引用してください:
117
+ ```bibtex
118
+ @misc{wu2024unique3d,
119
+ title={Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image},
120
+ author={Kailu Wu and Fangfu Liu and Zhihan Cai and Runjie Yan and Hanyang Wang and Yating Hu and Yueqi Duan and Kaisheng Ma},
121
+ year={2024},
122
+ eprint={2405.20343},
123
+ archivePrefix={arXiv},
124
+ primaryClass={cs.CV}
125
+ }
126
+ ```
README_zh.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **其他语言版本 [English](README.md)**
2
+
3
+ # Unique3D
4
+ High-Quality and Efficient 3D Mesh Generation from a Single Image
5
+
6
+ [Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/)
7
+
8
+ ## [论文](https://arxiv.org/abs/2405.20343) | [项目页面](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradio Demo](http://unique3d.demo.avar.cn/) | [在线演示](https://www.aiuni.ai/)
9
+
10
+
11
+
12
+ <p align="center">
13
+ <img src="assets/teaser_safe.jpg">
14
+ </p>
15
+
16
+ Unique3D从单视图图像生成高保真度和多样化纹理的网格,在4090上大约需要30秒。
17
+
18
+ ### 推理准备
19
+
20
+ #### Linux系统设置
21
+ ```angular2html
22
+ conda create -n unique3d
23
+ conda activate unique3d
24
+ pip install -r requirements.txt
25
+ ```
26
+
27
+ #### 交互式推理:运行您的本地gradio演示
28
+
29
+ 1. 从 [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt) 下载或者从[清华云盘](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/)下载权重,并将其解压到`ckpt/*`。
30
+ ```
31
+ Unique3D
32
+ ├──ckpt
33
+ ├── controlnet-tile/
34
+ ├── image2normal/
35
+ ├── img2mvimg/
36
+ ├── realesrgan-x4.onnx
37
+ └── v1-inference.yaml
38
+ ```
39
+
40
+ 2. 在本地运行交互式推理。
41
+ ```bash
42
+ python app/gradio_local.py --port 7860
43
+ ```
44
+
45
+ ## 获取更好结果的提示
46
+
47
+ 1. Unique3D对输入图像的朝向非常敏感。由于训练数据的分布,**正交正视图像**通常总是能带来良好的重建。对于人物而言,最好是 A-pose 或者 T-pose,因为目前训练数据很少含有其他类型姿态。
48
+ 2. 有遮挡的图像会导致更差的重建,因为4个视图无法覆盖完整的对象。遮挡较少的图像会带来更好的结果。
49
+ 3. 尽可能将高分辨率的图像用作输入。
50
+
51
+ ## 致谢
52
+
53
+ 我们借用了以下代码库的代码。非常感谢作者们分享他们的代码。
54
+ - [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
55
+ - [Wonder3d](https://github.com/xxlong0/Wonder3D)
56
+ - [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
57
+ - [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
58
+ - [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
59
+
60
+ ## 合作
61
+
62
+ 我们使命是创建一个具有3D概念的4D生成模型。这只是我们的第一步,前方的道路仍然很长,但我们有信心。我们热情邀请您加入讨论,并探索任何形式的潜在合作。<span style="color:red">**如果您有兴趣联系或与我们合作,欢迎通过电子邮件(wkl22@mails.tsinghua.edu.cn)与我们联系**</span>。
app/__init__.py ADDED
File without changes
app/all_models.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scripts.sd_model_zoo import load_common_sd15_pipe
3
+ from diffusers import StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
4
+
5
+
6
+ class MyModelZoo:
7
+ _pipe_disney_controlnet_lineart_ipadapter_i2i: StableDiffusionControlNetImg2ImgPipeline = None
8
+
9
+ base_model = "runwayml/stable-diffusion-v1-5"
10
+
11
+ def __init__(self, base_model=None) -> None:
12
+ if base_model is not None:
13
+ self.base_model = base_model
14
+
15
+ @property
16
+ def pipe_disney_controlnet_tile_ipadapter_i2i(self):
17
+ return self._pipe_disney_controlnet_lineart_ipadapter_i2i
18
+
19
+ def init_models(self):
20
+ self._pipe_disney_controlnet_lineart_ipadapter_i2i = load_common_sd15_pipe(base_model=self.base_model, ip_adapter=True, plus_model=False, controlnet="./ckpt/controlnet-tile", pipeline_class=StableDiffusionControlNetImg2ImgPipeline)
21
+
22
+ model_zoo = MyModelZoo()
app/custom_models/image2mvimage.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: "./ckpt/img2mvimg"
2
+ mixed_precision: "bf16"
3
+
4
+ init_config:
5
+ # enable controls
6
+ enable_cross_attn_lora: False
7
+ enable_cross_attn_ip: False
8
+ enable_self_attn_lora: False
9
+ enable_self_attn_ref: False
10
+ enable_multiview_attn: True
11
+
12
+ # for cross attention
13
+ init_cross_attn_lora: False
14
+ init_cross_attn_ip: False
15
+ cross_attn_lora_rank: 256 # 0 for not enabled
16
+ cross_attn_lora_only_kv: False
17
+ ipadapter_pretrained_name: "h94/IP-Adapter"
18
+ ipadapter_subfolder_name: "models"
19
+ ipadapter_weight_name: "ip-adapter_sd15.safetensors"
20
+ ipadapter_effect_on: "all" # all, first
21
+
22
+ # for self attention
23
+ init_self_attn_lora: False
24
+ self_attn_lora_rank: 256
25
+ self_attn_lora_only_kv: False
26
+
27
+ # for self attention ref
28
+ init_self_attn_ref: False
29
+ self_attn_ref_position: "attn1"
30
+ self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
31
+ self_attn_ref_pixel_wise_crosspond: False
32
+ self_attn_ref_effect_on: "all"
33
+
34
+ # for multiview attention
35
+ init_multiview_attn: True
36
+ multiview_attn_position: "attn1"
37
+ use_mv_joint_attn: True
38
+ num_modalities: 1
39
+
40
+ # for unet
41
+ init_unet_path: "${pretrained_model_name_or_path}"
42
+ cat_condition: True # cat condition to input
43
+
44
+ # for cls embedding
45
+ init_num_cls_label: 8 # for initialize
46
+ cls_labels: [0, 1, 2, 3] # for current task
47
+
48
+ trainers:
49
+ - trainer_type: "image2mvimage_trainer"
50
+ trainer:
51
+ pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
52
+ attn_config:
53
+ cls_labels: [0, 1, 2, 3] # for current task
54
+ enable_cross_attn_lora: False
55
+ enable_cross_attn_ip: False
56
+ enable_self_attn_lora: False
57
+ enable_self_attn_ref: False
58
+ enable_multiview_attn: True
59
+ resolution: "256"
60
+ condition_image_resolution: "256"
61
+ normal_cls_offset: 4
62
+ condition_image_column_name: "conditioning_image"
63
+ image_column_name: "image"
app/custom_models/image2normal.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: "lambdalabs/sd-image-variations-diffusers"
2
+ mixed_precision: "bf16"
3
+
4
+ init_config:
5
+ # enable controls
6
+ enable_cross_attn_lora: False
7
+ enable_cross_attn_ip: False
8
+ enable_self_attn_lora: False
9
+ enable_self_attn_ref: True
10
+ enable_multiview_attn: False
11
+
12
+ # for cross attention
13
+ init_cross_attn_lora: False
14
+ init_cross_attn_ip: False
15
+ cross_attn_lora_rank: 512 # 0 for not enabled
16
+ cross_attn_lora_only_kv: False
17
+ ipadapter_pretrained_name: "h94/IP-Adapter"
18
+ ipadapter_subfolder_name: "models"
19
+ ipadapter_weight_name: "ip-adapter_sd15.safetensors"
20
+ ipadapter_effect_on: "all" # all, first
21
+
22
+ # for self attention
23
+ init_self_attn_lora: False
24
+ self_attn_lora_rank: 512
25
+ self_attn_lora_only_kv: False
26
+
27
+ # for self attention ref
28
+ init_self_attn_ref: True
29
+ self_attn_ref_position: "attn1"
30
+ self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
31
+ self_attn_ref_pixel_wise_crosspond: True
32
+ self_attn_ref_effect_on: "all"
33
+
34
+ # for multiview attention
35
+ init_multiview_attn: False
36
+ multiview_attn_position: "attn1"
37
+ num_modalities: 1
38
+
39
+ # for unet
40
+ init_unet_path: "${pretrained_model_name_or_path}"
41
+ init_num_cls_label: 0 # for initialize
42
+ cls_labels: [] # for current task
43
+
44
+ trainers:
45
+ - trainer_type: "image2image_trainer"
46
+ trainer:
47
+ pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
48
+ attn_config:
49
+ cls_labels: [] # for current task
50
+ enable_cross_attn_lora: False
51
+ enable_cross_attn_ip: False
52
+ enable_self_attn_lora: False
53
+ enable_self_attn_ref: True
54
+ enable_multiview_attn: False
55
+ resolution: "512"
56
+ condition_image_resolution: "512"
57
+ condition_image_column_name: "conditioning_image"
58
+ image_column_name: "image"
59
+
60
+
61
+
app/custom_models/mvimg_prediction.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import numpy as np
6
+ from rembg import remove
7
+ from app.utils import change_rgba_bg, rgba_to_rgb
8
+ from app.custom_models.utils import load_pipeline
9
+ from scripts.all_typing import *
10
+ from scripts.utils import session, simple_preprocess
11
+
12
+ training_config = "app/custom_models/image2mvimage.yaml"
13
+ checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
14
+ trainer, pipeline = load_pipeline(training_config, checkpoint_path)
15
+ # pipeline.enable_model_cpu_offload()
16
+
17
+ def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
18
+ if isinstance(img_list, Image.Image):
19
+ img_list = [img_list]
20
+ img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
21
+ ret = []
22
+ for img in img_list:
23
+ images = trainer.pipeline_forward(
24
+ pipeline=pipeline,
25
+ image=img,
26
+ guidance_scale=guidance_scale,
27
+ **kwargs
28
+ ).images
29
+ ret.extend(images)
30
+ return ret
31
+
32
+
33
+ def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145):
34
+ if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.:
35
+ # still do remove using rembg, since simple_preprocess requires RGBA image
36
+ print("RGB image not RGBA! still remove bg!")
37
+ remove_bg = True
38
+
39
+ if remove_bg:
40
+ input_image = remove(input_image, session=session)
41
+
42
+ # make front_pil RGBA with white bg
43
+ input_image = change_rgba_bg(input_image, "white")
44
+ single_image = simple_preprocess(input_image)
45
+
46
+ generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None
47
+
48
+ rgb_pils = predict(
49
+ single_image,
50
+ generator=generator,
51
+ guidance_scale=guidance_scale,
52
+ width=256,
53
+ height=256,
54
+ num_inference_steps=30,
55
+ )
56
+
57
+ return rgb_pils, single_image
app/custom_models/normal_prediction.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from PIL import Image
3
+ from app.utils import rgba_to_rgb, simple_remove
4
+ from app.custom_models.utils import load_pipeline
5
+ from scripts.utils import rotate_normals_torch
6
+ from scripts.all_typing import *
7
+
8
+ training_config = "app/custom_models/image2normal.yaml"
9
+ checkpoint_path = "ckpt/image2normal/unet_state_dict.pth"
10
+ trainer, pipeline = load_pipeline(training_config, checkpoint_path)
11
+ # pipeline.enable_model_cpu_offload()
12
+
13
+ def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs):
14
+ img_list = image if isinstance(image, list) else [image]
15
+ img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
16
+ images = trainer.pipeline_forward(
17
+ pipeline=pipeline,
18
+ image=img_list,
19
+ num_inference_steps=num_inference_steps,
20
+ guidance_scale=guidance_scale,
21
+ **kwargs
22
+ ).images
23
+ images = simple_remove(images)
24
+ if do_rotate and len(images) > 1:
25
+ images = rotate_normals_torch(images, return_types='pil')
26
+ return images
app/custom_models/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List
3
+ from dataclasses import dataclass
4
+ from app.utils import rgba_to_rgb
5
+ from custum_3d_diffusion.trainings.config_classes import ExprimentConfig, TrainerSubConfig
6
+ from custum_3d_diffusion import modules
7
+ from custum_3d_diffusion.custum_modules.unifield_processor import AttnConfig, ConfigurableUNet2DConditionModel
8
+ from custum_3d_diffusion.trainings.base import BasicTrainer
9
+ from custum_3d_diffusion.trainings.utils import load_config
10
+
11
+
12
+ @dataclass
13
+ class FakeAccelerator:
14
+ device: torch.device = torch.device("cuda")
15
+
16
+
17
+ def init_trainers(cfg_path: str, weight_dtype: torch.dtype, extras: dict):
18
+ accelerator = FakeAccelerator()
19
+ cfg: ExprimentConfig = load_config(ExprimentConfig, cfg_path, extras)
20
+ init_config: AttnConfig = load_config(AttnConfig, cfg.init_config)
21
+ configurable_unet = ConfigurableUNet2DConditionModel(init_config, weight_dtype)
22
+ configurable_unet.enable_xformers_memory_efficient_attention()
23
+ trainer_cfgs: List[TrainerSubConfig] = [load_config(TrainerSubConfig, trainer) for trainer in cfg.trainers]
24
+ trainers: List[BasicTrainer] = [modules.find(trainer.trainer_type)(accelerator, None, configurable_unet, trainer.trainer, weight_dtype, i) for i, trainer in enumerate(trainer_cfgs)]
25
+ return trainers, configurable_unet
26
+
27
+ from app.utils import make_image_grid, split_image
28
+ def process_image(function, img, guidance_scale=2., merged_image=False, remove_bg=True):
29
+ from rembg import remove
30
+ if remove_bg:
31
+ img = remove(img)
32
+ img = rgba_to_rgb(img)
33
+ if merged_image:
34
+ img = split_image(img, rows=2)
35
+ images = function(
36
+ image=img,
37
+ guidance_scale=guidance_scale,
38
+ )
39
+ if len(images) > 1:
40
+ return make_image_grid(images, rows=2)
41
+ else:
42
+ return images[0]
43
+
44
+
45
+ def process_text(trainer, pipeline, img, guidance_scale=2.):
46
+ pipeline.cfg.validation_prompts = [img]
47
+ titles, images = trainer.batched_validation_forward(pipeline, guidance_scale=[guidance_scale])
48
+ return images[0]
49
+
50
+
51
+ def load_pipeline(config_path, ckpt_path, pipeline_filter=lambda x: True, weight_dtype = torch.bfloat16):
52
+ training_config = config_path
53
+ load_from_checkpoint = ckpt_path
54
+ extras = []
55
+ device = "cuda"
56
+ trainers, configurable_unet = init_trainers(training_config, weight_dtype, extras)
57
+ shared_modules = dict()
58
+ for trainer in trainers:
59
+ shared_modules = trainer.init_shared_modules(shared_modules)
60
+
61
+ if load_from_checkpoint is not None:
62
+ state_dict = torch.load(load_from_checkpoint)
63
+ configurable_unet.unet.load_state_dict(state_dict, strict=False)
64
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
65
+ configurable_unet.unet.to(device, dtype=weight_dtype)
66
+
67
+ pipeline = None
68
+ trainer_out = None
69
+ for trainer in trainers:
70
+ if pipeline_filter(trainer.cfg.trainer_name):
71
+ pipeline = trainer.construct_pipeline(shared_modules, configurable_unet.unet)
72
+ pipeline.set_progress_bar_config(disable=False)
73
+ trainer_out = trainer
74
+ pipeline = pipeline.to(device)
75
+ return trainer_out, pipeline
app/examples/Groot.png ADDED
app/examples/aaa.png ADDED
app/examples/abma.png ADDED
app/examples/akun.png ADDED
app/examples/anya.png ADDED
app/examples/bag.png ADDED

Git LFS Details

  • SHA256: ac798ea1f112091c04f5bdfa47c490806fb433a02fe17758aa1f8c55cd64b66e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.54 MB
app/examples/ex1.png ADDED

Git LFS Details

  • SHA256: d49ccccd40fe0317c2886b0d36a11667003d17a49cc49d9244208d250de9fe31
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
app/examples/ex2.png ADDED
app/examples/ex3.jpg ADDED
app/examples/ex4.png ADDED
app/examples/generated_1715761545_frame0.png ADDED
app/examples/generated_1715762357_frame0.png ADDED
app/examples/generated_1715763329_frame0.png ADDED
app/examples/hatsune_miku.png ADDED
app/examples/princess-large.png ADDED
app/gradio_3dgen.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from pytorch3d.structures import Meshes
5
+ from app.utils import clean_up
6
+ from app.custom_models.mvimg_prediction import run_mvprediction
7
+ from app.custom_models.normal_prediction import predict_normals
8
+ from scripts.refine_lr_to_sr import run_sr_fast
9
+ from scripts.utils import save_glb_and_video
10
+ from scripts.multiview_inference import geo_reconstruct
11
+
12
+ def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
13
+ if preview_img is None:
14
+ raise gr.Error("preview_img is none")
15
+ if isinstance(preview_img, str):
16
+ preview_img = Image.open(preview_img)
17
+
18
+ if preview_img.size[0] <= 512:
19
+ preview_img = run_sr_fast([preview_img])[0]
20
+ rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
21
+ new_meshes = geo_reconstruct(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
22
+ vertices = new_meshes.verts_packed()
23
+ vertices = vertices / 2 * 1.35
24
+ vertices[..., [0, 2]] = - vertices[..., [0, 2]]
25
+ new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures)
26
+
27
+ ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video)
28
+ return ret_mesh, video
29
+
30
+ #######################################
31
+ def create_ui(concurrency_id="wkl"):
32
+ with gr.Row():
33
+ with gr.Column(scale=2):
34
+ input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
35
+
36
+ example_folder = os.path.join(os.path.dirname(__file__), "./examples")
37
+ example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)])
38
+ gr.Examples(
39
+ examples=example_fns,
40
+ inputs=[input_image],
41
+ cache_examples=False,
42
+ label='Examples (click one of the images below to start)',
43
+ examples_per_page=12
44
+ )
45
+
46
+
47
+ with gr.Column(scale=3):
48
+ # export mesh display
49
+ output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320)
50
+ output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
51
+
52
+ input_processing = gr.Checkbox(
53
+ value=True,
54
+ label='Remove Background',
55
+ visible=True,
56
+ )
57
+ do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False)
58
+ expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
59
+ init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False)
60
+ setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed")
61
+ render_video = gr.Checkbox(value=False, visible=False, label="generate video")
62
+ fullrunv2_btn = gr.Button('Generate 3D', interactive=True)
63
+
64
+ fullrunv2_btn.click(
65
+ fn = generate3dv2,
66
+ inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type],
67
+ outputs=[output_mesh, output_video],
68
+ concurrency_id=concurrency_id,
69
+ api_name="generate3dv2",
70
+ ).success(clean_up, api_name=False)
71
+ return input_image
app/gradio_3dgen_steps.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ from app.custom_models.mvimg_prediction import run_mvprediction
5
+ from app.utils import make_image_grid, split_image
6
+ from scripts.utils import save_glb_and_video
7
+
8
+ def concept_to_multiview(preview_img, input_processing, seed, guidance=1.):
9
+ seed = int(seed)
10
+ if preview_img is None:
11
+ raise gr.Error("preview_img is none.")
12
+ if isinstance(preview_img, str):
13
+ preview_img = Image.open(preview_img)
14
+
15
+ rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=seed, guidance_scale=guidance)
16
+ rgb_pil = make_image_grid(rgb_pils, rows=2)
17
+ return rgb_pil, front_pil
18
+
19
+ def concept_to_multiview_ui(concurrency_id="wkl"):
20
+ with gr.Row():
21
+ with gr.Column(scale=2):
22
+ preview_img = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
23
+ input_processing = gr.Checkbox(
24
+ value=True,
25
+ label='Remove Background',
26
+ )
27
+ seed = gr.Slider(minimum=-1, maximum=1000000000, value=-1, step=1.0, label="seed")
28
+ guidance = gr.Slider(minimum=1.0, maximum=5.0, value=1.0, label="Guidance Scale", step=0.5)
29
+ run_btn = gr.Button('Generate Multiview', interactive=True)
30
+ with gr.Column(scale=3):
31
+ # export mesh display
32
+ output_rgb = gr.Image(type='pil', label="RGB", show_label=True)
33
+ output_front = gr.Image(type='pil', image_mode='RGBA', label="Frontview", show_label=True)
34
+ run_btn.click(
35
+ fn = concept_to_multiview,
36
+ inputs=[preview_img, input_processing, seed, guidance],
37
+ outputs=[output_rgb, output_front],
38
+ concurrency_id=concurrency_id,
39
+ api_name=False,
40
+ )
41
+ return output_rgb, output_front
42
+
43
+ from app.custom_models.normal_prediction import predict_normals
44
+ from scripts.multiview_inference import geo_reconstruct
45
+ def multiview_to_mesh_v2(rgb_pil, normal_pil, front_pil, do_refine=False, expansion_weight=0.1, init_type="std"):
46
+ rgb_pils = split_image(rgb_pil, rows=2)
47
+ if normal_pil is not None:
48
+ normal_pil = split_image(normal_pil, rows=2)
49
+ if front_pil is None:
50
+ front_pil = rgb_pils[0]
51
+ new_meshes = geo_reconstruct(rgb_pils, normal_pil, front_pil, do_refine=do_refine, predict_normal=normal_pil is None, expansion_weight=expansion_weight, init_type=init_type)
52
+ ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=False)
53
+ return ret_mesh
54
+
55
+ def new_multiview_to_mesh_ui(concurrency_id="wkl"):
56
+ with gr.Row():
57
+ with gr.Column(scale=2):
58
+ rgb_pil = gr.Image(type='pil', image_mode='RGB', label='RGB')
59
+ front_pil = gr.Image(type='pil', image_mode='RGBA', label='Frontview(Optinal)')
60
+ normal_pil = gr.Image(type='pil', image_mode='RGBA', label='Normal(Optinal)')
61
+ do_refine = gr.Checkbox(
62
+ value=False,
63
+ label='Refine rgb',
64
+ visible=False,
65
+ )
66
+ expansion_weight = gr.Slider(minimum=-1.0, maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
67
+ init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh initialization", value="std", visible=False)
68
+ run_btn = gr.Button('Generate 3D', interactive=True)
69
+ with gr.Column(scale=3):
70
+ # export mesh display
71
+ output_mesh = gr.Model3D(value=None, label="mesh model", show_label=True)
72
+ run_btn.click(
73
+ fn = multiview_to_mesh_v2,
74
+ inputs=[rgb_pil, normal_pil, front_pil, do_refine, expansion_weight, init_type],
75
+ outputs=[output_mesh],
76
+ concurrency_id=concurrency_id,
77
+ api_name="multiview_to_mesh",
78
+ )
79
+ return rgb_pil, front_pil, output_mesh
80
+
81
+
82
+ #######################################
83
+ def create_step_ui(concurrency_id="wkl"):
84
+ with gr.Tab(label="3D:concept_to_multiview"):
85
+ concept_to_multiview_ui(concurrency_id)
86
+ with gr.Tab(label="3D:new_multiview_to_mesh"):
87
+ new_multiview_to_mesh_ui(concurrency_id)
app/gradio_local.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import os
3
+ import sys
4
+ sys.path.append(os.curdir)
5
+ if 'CUDA_VISIBLE_DEVICES' not in os.environ:
6
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
7
+ os.environ['TRANSFORMERS_OFFLINE']='0'
8
+ os.environ['DIFFUSERS_OFFLINE']='0'
9
+ os.environ['HF_HUB_OFFLINE']='0'
10
+ os.environ['GRADIO_ANALYTICS_ENABLED']='False'
11
+ os.environ['HF_ENDPOINT']='https://hf-mirror.com'
12
+ import torch
13
+ torch.set_float32_matmul_precision('medium')
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.set_grad_enabled(False)
16
+
17
+ import gradio as gr
18
+ import argparse
19
+
20
+ from app.gradio_3dgen import create_ui as create_3d_ui
21
+ # from app.gradio_3dgen_steps import create_step_ui
22
+ from app.all_models import model_zoo
23
+
24
+
25
+ _TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
26
+ _DESCRIPTION = '''
27
+ [Project page](https://wukailu.github.io/Unique3D/)
28
+
29
+ * High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
30
+
31
+ **If the Gradio Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.**
32
+ '''
33
+
34
+ def launch(
35
+ port,
36
+ listen=False,
37
+ share=False,
38
+ gradio_root="",
39
+ ):
40
+ model_zoo.init_models()
41
+
42
+ with gr.Blocks(
43
+ title=_TITLE,
44
+ theme=gr.themes.Monochrome(),
45
+ ) as demo:
46
+ with gr.Row():
47
+ with gr.Column(scale=1):
48
+ gr.Markdown('# ' + _TITLE)
49
+ gr.Markdown(_DESCRIPTION)
50
+ create_3d_ui("wkl")
51
+
52
+ launch_args = {}
53
+ if listen:
54
+ launch_args["server_name"] = "0.0.0.0"
55
+
56
+ demo.queue(default_concurrency_limit=1).launch(
57
+ server_port=None if port == 0 else port,
58
+ share=share,
59
+ root_path=gradio_root if gradio_root != "" else None, # "/myapp"
60
+ **launch_args,
61
+ )
62
+
63
+ if __name__ == "__main__":
64
+ parser = argparse.ArgumentParser()
65
+ args, extra = parser.parse_known_args()
66
+ parser.add_argument("--listen", action="store_true")
67
+ parser.add_argument("--port", type=int, default=0)
68
+ parser.add_argument("--share", action="store_true")
69
+ parser.add_argument("--gradio_root", default="")
70
+ args = parser.parse_args()
71
+ launch(
72
+ args.port,
73
+ listen=args.listen,
74
+ share=args.share,
75
+ gradio_root=args.gradio_root,
76
+ )
app/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import gc
5
+ import numpy as np
6
+ import numpy as np
7
+ from PIL import Image
8
+ from scripts.refine_lr_to_sr import run_sr_fast
9
+
10
+ GRADIO_CACHE = "/tmp/gradio/"
11
+
12
+ def clean_up():
13
+ torch.cuda.empty_cache()
14
+ gc.collect()
15
+
16
+ def remove_color(arr):
17
+ if arr.shape[-1] == 4:
18
+ arr = arr[..., :3]
19
+ # calc diffs
20
+ base = arr[0, 0]
21
+ diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1)
22
+ alpha = (diffs <= 80)
23
+
24
+ arr[alpha] = 255
25
+ alpha = ~alpha
26
+ arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1)
27
+ return arr
28
+
29
+ def simple_remove(imgs, run_sr=True):
30
+ """Only works for normal"""
31
+ if not isinstance(imgs, list):
32
+ imgs = [imgs]
33
+ single_input = True
34
+ else:
35
+ single_input = False
36
+ if run_sr:
37
+ imgs = run_sr_fast(imgs)
38
+ rets = []
39
+ for img in imgs:
40
+ arr = np.array(img)
41
+ arr = remove_color(arr)
42
+ rets.append(Image.fromarray(arr.astype(np.uint8)))
43
+ if single_input:
44
+ return rets[0]
45
+ return rets
46
+
47
+ def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"):
48
+ new_image = Image.new("RGBA", rgba.size, bkgd)
49
+ new_image.paste(rgba, (0, 0), rgba)
50
+ new_image = new_image.convert('RGB')
51
+ return new_image
52
+
53
+ def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"):
54
+ rgb_white = rgba_to_rgb(rgba, bkgd)
55
+ new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1))
56
+ return new_rgba
57
+
58
+ def split_image(image, rows=None, cols=None):
59
+ """
60
+ inverse function of make_image_grid
61
+ """
62
+ # image is in square
63
+ if rows is None and cols is None:
64
+ # image.size [W, H]
65
+ rows = 1
66
+ cols = image.size[0] // image.size[1]
67
+ assert cols * image.size[1] == image.size[0]
68
+ subimg_size = image.size[1]
69
+ elif rows is None:
70
+ subimg_size = image.size[0] // cols
71
+ rows = image.size[1] // subimg_size
72
+ assert rows * subimg_size == image.size[1]
73
+ elif cols is None:
74
+ subimg_size = image.size[1] // rows
75
+ cols = image.size[0] // subimg_size
76
+ assert cols * subimg_size == image.size[0]
77
+ else:
78
+ subimg_size = image.size[1] // rows
79
+ assert cols * subimg_size == image.size[0]
80
+ subimgs = []
81
+ for i in range(rows):
82
+ for j in range(cols):
83
+ subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size))
84
+ subimgs.append(subimg)
85
+ return subimgs
86
+
87
+ def make_image_grid(images, rows=None, cols=None, resize=None):
88
+ if rows is None and cols is None:
89
+ rows = 1
90
+ cols = len(images)
91
+ if rows is None:
92
+ rows = len(images) // cols
93
+ if len(images) % cols != 0:
94
+ rows += 1
95
+ if cols is None:
96
+ cols = len(images) // rows
97
+ if len(images) % rows != 0:
98
+ cols += 1
99
+ total_imgs = rows * cols
100
+ if total_imgs > len(images):
101
+ images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))]
102
+
103
+ if resize is not None:
104
+ images = [img.resize((resize, resize)) for img in images]
105
+
106
+ w, h = images[0].size
107
+ grid = Image.new(images[0].mode, size=(cols * w, rows * h))
108
+
109
+ for i, img in enumerate(images):
110
+ grid.paste(img, box=(i % cols * w, i // cols * h))
111
+ return grid
112
+
assets/teaser.jpg ADDED
assets/teaser_safe.jpg ADDED

Git LFS Details

  • SHA256: 5eb9060bc45c1d334f988e8053f1de40cf60df907750dfef89d81cdbe86ffc79
  • Pointer size: 132 Bytes
  • Size of remote file: 2.82 MB
custum_3d_diffusion/custum_modules/attention_processors.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+ import torch
3
+ from diffusers.models.attention_processor import Attention
4
+
5
+ def construct_pix2pix_attention(hidden_states_dim, norm_type="none"):
6
+ if norm_type == "layernorm":
7
+ norm = torch.nn.LayerNorm(hidden_states_dim)
8
+ else:
9
+ norm = torch.nn.Identity()
10
+ attention = Attention(
11
+ query_dim=hidden_states_dim,
12
+ heads=8,
13
+ dim_head=hidden_states_dim // 8,
14
+ bias=True,
15
+ )
16
+ # NOTE: xformers 0.22 does not support batchsize >= 4096
17
+ attention.xformers_not_supported = True # hacky solution
18
+ return norm, attention
19
+
20
+ class ExtraAttnProc(torch.nn.Module):
21
+ def __init__(
22
+ self,
23
+ chained_proc,
24
+ enabled=False,
25
+ name=None,
26
+ mode='extract',
27
+ with_proj_in=False,
28
+ proj_in_dim=768,
29
+ target_dim=None,
30
+ pixel_wise_crosspond=False,
31
+ norm_type="none", # none or layernorm
32
+ crosspond_effect_on="all", # all or first
33
+ crosspond_chain_pos="parralle", # before or parralle or after
34
+ simple_3d=False,
35
+ views=4,
36
+ ) -> None:
37
+ super().__init__()
38
+ self.enabled = enabled
39
+ self.chained_proc = chained_proc
40
+ self.name = name
41
+ self.mode = mode
42
+ self.with_proj_in=with_proj_in
43
+ self.proj_in_dim = proj_in_dim
44
+ self.target_dim = target_dim or proj_in_dim
45
+ self.hidden_states_dim = self.target_dim
46
+ self.pixel_wise_crosspond = pixel_wise_crosspond
47
+ self.crosspond_effect_on = crosspond_effect_on
48
+ self.crosspond_chain_pos = crosspond_chain_pos
49
+ self.views = views
50
+ self.simple_3d = simple_3d
51
+ if self.with_proj_in and self.enabled:
52
+ self.in_linear = torch.nn.Linear(self.proj_in_dim, self.target_dim, bias=False)
53
+ if self.target_dim == self.proj_in_dim:
54
+ self.in_linear.weight.data = torch.eye(proj_in_dim)
55
+ else:
56
+ self.in_linear = None
57
+ if self.pixel_wise_crosspond and self.enabled:
58
+ self.crosspond_norm, self.crosspond_attention = construct_pix2pix_attention(self.hidden_states_dim, norm_type=norm_type)
59
+
60
+ def do_crosspond_attention(self, hidden_states: torch.FloatTensor, other_states: torch.FloatTensor):
61
+ hidden_states = self.crosspond_norm(hidden_states)
62
+
63
+ batch, L, D = hidden_states.shape
64
+ assert hidden_states.shape == other_states.shape, f"got {hidden_states.shape} and {other_states.shape}"
65
+ # to -> batch * L, 1, D
66
+ hidden_states = hidden_states.reshape(batch * L, 1, D)
67
+ other_states = other_states.reshape(batch * L, 1, D)
68
+ hidden_states_catted = other_states
69
+ hidden_states = self.crosspond_attention(
70
+ hidden_states,
71
+ encoder_hidden_states=hidden_states_catted,
72
+ )
73
+ return hidden_states.reshape(batch, L, D)
74
+
75
+ def __call__(
76
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
77
+ ref_dict: dict = None, mode=None, **kwargs
78
+ ) -> Any:
79
+ if not self.enabled:
80
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
81
+ if encoder_hidden_states is None:
82
+ encoder_hidden_states = hidden_states
83
+ assert ref_dict is not None
84
+ if (mode or self.mode) == 'extract':
85
+ ref_dict[self.name] = hidden_states
86
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
87
+ if self.pixel_wise_crosspond and self.crosspond_chain_pos == "after":
88
+ ref_dict[self.name] = hidden_states1
89
+ return hidden_states1
90
+ elif (mode or self.mode) == 'inject':
91
+ ref_state = ref_dict.pop(self.name)
92
+ if self.with_proj_in:
93
+ ref_state = self.in_linear(ref_state)
94
+
95
+ B, L, D = ref_state.shape
96
+ if hidden_states.shape[0] == B:
97
+ modalities = 1
98
+ views = 1
99
+ else:
100
+ modalities = hidden_states.shape[0] // B // self.views
101
+ views = self.views
102
+ if self.pixel_wise_crosspond:
103
+ if self.crosspond_effect_on == "all":
104
+ ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, *ref_state.shape[-2:])
105
+
106
+ if self.crosspond_chain_pos == "before":
107
+ hidden_states = hidden_states + self.do_crosspond_attention(hidden_states, ref_state)
108
+
109
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
110
+
111
+ if self.crosspond_chain_pos == "parralle":
112
+ hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states, ref_state)
113
+
114
+ if self.crosspond_chain_pos == "after":
115
+ hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states1, ref_state)
116
+ return hidden_states1
117
+ else:
118
+ assert self.crosspond_effect_on == "first"
119
+ # hidden_states [B * modalities * views, L, D]
120
+ # ref_state [B, L, D]
121
+ ref_state = ref_state[:, None].expand(-1, modalities, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1]) # [B * modalities, L, D]
122
+
123
+ def do_paritial_crosspond(hidden_states, ref_state):
124
+ first_view_hidden_states = hidden_states.view(-1, views, hidden_states.shape[1], hidden_states.shape[2])[:, 0] # [B * modalities, L, D]
125
+ hidden_states2 = self.do_crosspond_attention(first_view_hidden_states, ref_state) # [B * modalities, L, D]
126
+ hidden_states2_padded = torch.zeros_like(hidden_states).reshape(-1, views, hidden_states.shape[1], hidden_states.shape[2])
127
+ hidden_states2_padded[:, 0] = hidden_states2
128
+ hidden_states2_padded = hidden_states2_padded.reshape(-1, hidden_states.shape[1], hidden_states.shape[2])
129
+ return hidden_states2_padded
130
+
131
+ if self.crosspond_chain_pos == "before":
132
+ hidden_states = hidden_states + do_paritial_crosspond(hidden_states, ref_state)
133
+
134
+ hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) # [B * modalities * views, L, D]
135
+ if self.crosspond_chain_pos == "parralle":
136
+ hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states, ref_state)
137
+ if self.crosspond_chain_pos == "after":
138
+ hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states1, ref_state)
139
+ return hidden_states1
140
+ elif self.simple_3d:
141
+ B, L, C = encoder_hidden_states.shape
142
+ mv = self.views
143
+ encoder_hidden_states = encoder_hidden_states.reshape(B // mv, mv, L, C)
144
+ ref_state = ref_state[:, None]
145
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
146
+ encoder_hidden_states = encoder_hidden_states.reshape(B // mv, 1, (mv+1) * L, C)
147
+ encoder_hidden_states = encoder_hidden_states.repeat(1, mv, 1, 1).reshape(-1, (mv+1) * L, C)
148
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
149
+ else:
150
+ ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1])
151
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
152
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
153
+ else:
154
+ raise NotImplementedError("mode or self.mode is required to be 'extract' or 'inject'")
155
+
156
+ def add_extra_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
157
+ return_dict = torch.nn.ModuleDict()
158
+ proj_in_dim = kwargs.get('proj_in_dim', False)
159
+ kwargs.pop('proj_in_dim', None)
160
+
161
+ def recursive_add_processors(name: str, module: torch.nn.Module):
162
+ for sub_name, child in module.named_children():
163
+ if "ref_unet" not in (sub_name + name):
164
+ recursive_add_processors(f"{name}.{sub_name}", child)
165
+
166
+ if isinstance(module, Attention):
167
+ new_processor = ExtraAttnProc(
168
+ chained_proc=module.get_processor(),
169
+ enabled=enable_filter(f"{name}.processor"),
170
+ name=f"{name}.processor",
171
+ proj_in_dim=proj_in_dim if proj_in_dim else module.cross_attention_dim,
172
+ target_dim=module.cross_attention_dim,
173
+ **kwargs
174
+ )
175
+ module.set_processor(new_processor)
176
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
177
+
178
+ for name, module in model.named_children():
179
+ recursive_add_processors(name, module)
180
+ return return_dict
181
+
182
+ def switch_extra_processor(model, enable_filter=lambda x:True):
183
+ def recursive_add_processors(name: str, module: torch.nn.Module):
184
+ for sub_name, child in module.named_children():
185
+ recursive_add_processors(f"{name}.{sub_name}", child)
186
+
187
+ if isinstance(module, ExtraAttnProc):
188
+ module.enabled = enable_filter(name)
189
+
190
+ for name, module in model.named_children():
191
+ recursive_add_processors(name, module)
192
+
193
+ class multiviewAttnProc(torch.nn.Module):
194
+ def __init__(
195
+ self,
196
+ chained_proc,
197
+ enabled=False,
198
+ name=None,
199
+ hidden_states_dim=None,
200
+ chain_pos="parralle", # before or parralle or after
201
+ num_modalities=1,
202
+ views=4,
203
+ base_img_size=64,
204
+ ) -> None:
205
+ super().__init__()
206
+ self.enabled = enabled
207
+ self.chained_proc = chained_proc
208
+ self.name = name
209
+ self.hidden_states_dim = hidden_states_dim
210
+ self.num_modalities = num_modalities
211
+ self.views = views
212
+ self.base_img_size = base_img_size
213
+ self.chain_pos = chain_pos
214
+ self.diff_joint_attn = True
215
+
216
+ def __call__(
217
+ self,
218
+ attn: Attention,
219
+ hidden_states: torch.FloatTensor,
220
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
221
+ attention_mask: Optional[torch.FloatTensor] = None,
222
+ **kwargs
223
+ ) -> torch.Tensor:
224
+ if not self.enabled:
225
+ return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
226
+
227
+ B, L, C = hidden_states.shape
228
+ mv = self.views
229
+ hidden_states = hidden_states.reshape(B // mv, mv, L, C).reshape(-1, mv * L, C)
230
+ hidden_states = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
231
+ return hidden_states.reshape(B // mv, mv, L, C).reshape(-1, L, C)
232
+
233
+ def add_multiview_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
234
+ return_dict = torch.nn.ModuleDict()
235
+ def recursive_add_processors(name: str, module: torch.nn.Module):
236
+ for sub_name, child in module.named_children():
237
+ if "ref_unet" not in (sub_name + name):
238
+ recursive_add_processors(f"{name}.{sub_name}", child)
239
+
240
+ if isinstance(module, Attention):
241
+ new_processor = multiviewAttnProc(
242
+ chained_proc=module.get_processor(),
243
+ enabled=enable_filter(f"{name}.processor"),
244
+ name=f"{name}.processor",
245
+ hidden_states_dim=module.inner_dim,
246
+ **kwargs
247
+ )
248
+ module.set_processor(new_processor)
249
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
250
+
251
+ for name, module in model.named_children():
252
+ recursive_add_processors(name, module)
253
+
254
+ return return_dict
255
+
256
+ def switch_multiview_processor(model, enable_filter=lambda x:True):
257
+ def recursive_add_processors(name: str, module: torch.nn.Module):
258
+ for sub_name, child in module.named_children():
259
+ recursive_add_processors(f"{name}.{sub_name}", child)
260
+
261
+ if isinstance(module, Attention):
262
+ processor = module.get_processor()
263
+ if isinstance(processor, multiviewAttnProc):
264
+ processor.enabled = enable_filter(f"{name}.processor")
265
+
266
+ for name, module in model.named_children():
267
+ recursive_add_processors(name, module)
268
+
269
+ class NNModuleWrapper(torch.nn.Module):
270
+ def __init__(self, module):
271
+ super().__init__()
272
+ self.module = module
273
+
274
+ def forward(self, *args, **kwargs):
275
+ return self.module(*args, **kwargs)
276
+
277
+ def __getattr__(self, name: str):
278
+ try:
279
+ return super().__getattr__(name)
280
+ except AttributeError:
281
+ return getattr(self.module, name)
282
+
283
+ class AttnProcessorSwitch(torch.nn.Module):
284
+ def __init__(
285
+ self,
286
+ proc_dict: dict,
287
+ enabled_proc="default",
288
+ name=None,
289
+ switch_name="default_switch",
290
+ ):
291
+ super().__init__()
292
+ self.proc_dict = torch.nn.ModuleDict({k: (v if isinstance(v, torch.nn.Module) else NNModuleWrapper(v)) for k, v in proc_dict.items()})
293
+ self.enabled_proc = enabled_proc
294
+ self.name = name
295
+ self.switch_name = switch_name
296
+ self.choose_module(enabled_proc)
297
+
298
+ def choose_module(self, enabled_proc):
299
+ self.enabled_proc = enabled_proc
300
+ assert enabled_proc in self.proc_dict.keys()
301
+
302
+ def __call__(
303
+ self,
304
+ *args,
305
+ **kwargs
306
+ ) -> torch.FloatTensor:
307
+ used_proc = self.proc_dict[self.enabled_proc]
308
+ return used_proc(*args, **kwargs)
309
+
310
+ def add_switch(model: torch.nn.Module, module_filter=lambda x:True, switch_dict_fn=lambda x: {"default": x}, switch_name="default_switch", enabled_proc="default"):
311
+ return_dict = torch.nn.ModuleDict()
312
+ def recursive_add_processors(name: str, module: torch.nn.Module):
313
+ for sub_name, child in module.named_children():
314
+ if "ref_unet" not in (sub_name + name):
315
+ recursive_add_processors(f"{name}.{sub_name}", child)
316
+
317
+ if isinstance(module, Attention):
318
+ processor = module.get_processor()
319
+ if module_filter(processor):
320
+ proc_dict = switch_dict_fn(processor)
321
+ new_processor = AttnProcessorSwitch(
322
+ proc_dict=proc_dict,
323
+ enabled_proc=enabled_proc,
324
+ name=f"{name}.processor",
325
+ switch_name=switch_name,
326
+ )
327
+ module.set_processor(new_processor)
328
+ return_dict[f"{name}.processor".replace(".", "__")] = new_processor
329
+
330
+ for name, module in model.named_children():
331
+ recursive_add_processors(name, module)
332
+
333
+ return return_dict
334
+
335
+ def change_switch(model: torch.nn.Module, switch_name="default_switch", enabled_proc="default"):
336
+ def recursive_change_processors(name: str, module: torch.nn.Module):
337
+ for sub_name, child in module.named_children():
338
+ recursive_change_processors(f"{name}.{sub_name}", child)
339
+
340
+ if isinstance(module, Attention):
341
+ processor = module.get_processor()
342
+ if isinstance(processor, AttnProcessorSwitch) and processor.switch_name == switch_name:
343
+ processor.choose_module(enabled_proc)
344
+
345
+ for name, module in model.named_children():
346
+ recursive_change_processors(name, module)
347
+
348
+ ########## Hack: Attention fix #############
349
+ from diffusers.models.attention import Attention
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states: torch.FloatTensor,
354
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
355
+ attention_mask: Optional[torch.FloatTensor] = None,
356
+ **cross_attention_kwargs,
357
+ ) -> torch.Tensor:
358
+ r"""
359
+ The forward method of the `Attention` class.
360
+
361
+ Args:
362
+ hidden_states (`torch.Tensor`):
363
+ The hidden states of the query.
364
+ encoder_hidden_states (`torch.Tensor`, *optional*):
365
+ The hidden states of the encoder.
366
+ attention_mask (`torch.Tensor`, *optional*):
367
+ The attention mask to use. If `None`, no mask is applied.
368
+ **cross_attention_kwargs:
369
+ Additional keyword arguments to pass along to the cross attention.
370
+
371
+ Returns:
372
+ `torch.Tensor`: The output of the attention layer.
373
+ """
374
+ # The `Attention` class can call different attention processors / attention functions
375
+ # here we simply pass along all tensors to the selected processor class
376
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
377
+ return self.processor(
378
+ self,
379
+ hidden_states,
380
+ encoder_hidden_states=encoder_hidden_states,
381
+ attention_mask=attention_mask,
382
+ **cross_attention_kwargs,
383
+ )
384
+
385
+ Attention.forward = forward
custum_3d_diffusion/custum_modules/unifield_processor.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import FunctionType
2
+ from typing import Any, Dict, List
3
+ from diffusers import UNet2DConditionModel
4
+ import torch
5
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, ImageProjection
6
+ from diffusers.models.attention_processor import Attention, AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
7
+ from dataclasses import dataclass, field
8
+ from diffusers.loaders import IPAdapterMixin
9
+ from custum_3d_diffusion.custum_modules.attention_processors import add_extra_processor, switch_extra_processor, add_multiview_processor, switch_multiview_processor, add_switch, change_switch
10
+
11
+ @dataclass
12
+ class AttnConfig:
13
+ """
14
+ * CrossAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), IPAdapter module (achieves conceptual control).
15
+ * SelfAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), Reference Attention module (achieves pixel-level control).
16
+ * Multiview Attention module: Multiview Attention module (achieves multi-view consistency).
17
+ * Cross Modality Attention module: Cross Modality Attention module (achieves multi-modality consistency).
18
+
19
+ For setups:
20
+ train_xxx_lr is implemented in the U-Net architecture.
21
+ enable_xxx_lora is implemented in the U-Net architecture.
22
+ enable_xxx_ip is implemented in the processor and U-Net architecture.
23
+ enable_xxx_ref_proj_in is implemented in the processor.
24
+ """
25
+ latent_size: int = 64
26
+
27
+ train_lr: float = 0
28
+ # for cross attention
29
+ # 0 learning rate for not training
30
+ train_cross_attn_lr: float = 0
31
+ train_cross_attn_lora_lr: float = 0
32
+ train_cross_attn_ip_lr: float = 0 # 0 for not trained
33
+ init_cross_attn_lora: bool = False
34
+ enable_cross_attn_lora: bool = False
35
+ init_cross_attn_ip: bool = False
36
+ enable_cross_attn_ip: bool = False
37
+ cross_attn_lora_rank: int = 64 # 0 for not enabled
38
+ cross_attn_lora_only_kv: bool = False
39
+ ipadapter_pretrained_name: str = "h94/IP-Adapter"
40
+ ipadapter_subfolder_name: str = "models"
41
+ ipadapter_weight_name: str = "ip-adapter-plus_sd15.safetensors"
42
+ ipadapter_effect_on: str = "all" # all, first
43
+
44
+ # for self attention
45
+ train_self_attn_lr: float = 0
46
+ train_self_attn_lora_lr: float = 0
47
+ init_self_attn_lora: bool = False
48
+ enable_self_attn_lora: bool = False
49
+ self_attn_lora_rank: int = 64
50
+ self_attn_lora_only_kv: bool = False
51
+
52
+ train_self_attn_ref_lr: float = 0
53
+ train_ref_unet_lr: float = 0
54
+ init_self_attn_ref: bool = False
55
+ enable_self_attn_ref: bool = False
56
+ self_attn_ref_other_model_name: str = ""
57
+ self_attn_ref_position: str = "attn1"
58
+ self_attn_ref_pixel_wise_crosspond: bool = False # enable pixel_wise_crosspond in refattn
59
+ self_attn_ref_chain_pos: str = "parralle" # before or parralle or after
60
+ self_attn_ref_effect_on: str = "all" # all or first, for _crosspond attn
61
+ self_attn_ref_zero_init: bool = True
62
+ use_simple3d_attn: bool = False
63
+
64
+ # for multiview attention
65
+ init_multiview_attn: bool = False
66
+ enable_multiview_attn: bool = False
67
+ multiview_attn_position: str = "attn1"
68
+ multiview_chain_pose: str = "parralle" # before or parralle or after
69
+ num_modalities: int = 1
70
+ use_mv_joint_attn: bool = False
71
+
72
+ # for unet
73
+ init_unet_path: str = "runwayml/stable-diffusion-v1-5"
74
+ init_num_cls_label: int = 0 # for initialize
75
+ cls_labels: List[int] = field(default_factory=lambda: [])
76
+ cls_label_type: str = "embedding"
77
+ cat_condition: bool = False # cat condition to input
78
+
79
+ class Configurable:
80
+ attn_config: AttnConfig
81
+
82
+ def set_config(self, attn_config: AttnConfig):
83
+ raise NotImplementedError()
84
+
85
+ def update_config(self, attn_config: AttnConfig):
86
+ self.attn_config = attn_config
87
+
88
+ def do_set_config(self, attn_config: AttnConfig):
89
+ self.set_config(attn_config)
90
+ for name, module in self.named_modules():
91
+ if isinstance(module, Configurable):
92
+ if hasattr(module, "do_set_config"):
93
+ module.do_set_config(attn_config)
94
+ else:
95
+ print(f"Warning: {name} has no attribute do_set_config, but is an instance of Configurable")
96
+ module.attn_config = attn_config
97
+
98
+ def do_update_config(self, attn_config: AttnConfig):
99
+ self.update_config(attn_config)
100
+ for name, module in self.named_modules():
101
+ if isinstance(module, Configurable):
102
+ if hasattr(module, "do_update_config"):
103
+ module.do_update_config(attn_config)
104
+ else:
105
+ print(f"Warning: {name} has no attribute do_update_config, but is an instance of Configurable")
106
+ module.attn_config = attn_config
107
+
108
+ from diffusers import ModelMixin # Must import ModelMixin for CompiledUNet
109
+ class UnifieldWrappedUNet(UNet2DConditionModel):
110
+ forward_hook: FunctionType
111
+
112
+ def forward(self, *args, **kwargs):
113
+ if hasattr(self, 'forward_hook'):
114
+ return self.forward_hook(super().forward, *args, **kwargs)
115
+ return super().forward(*args, **kwargs)
116
+
117
+
118
+ class ConfigurableUNet2DConditionModel(Configurable, IPAdapterMixin):
119
+ unet: UNet2DConditionModel
120
+
121
+ cls_embedding_param_dict = {}
122
+ cross_attn_lora_param_dict = {}
123
+ self_attn_lora_param_dict = {}
124
+ cross_attn_param_dict = {}
125
+ self_attn_param_dict = {}
126
+ ipadapter_param_dict = {}
127
+ ref_attn_param_dict = {}
128
+ ref_unet_param_dict = {}
129
+ multiview_attn_param_dict = {}
130
+ other_param_dict = {}
131
+
132
+ rev_param_name_mapping = {}
133
+
134
+ class_labels = []
135
+ def set_class_labels(self, class_labels: torch.Tensor):
136
+ if self.attn_config.init_num_cls_label != 0:
137
+ self.class_labels = class_labels.to(self.unet.device).long()
138
+
139
+ def __init__(self, init_config: AttnConfig, weight_dtype) -> None:
140
+ super().__init__()
141
+ self.weight_dtype = weight_dtype
142
+ self.set_config(init_config)
143
+
144
+ def enable_xformers_memory_efficient_attention(self):
145
+ self.unet.enable_xformers_memory_efficient_attention
146
+ def recursive_add_processors(name: str, module: torch.nn.Module):
147
+ for sub_name, child in module.named_children():
148
+ recursive_add_processors(f"{name}.{sub_name}", child)
149
+
150
+ if isinstance(module, Attention):
151
+ if hasattr(module, 'xformers_not_supported'):
152
+ return
153
+ old_processor = module.get_processor()
154
+ if isinstance(old_processor, (AttnProcessor, AttnProcessor2_0)):
155
+ module.set_use_memory_efficient_attention_xformers(True)
156
+
157
+ for name, module in self.unet.named_children():
158
+ recursive_add_processors(name, module)
159
+
160
+ def __getattr__(self, name: str) -> Any:
161
+ try:
162
+ return super().__getattr__(name)
163
+ except AttributeError:
164
+ return getattr(self.unet, name)
165
+
166
+ # --- for IPAdapterMixin
167
+
168
+ def register_modules(self, **kwargs):
169
+ for name, module in kwargs.items():
170
+ # set models
171
+ setattr(self, name, module)
172
+
173
+ def register_to_config(self, **kwargs):
174
+ pass
175
+
176
+ def unload_ip_adapter(self):
177
+ raise NotImplementedError()
178
+
179
+ # --- for Configurable
180
+
181
+ def get_refunet(self):
182
+ if self.attn_config.self_attn_ref_other_model_name == "self":
183
+ return self.unet
184
+ else:
185
+ return self.unet.ref_unet
186
+
187
+ def set_config(self, attn_config: AttnConfig):
188
+ self.attn_config = attn_config
189
+
190
+ unet_type = UnifieldWrappedUNet
191
+ # class_embed_type = "projection" for 'camera'
192
+ # class_embed_type = None for 'embedding'
193
+ unet_kwargs = {}
194
+ if attn_config.init_num_cls_label > 0:
195
+ if attn_config.cls_label_type == "embedding":
196
+ unet_kwargs = {
197
+ "num_class_embeds": attn_config.init_num_cls_label,
198
+ "device_map": None,
199
+ "low_cpu_mem_usage": False,
200
+ "class_embed_type": None,
201
+ }
202
+ else:
203
+ raise ValueError(f"cls_label_type {attn_config.cls_label_type} is not supported")
204
+
205
+ self.unet: UnifieldWrappedUNet = unet_type.from_pretrained(
206
+ attn_config.init_unet_path, subfolder="unet", torch_dtype=self.weight_dtype,
207
+ ignore_mismatched_sizes=True, # Added this line
208
+ **unet_kwargs
209
+ )
210
+ assert isinstance(self.unet, UnifieldWrappedUNet)
211
+ self.unet.forward_hook = self.unet_forward_hook
212
+
213
+ if self.attn_config.cat_condition:
214
+ # double in_channels
215
+ if self.unet.config.in_channels != 8:
216
+ self.unet.register_to_config(in_channels=self.unet.config.in_channels * 2)
217
+ # repeate unet.conv_in weight twice
218
+ doubled_conv_in = torch.nn.Conv2d(self.unet.conv_in.in_channels * 2, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
219
+ doubled_conv_in.weight.data = torch.cat([self.unet.conv_in.weight.data, torch.zeros_like(self.unet.conv_in.weight.data)], dim=1)
220
+ doubled_conv_in.bias.data = self.unet.conv_in.bias.data
221
+ self.unet.conv_in = doubled_conv_in
222
+
223
+ used_param_ids = set()
224
+
225
+ if attn_config.init_cross_attn_lora:
226
+ # setup lora
227
+ from peft import LoraConfig
228
+ from peft.utils import get_peft_model_state_dict
229
+ if attn_config.cross_attn_lora_only_kv:
230
+ target_modules=["attn2.to_k", "attn2.to_v"]
231
+ else:
232
+ target_modules=["attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0"]
233
+ lora_config: LoraConfig = LoraConfig(
234
+ r=attn_config.cross_attn_lora_rank,
235
+ lora_alpha=attn_config.cross_attn_lora_rank,
236
+ init_lora_weights="gaussian",
237
+ target_modules=target_modules,
238
+ )
239
+ adapter_name="cross_attn_lora"
240
+ self.unet.add_adapter(lora_config, adapter_name=adapter_name)
241
+ # update cross_attn_lora_param_dict
242
+ self.cross_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
243
+ used_param_ids.update(self.cross_attn_lora_param_dict.keys())
244
+
245
+ if attn_config.init_self_attn_lora:
246
+ # setup lora
247
+ from peft import LoraConfig
248
+ if attn_config.self_attn_lora_only_kv:
249
+ target_modules=["attn1.to_k", "attn1.to_v"]
250
+ else:
251
+ target_modules=["attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0"]
252
+ lora_config: LoraConfig = LoraConfig(
253
+ r=attn_config.self_attn_lora_rank,
254
+ lora_alpha=attn_config.self_attn_lora_rank,
255
+ init_lora_weights="gaussian",
256
+ target_modules=target_modules,
257
+ )
258
+ adapter_name="self_attn_lora"
259
+ self.unet.add_adapter(lora_config, adapter_name=adapter_name)
260
+ # update cross_self_lora_param_dict
261
+ self.self_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
262
+ used_param_ids.update(self.self_attn_lora_param_dict.keys())
263
+
264
+ if attn_config.init_num_cls_label != 0:
265
+ self.cls_embedding_param_dict = {id(param): param for param in self.unet.class_embedding.parameters()}
266
+ used_param_ids.update(self.cls_embedding_param_dict.keys())
267
+ self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
268
+
269
+ if attn_config.init_cross_attn_ip:
270
+ self.image_encoder = None
271
+ # setup ipadapter
272
+ self.load_ip_adapter(
273
+ attn_config.ipadapter_pretrained_name,
274
+ subfolder=attn_config.ipadapter_subfolder_name,
275
+ weight_name=attn_config.ipadapter_weight_name
276
+ )
277
+ # warp ip_adapter_attn_proc with switch
278
+ from diffusers.models.attention_processor import IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0
279
+ add_switch(self.unet, module_filter=lambda x: isinstance(x, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)), switch_dict_fn=lambda x: {"ipadapter": x, "default": XFormersAttnProcessor()}, switch_name="ipadapter_switch", enabled_proc="ipadapter")
280
+ # update ipadapter_param_dict
281
+ # weights are in attention processors and unet.encoder_hid_proj
282
+ self.ipadapter_param_dict = {id(param): param for param in self.unet.encoder_hid_proj.parameters() if id(param) not in used_param_ids}
283
+ used_param_ids.update(self.ipadapter_param_dict.keys())
284
+ print("DEBUG: ipadapter_param_dict len in encoder_hid_proj", len(self.ipadapter_param_dict))
285
+ for name, processor in self.unet.attn_processors.items():
286
+ if hasattr(processor, "to_k_ip"):
287
+ self.ipadapter_param_dict.update({id(param): param for param in processor.parameters()})
288
+ print(f"DEBUG: ipadapter_param_dict len in all", len(self.ipadapter_param_dict))
289
+
290
+ ref_unet = None
291
+ if attn_config.init_self_attn_ref:
292
+ # setup reference attention processor
293
+ if attn_config.self_attn_ref_other_model_name == "self":
294
+ raise NotImplementedError("self reference is not fully implemented")
295
+ else:
296
+ ref_unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
297
+ attn_config.self_attn_ref_other_model_name, subfolder="unet", torch_dtype=self.unet.dtype
298
+ )
299
+ ref_unet.to(self.unet.device)
300
+ if self.attn_config.train_ref_unet_lr == 0:
301
+ ref_unet.eval()
302
+ ref_unet.requires_grad_(False)
303
+ else:
304
+ ref_unet.train()
305
+
306
+ add_extra_processor(
307
+ model=ref_unet,
308
+ enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
309
+ mode='extract',
310
+ with_proj_in=False,
311
+ pixel_wise_crosspond=False,
312
+ )
313
+ # NOTE: Here require cross_attention_dim in two unet's self attention should be the same
314
+ processor_dict = add_extra_processor(
315
+ model=self.unet,
316
+ enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
317
+ mode='inject',
318
+ with_proj_in=False,
319
+ pixel_wise_crosspond=attn_config.self_attn_ref_pixel_wise_crosspond,
320
+ crosspond_effect_on=attn_config.self_attn_ref_effect_on,
321
+ crosspond_chain_pos=attn_config.self_attn_ref_chain_pos,
322
+ simple_3d=attn_config.use_simple3d_attn,
323
+ )
324
+ self.ref_unet_param_dict = {id(param): param for name, param in ref_unet.named_parameters() if id(param) not in used_param_ids and (attn_config.self_attn_ref_position in name)}
325
+ if attn_config.self_attn_ref_chain_pos != "after":
326
+ # pop untrainable paramters
327
+ for name, param in ref_unet.named_parameters():
328
+ if id(param) in self.ref_unet_param_dict and ('up_blocks.3.attentions.2.transformer_blocks.0.' in name):
329
+ self.ref_unet_param_dict.pop(id(param))
330
+ used_param_ids.update(self.ref_unet_param_dict.keys())
331
+ # update ref_attn_param_dict
332
+ self.ref_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
333
+ used_param_ids.update(self.ref_attn_param_dict.keys())
334
+
335
+ if attn_config.init_multiview_attn:
336
+ processor_dict = add_multiview_processor(
337
+ model = self.unet,
338
+ enable_filter = lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"),
339
+ num_modalities = attn_config.num_modalities,
340
+ base_img_size = attn_config.latent_size,
341
+ chain_pos = attn_config.multiview_chain_pose,
342
+ )
343
+ # update multiview_attn_param_dict
344
+ self.multiview_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
345
+ used_param_ids.update(self.multiview_attn_param_dict.keys())
346
+
347
+ # initialize cross_attn_param_dict parameters
348
+ self.cross_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn2" in name and id(param) not in used_param_ids}
349
+ used_param_ids.update(self.cross_attn_param_dict.keys())
350
+
351
+ # initialize self_attn_param_dict parameters
352
+ self.self_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn1" in name and id(param) not in used_param_ids}
353
+ used_param_ids.update(self.self_attn_param_dict.keys())
354
+
355
+ # initialize other_param_dict parameters
356
+ self.other_param_dict = {id(param): param for name, param in self.unet.named_parameters() if id(param) not in used_param_ids}
357
+
358
+ if ref_unet is not None:
359
+ self.unet.ref_unet = ref_unet
360
+
361
+ self.rev_param_name_mapping = {id(param): name for name, param in self.unet.named_parameters()}
362
+
363
+ self.update_config(attn_config, force_update=True)
364
+ return self.unet
365
+
366
+ _attn_keys_to_update = ["enable_cross_attn_lora", "enable_cross_attn_ip", "enable_self_attn_lora", "enable_self_attn_ref", "enable_multiview_attn", "cls_labels"]
367
+
368
+ def update_config(self, attn_config: AttnConfig, force_update=False):
369
+ assert isinstance(self.unet, UNet2DConditionModel), "unet must be an instance of UNet2DConditionModel"
370
+
371
+ need_to_update = False
372
+ # update cls_labels
373
+ for key in self._attn_keys_to_update:
374
+ if getattr(self.attn_config, key) != getattr(attn_config, key):
375
+ need_to_update = True
376
+ break
377
+ if not force_update and not need_to_update:
378
+ return
379
+
380
+ self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
381
+
382
+ # setup loras
383
+ if self.attn_config.init_cross_attn_lora or self.attn_config.init_self_attn_lora:
384
+ if attn_config.enable_cross_attn_lora or attn_config.enable_self_attn_lora:
385
+ cross_attn_lora_weight = 1. if attn_config.enable_cross_attn_lora > 0 else 0
386
+ self_attn_lora_weight = 1. if attn_config.enable_self_attn_lora > 0 else 0
387
+ self.unet.set_adapters(["cross_attn_lora", "self_attn_lora"], weights=[cross_attn_lora_weight, self_attn_lora_weight])
388
+ else:
389
+ self.unet.disable_adapters()
390
+
391
+ # setup ipadapter
392
+ if self.attn_config.init_cross_attn_ip:
393
+ if attn_config.enable_cross_attn_ip:
394
+ change_switch(self.unet, "ipadapter_switch", "ipadapter")
395
+ else:
396
+ change_switch(self.unet, "ipadapter_switch", "default")
397
+
398
+ # setup reference attention processor
399
+ if self.attn_config.init_self_attn_ref:
400
+ if attn_config.enable_self_attn_ref:
401
+ switch_extra_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"))
402
+ else:
403
+ switch_extra_processor(self.unet, enable_filter=lambda name: False)
404
+
405
+ # setup multiview attention processor
406
+ if self.attn_config.init_multiview_attn:
407
+ if attn_config.enable_multiview_attn:
408
+ switch_multiview_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"))
409
+ else:
410
+ switch_multiview_processor(self.unet, enable_filter=lambda name: False)
411
+
412
+ # update cls_labels
413
+ for key in self._attn_keys_to_update:
414
+ setattr(self.attn_config, key, getattr(attn_config, key))
415
+
416
+ def unet_forward_hook(self, raw_forward, sample: torch.FloatTensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, cross_attention_kwargs=None, condition_latents=None, class_labels=None, noisy_condition_input=False, cond_pixels_clip=None, **kwargs):
417
+ if class_labels is None and len(self.class_labels) > 0:
418
+ class_labels = self.class_labels.repeat(sample.shape[0] // self.class_labels.shape[0]).to(sample.device)
419
+ elif self.attn_config.init_num_cls_label != 0:
420
+ assert class_labels is not None, "class_labels should be passed if self.class_labels is empty and self.attn_config.init_num_cls_label is not 0"
421
+ if class_labels is not None:
422
+ if self.attn_config.cls_label_type == "embedding":
423
+ pass
424
+ else:
425
+ raise ValueError(f"cls_label_type {self.attn_config.cls_label_type} is not supported")
426
+ if self.attn_config.init_self_attn_ref and self.attn_config.enable_self_attn_ref:
427
+ # NOTE: extra step, extract condition
428
+ ref_dict = {}
429
+ ref_unet = self.get_refunet().to(sample.device)
430
+ assert condition_latents is not None
431
+ if self.attn_config.self_attn_ref_other_model_name == "self":
432
+ raise NotImplementedError()
433
+ else:
434
+ with torch.no_grad():
435
+ cond_encoder_hidden_states = encoder_hidden_states.reshape(condition_latents.shape[0], -1, *encoder_hidden_states.shape[1:])[:, 0]
436
+ if timestep.dim() == 0:
437
+ cond_timestep = timestep
438
+ else:
439
+ cond_timestep = timestep.reshape(condition_latents.shape[0], -1)[:, 0]
440
+ ref_unet(condition_latents, cond_timestep, cond_encoder_hidden_states, cross_attention_kwargs=dict(ref_dict=ref_dict))
441
+ # NOTE: extra step, inject condition
442
+ # Predict the noise residual and compute loss
443
+ if cross_attention_kwargs is None:
444
+ cross_attention_kwargs = {}
445
+ cross_attention_kwargs.update(ref_dict=ref_dict, mode='inject')
446
+ elif condition_latents is not None:
447
+ if not hasattr(self, 'condition_latents_raised'):
448
+ print("Warning! condition_latents is not None, but self_attn_ref is not enabled! This warning will only be raised once.")
449
+ self.condition_latents_raised = True
450
+
451
+ if self.attn_config.init_cross_attn_ip:
452
+ raise NotImplementedError()
453
+
454
+ if self.attn_config.cat_condition:
455
+ assert condition_latents is not None
456
+ B = condition_latents.shape[0]
457
+ cat_latents = condition_latents.reshape(B, 1, *condition_latents.shape[1:]).repeat(1, sample.shape[0] // B, 1, 1, 1).reshape(*sample.shape)
458
+ sample = torch.cat([sample, cat_latents], dim=1)
459
+
460
+ return raw_forward(sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, **kwargs)
custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # modified by Wuvin
15
+
16
+
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
23
+ from diffusers.schedulers import KarrasDiffusionSchedulers
24
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
25
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
26
+ from PIL import Image
27
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
28
+
29
+
30
+
31
+ class StableDiffusionImageCustomPipeline(
32
+ StableDiffusionImageVariationPipeline
33
+ ):
34
+ def __init__(
35
+ self,
36
+ vae: AutoencoderKL,
37
+ image_encoder: CLIPVisionModelWithProjection,
38
+ unet: UNet2DConditionModel,
39
+ scheduler: KarrasDiffusionSchedulers,
40
+ safety_checker: StableDiffusionSafetyChecker,
41
+ feature_extractor: CLIPImageProcessor,
42
+ requires_safety_checker: bool = True,
43
+ latents_offset=None,
44
+ noisy_cond_latents=False,
45
+ ):
46
+ super().__init__(
47
+ vae=vae,
48
+ image_encoder=image_encoder,
49
+ unet=unet,
50
+ scheduler=scheduler,
51
+ safety_checker=safety_checker,
52
+ feature_extractor=feature_extractor,
53
+ requires_safety_checker=requires_safety_checker
54
+ )
55
+ latents_offset = tuple(latents_offset) if latents_offset is not None else None
56
+ self.latents_offset = latents_offset
57
+ if latents_offset is not None:
58
+ self.register_to_config(latents_offset=latents_offset)
59
+ self.noisy_cond_latents = noisy_cond_latents
60
+ self.register_to_config(noisy_cond_latents=noisy_cond_latents)
61
+
62
+ def encode_latents(self, image, device, dtype, height, width):
63
+ # support batchsize > 1
64
+ if isinstance(image, Image.Image):
65
+ image = [image]
66
+ image = [img.convert("RGB") for img in image]
67
+ images = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype)
68
+ latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
69
+ if self.latents_offset is not None:
70
+ return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
71
+ else:
72
+ return latents
73
+
74
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
75
+ dtype = next(self.image_encoder.parameters()).dtype
76
+
77
+ if not isinstance(image, torch.Tensor):
78
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
79
+
80
+ image = image.to(device=device, dtype=dtype)
81
+ image_embeddings = self.image_encoder(image).image_embeds
82
+ image_embeddings = image_embeddings.unsqueeze(1)
83
+
84
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
85
+ bs_embed, seq_len, _ = image_embeddings.shape
86
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
87
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
88
+
89
+ if do_classifier_free_guidance:
90
+ # NOTE: the same as original code
91
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
92
+ # For classifier free guidance, we need to do two forward passes.
93
+ # Here we concatenate the unconditional and text embeddings into a single batch
94
+ # to avoid doing two forward passes
95
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
96
+
97
+ return image_embeddings
98
+
99
+ @torch.no_grad()
100
+ def __call__(
101
+ self,
102
+ image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
103
+ height: Optional[int] = 1024,
104
+ width: Optional[int] = 1024,
105
+ height_cond: Optional[int] = 512,
106
+ width_cond: Optional[int] = 512,
107
+ num_inference_steps: int = 50,
108
+ guidance_scale: float = 7.5,
109
+ num_images_per_prompt: Optional[int] = 1,
110
+ eta: float = 0.0,
111
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
112
+ latents: Optional[torch.FloatTensor] = None,
113
+ output_type: Optional[str] = "pil",
114
+ return_dict: bool = True,
115
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
116
+ callback_steps: int = 1,
117
+ upper_left_feature: bool = False,
118
+ ):
119
+ r"""
120
+ The call function to the pipeline for generation.
121
+
122
+ Args:
123
+ image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
124
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
125
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
126
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
127
+ The height in pixels of the generated image.
128
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
129
+ The width in pixels of the generated image.
130
+ num_inference_steps (`int`, *optional*, defaults to 50):
131
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
132
+ expense of slower inference. This parameter is modulated by `strength`.
133
+ guidance_scale (`float`, *optional*, defaults to 7.5):
134
+ A higher guidance scale value encourages the model to generate images closely linked to the text
135
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
136
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
137
+ The number of images to generate per prompt.
138
+ eta (`float`, *optional*, defaults to 0.0):
139
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
140
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
141
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
142
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
143
+ generation deterministic.
144
+ latents (`torch.FloatTensor`, *optional*):
145
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
146
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
147
+ tensor is generated by sampling using the supplied random `generator`.
148
+ output_type (`str`, *optional*, defaults to `"pil"`):
149
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
150
+ return_dict (`bool`, *optional*, defaults to `True`):
151
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
152
+ plain tuple.
153
+ callback (`Callable`, *optional*):
154
+ A function that calls every `callback_steps` steps during inference. The function is called with the
155
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
156
+ callback_steps (`int`, *optional*, defaults to 1):
157
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
158
+ every step.
159
+
160
+ Returns:
161
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
162
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
163
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
164
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
165
+ "not-safe-for-work" (nsfw) content.
166
+
167
+ Examples:
168
+
169
+ ```py
170
+ from diffusers import StableDiffusionImageVariationPipeline
171
+ from PIL import Image
172
+ from io import BytesIO
173
+ import requests
174
+
175
+ pipe = StableDiffusionImageVariationPipeline.from_pretrained(
176
+ "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
177
+ )
178
+ pipe = pipe.to("cuda")
179
+
180
+ url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
181
+
182
+ response = requests.get(url)
183
+ image = Image.open(BytesIO(response.content)).convert("RGB")
184
+
185
+ out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
186
+ out["images"][0].save("result.jpg")
187
+ ```
188
+ """
189
+ # 0. Default height and width to unet
190
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
191
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
192
+
193
+ # 1. Check inputs. Raise error if not correct
194
+ self.check_inputs(image, height, width, callback_steps)
195
+
196
+ # 2. Define call parameters
197
+ if isinstance(image, Image.Image):
198
+ batch_size = 1
199
+ elif isinstance(image, list):
200
+ batch_size = len(image)
201
+ else:
202
+ batch_size = image.shape[0]
203
+ device = self._execution_device
204
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
205
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
206
+ # corresponds to doing no classifier free guidance.
207
+ do_classifier_free_guidance = guidance_scale > 1.0
208
+
209
+ # 3. Encode input image
210
+ if isinstance(image, Image.Image) and upper_left_feature:
211
+ # only use the first one of four images
212
+ emb_image = image.crop((0, 0, image.size[0] // 2, image.size[1] // 2))
213
+ else:
214
+ emb_image = image
215
+
216
+ image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
217
+ cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
218
+
219
+ # 4. Prepare timesteps
220
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
221
+ timesteps = self.scheduler.timesteps
222
+
223
+ # 5. Prepare latent variables
224
+ num_channels_latents = self.unet.config.out_channels
225
+ latents = self.prepare_latents(
226
+ batch_size * num_images_per_prompt,
227
+ num_channels_latents,
228
+ height,
229
+ width,
230
+ image_embeddings.dtype,
231
+ device,
232
+ generator,
233
+ latents,
234
+ )
235
+
236
+ # 6. Prepare extra step kwargs.
237
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
238
+
239
+ # 7. Denoising loop
240
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
241
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
242
+ for i, t in enumerate(timesteps):
243
+ if self.noisy_cond_latents:
244
+ raise ValueError("Noisy condition latents is not recommended.")
245
+ else:
246
+ noisy_cond_latents = cond_latents
247
+
248
+ noisy_cond_latents = torch.cat([torch.zeros_like(noisy_cond_latents), noisy_cond_latents]) if do_classifier_free_guidance else noisy_cond_latents
249
+ # expand the latents if we are doing classifier free guidance
250
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
251
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
252
+
253
+ # predict the noise residual
254
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=noisy_cond_latents).sample
255
+
256
+ # perform guidance
257
+ if do_classifier_free_guidance:
258
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
259
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
260
+
261
+ # compute the previous noisy sample x_t -> x_t-1
262
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
263
+
264
+ # call the callback, if provided
265
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
266
+ progress_bar.update()
267
+ if callback is not None and i % callback_steps == 0:
268
+ step_idx = i // getattr(self.scheduler, "order", 1)
269
+ callback(step_idx, t, latents)
270
+
271
+ self.maybe_free_model_hooks()
272
+
273
+ if self.latents_offset is not None:
274
+ latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
275
+
276
+ if not output_type == "latent":
277
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
278
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
279
+ else:
280
+ image = latents
281
+ has_nsfw_concept = None
282
+
283
+ if has_nsfw_concept is None:
284
+ do_denormalize = [True] * image.shape[0]
285
+ else:
286
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
287
+
288
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
289
+
290
+ self.maybe_free_model_hooks()
291
+
292
+ if not return_dict:
293
+ return (image, has_nsfw_concept)
294
+
295
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
296
+
297
+ if __name__ == "__main__":
298
+ pass
custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # modified by Wuvin
15
+
16
+
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
23
+ from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler
24
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
25
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
26
+ from PIL import Image
27
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
28
+
29
+
30
+
31
+ class StableDiffusionImage2MVCustomPipeline(
32
+ StableDiffusionImageVariationPipeline
33
+ ):
34
+ def __init__(
35
+ self,
36
+ vae: AutoencoderKL,
37
+ image_encoder: CLIPVisionModelWithProjection,
38
+ unet: UNet2DConditionModel,
39
+ scheduler: KarrasDiffusionSchedulers,
40
+ safety_checker: StableDiffusionSafetyChecker,
41
+ feature_extractor: CLIPImageProcessor,
42
+ requires_safety_checker: bool = True,
43
+ latents_offset=None,
44
+ noisy_cond_latents=False,
45
+ condition_offset=True,
46
+ ):
47
+ super().__init__(
48
+ vae=vae,
49
+ image_encoder=image_encoder,
50
+ unet=unet,
51
+ scheduler=scheduler,
52
+ safety_checker=safety_checker,
53
+ feature_extractor=feature_extractor,
54
+ requires_safety_checker=requires_safety_checker
55
+ )
56
+ latents_offset = tuple(latents_offset) if latents_offset is not None else None
57
+ self.latents_offset = latents_offset
58
+ if latents_offset is not None:
59
+ self.register_to_config(latents_offset=latents_offset)
60
+ if noisy_cond_latents:
61
+ raise NotImplementedError("Noisy condition latents not supported Now.")
62
+ self.condition_offset = condition_offset
63
+ self.register_to_config(condition_offset=condition_offset)
64
+
65
+ def encode_latents(self, image: Image.Image, device, dtype, height, width):
66
+ images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype)
67
+ # NOTE: .mode() for condition
68
+ latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
69
+ if self.latents_offset is not None and self.condition_offset:
70
+ return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
71
+ else:
72
+ return latents
73
+
74
+ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
75
+ dtype = next(self.image_encoder.parameters()).dtype
76
+
77
+ if not isinstance(image, torch.Tensor):
78
+ image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
79
+
80
+ image = image.to(device=device, dtype=dtype)
81
+ image_embeddings = self.image_encoder(image).image_embeds
82
+ image_embeddings = image_embeddings.unsqueeze(1)
83
+
84
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
85
+ bs_embed, seq_len, _ = image_embeddings.shape
86
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
87
+ image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
88
+
89
+ if do_classifier_free_guidance:
90
+ # NOTE: the same as original code
91
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
92
+ # For classifier free guidance, we need to do two forward passes.
93
+ # Here we concatenate the unconditional and text embeddings into a single batch
94
+ # to avoid doing two forward passes
95
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
96
+
97
+ return image_embeddings
98
+
99
+ @torch.no_grad()
100
+ def __call__(
101
+ self,
102
+ image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
103
+ height: Optional[int] = 1024,
104
+ width: Optional[int] = 1024,
105
+ height_cond: Optional[int] = 512,
106
+ width_cond: Optional[int] = 512,
107
+ num_inference_steps: int = 50,
108
+ guidance_scale: float = 7.5,
109
+ num_images_per_prompt: Optional[int] = 1,
110
+ eta: float = 0.0,
111
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
112
+ latents: Optional[torch.FloatTensor] = None,
113
+ output_type: Optional[str] = "pil",
114
+ return_dict: bool = True,
115
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
116
+ callback_steps: int = 1,
117
+ ):
118
+ r"""
119
+ The call function to the pipeline for generation.
120
+
121
+ Args:
122
+ image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
123
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
124
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
125
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
126
+ The height in pixels of the generated image.
127
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
128
+ The width in pixels of the generated image.
129
+ num_inference_steps (`int`, *optional*, defaults to 50):
130
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
131
+ expense of slower inference. This parameter is modulated by `strength`.
132
+ guidance_scale (`float`, *optional*, defaults to 7.5):
133
+ A higher guidance scale value encourages the model to generate images closely linked to the text
134
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
135
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
136
+ The number of images to generate per prompt.
137
+ eta (`float`, *optional*, defaults to 0.0):
138
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
139
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
140
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
141
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
142
+ generation deterministic.
143
+ latents (`torch.FloatTensor`, *optional*):
144
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
145
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
146
+ tensor is generated by sampling using the supplied random `generator`.
147
+ output_type (`str`, *optional*, defaults to `"pil"`):
148
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
149
+ return_dict (`bool`, *optional*, defaults to `True`):
150
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
151
+ plain tuple.
152
+ callback (`Callable`, *optional*):
153
+ A function that calls every `callback_steps` steps during inference. The function is called with the
154
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
155
+ callback_steps (`int`, *optional*, defaults to 1):
156
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
157
+ every step.
158
+
159
+ Returns:
160
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
161
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
162
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
163
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
164
+ "not-safe-for-work" (nsfw) content.
165
+
166
+ Examples:
167
+
168
+ ```py
169
+ from diffusers import StableDiffusionImageVariationPipeline
170
+ from PIL import Image
171
+ from io import BytesIO
172
+ import requests
173
+
174
+ pipe = StableDiffusionImageVariationPipeline.from_pretrained(
175
+ "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
176
+ )
177
+ pipe = pipe.to("cuda")
178
+
179
+ url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
180
+
181
+ response = requests.get(url)
182
+ image = Image.open(BytesIO(response.content)).convert("RGB")
183
+
184
+ out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
185
+ out["images"][0].save("result.jpg")
186
+ ```
187
+ """
188
+ # 0. Default height and width to unet
189
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
190
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
191
+
192
+ # 1. Check inputs. Raise error if not correct
193
+ self.check_inputs(image, height, width, callback_steps)
194
+
195
+ # 2. Define call parameters
196
+ if isinstance(image, Image.Image):
197
+ batch_size = 1
198
+ elif len(image) == 1:
199
+ image = image[0]
200
+ batch_size = 1
201
+ else:
202
+ raise NotImplementedError()
203
+ # elif isinstance(image, list):
204
+ # batch_size = len(image)
205
+ # else:
206
+ # batch_size = image.shape[0]
207
+ device = self._execution_device
208
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
209
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
210
+ # corresponds to doing no classifier free guidance.
211
+ do_classifier_free_guidance = guidance_scale > 1.0
212
+
213
+ # 3. Encode input image
214
+ emb_image = image
215
+
216
+ image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
217
+ cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
218
+ cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents
219
+ image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values
220
+ if do_classifier_free_guidance:
221
+ image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0)
222
+
223
+ # 4. Prepare timesteps
224
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
225
+ timesteps = self.scheduler.timesteps
226
+
227
+ # 5. Prepare latent variables
228
+ num_channels_latents = self.unet.config.out_channels
229
+ latents = self.prepare_latents(
230
+ batch_size * num_images_per_prompt,
231
+ num_channels_latents,
232
+ height,
233
+ width,
234
+ image_embeddings.dtype,
235
+ device,
236
+ generator,
237
+ latents,
238
+ )
239
+
240
+
241
+ # 6. Prepare extra step kwargs.
242
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
243
+ # 7. Denoising loop
244
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
245
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
246
+ for i, t in enumerate(timesteps):
247
+ # expand the latents if we are doing classifier free guidance
248
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
249
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
250
+
251
+ # predict the noise residual
252
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=cond_latents, noisy_condition_input=False, cond_pixels_clip=image_pixels).sample
253
+
254
+ # perform guidance
255
+ if do_classifier_free_guidance:
256
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
257
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
258
+
259
+ # compute the previous noisy sample x_t -> x_t-1
260
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
261
+
262
+ # call the callback, if provided
263
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
264
+ progress_bar.update()
265
+ if callback is not None and i % callback_steps == 0:
266
+ step_idx = i // getattr(self.scheduler, "order", 1)
267
+ callback(step_idx, t, latents)
268
+
269
+ self.maybe_free_model_hooks()
270
+
271
+ if self.latents_offset is not None:
272
+ latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
273
+
274
+ if not output_type == "latent":
275
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
276
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
277
+ else:
278
+ image = latents
279
+ has_nsfw_concept = None
280
+
281
+ if has_nsfw_concept is None:
282
+ do_denormalize = [True] * image.shape[0]
283
+ else:
284
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
285
+
286
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
287
+
288
+ self.maybe_free_model_hooks()
289
+
290
+ if not return_dict:
291
+ return (image, has_nsfw_concept)
292
+
293
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
294
+
295
+ if __name__ == "__main__":
296
+ pass
custum_3d_diffusion/modules.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __modules__ = {}
2
+
3
+ def register(name):
4
+ def decorator(cls):
5
+ __modules__[name] = cls
6
+ return cls
7
+
8
+ return decorator
9
+
10
+
11
+ def find(name):
12
+ return __modules__[name]
13
+
14
+ from custum_3d_diffusion.trainings import base, image2mvimage_trainer, image2image_trainer
custum_3d_diffusion/trainings/__init__.py ADDED
File without changes
custum_3d_diffusion/trainings/base.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from accelerate import Accelerator
3
+ from accelerate.logging import MultiProcessAdapter
4
+ from dataclasses import dataclass, field
5
+ from typing import Optional, Union
6
+ from datasets import load_dataset
7
+ import json
8
+ import abc
9
+ from diffusers.utils import make_image_grid
10
+ import numpy as np
11
+ import wandb
12
+
13
+ from custum_3d_diffusion.trainings.utils import load_config
14
+ from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig
15
+
16
+ class BasicTrainer(torch.nn.Module, abc.ABC):
17
+ accelerator: Accelerator
18
+ logger: MultiProcessAdapter
19
+ unet: ConfigurableUNet2DConditionModel
20
+ train_dataloader: torch.utils.data.DataLoader
21
+ test_dataset: torch.utils.data.Dataset
22
+ attn_config: AttnConfig
23
+
24
+ @dataclass
25
+ class TrainerConfig:
26
+ trainer_name: str = "basic"
27
+ pretrained_model_name_or_path: str = ""
28
+
29
+ attn_config: dict = field(default_factory=dict)
30
+ dataset_name: str = ""
31
+ dataset_config_name: Optional[str] = None
32
+ resolution: str = "1024"
33
+ dataloader_num_workers: int = 4
34
+ pair_sampler_group_size: int = 1
35
+ num_views: int = 4
36
+
37
+ max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps)
38
+ training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps
39
+ max_train_samples: Optional[int] = None
40
+ seed: Optional[int] = None # For dataset related operations and validation stuff
41
+ train_batch_size: int = 1
42
+
43
+ validation_interval: int = 5000
44
+ debug: bool = False
45
+
46
+ cfg: TrainerConfig # only enable_xxx is used
47
+
48
+ def __init__(
49
+ self,
50
+ accelerator: Accelerator,
51
+ logger: MultiProcessAdapter,
52
+ unet: ConfigurableUNet2DConditionModel,
53
+ config: Union[dict, str],
54
+ weight_dtype: torch.dtype,
55
+ index: int,
56
+ ):
57
+ super().__init__()
58
+ self.index = index # index in all trainers
59
+ self.accelerator = accelerator
60
+ self.logger = logger
61
+ self.unet = unet
62
+ self.weight_dtype = weight_dtype
63
+ self.ext_logs = {}
64
+ self.cfg = load_config(self.TrainerConfig, config)
65
+ self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
66
+ self.test_dataset = None
67
+ self.validate_trainer_config()
68
+ self.configure()
69
+
70
+ def get_HW(self):
71
+ resolution = json.loads(self.cfg.resolution)
72
+ if isinstance(resolution, int):
73
+ H = W = resolution
74
+ elif isinstance(resolution, list):
75
+ H, W = resolution
76
+ return H, W
77
+
78
+ def unet_update(self):
79
+ self.unet.update_config(self.attn_config)
80
+
81
+ def validate_trainer_config(self):
82
+ pass
83
+
84
+ def is_train_finished(self, current_step):
85
+ assert isinstance(self.cfg.max_train_steps, int)
86
+ return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
87
+
88
+ def next_train_step(self, current_step):
89
+ if self.is_train_finished(current_step):
90
+ return None
91
+ return current_step + self.cfg.training_step_interval
92
+
93
+ @classmethod
94
+ def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
95
+ catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
96
+ return make_image_grid(catted, rows=1, cols=len(catted))
97
+
98
+ def configure(self) -> None:
99
+ pass
100
+
101
+ @abc.abstractmethod
102
+ def init_shared_modules(self, shared_modules: dict) -> dict:
103
+ pass
104
+
105
+ def load_dataset(self):
106
+ dataset = load_dataset(
107
+ self.cfg.dataset_name,
108
+ self.cfg.dataset_config_name,
109
+ trust_remote_code=True
110
+ )
111
+ return dataset
112
+
113
+ @abc.abstractmethod
114
+ def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
115
+ """Both init train_dataloader and test_dataset, but returns train_dataloader only"""
116
+ pass
117
+
118
+ @abc.abstractmethod
119
+ def forward_step(
120
+ self,
121
+ *args,
122
+ **kwargs
123
+ ) -> torch.Tensor:
124
+ """
125
+ input a batch
126
+ return a loss
127
+ """
128
+ self.unet_update()
129
+ pass
130
+
131
+ @abc.abstractmethod
132
+ def construct_pipeline(self, shared_modules, unet):
133
+ pass
134
+
135
+ @abc.abstractmethod
136
+ def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
137
+ """
138
+ For inference time forward.
139
+ """
140
+ pass
141
+
142
+ @abc.abstractmethod
143
+ def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
144
+ pass
145
+
146
+ def do_validation(
147
+ self,
148
+ shared_modules,
149
+ unet,
150
+ global_step,
151
+ ):
152
+ self.unet_update()
153
+ self.logger.info("Running validation... ")
154
+ pipeline = self.construct_pipeline(shared_modules, unet)
155
+ pipeline.set_progress_bar_config(disable=True)
156
+ titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
157
+ for tracker in self.accelerator.trackers:
158
+ if tracker.name == "tensorboard":
159
+ np_images = np.stack([np.asarray(img) for img in images])
160
+ tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
161
+ elif tracker.name == "wandb":
162
+ [image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation
163
+ tracker.log({"validation": [
164
+ wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
165
+ for i, image in enumerate(images)]})
166
+ else:
167
+ self.logger.warn(f"image logging not implemented for {tracker.name}")
168
+ del pipeline
169
+ torch.cuda.empty_cache()
170
+ return images
171
+
172
+
173
+ @torch.no_grad()
174
+ def log_validation(
175
+ self,
176
+ shared_modules,
177
+ unet,
178
+ global_step,
179
+ force=False
180
+ ):
181
+ if self.accelerator.is_main_process:
182
+ for tracker in self.accelerator.trackers:
183
+ if tracker.name == "wandb":
184
+ tracker.log(self.ext_logs)
185
+ self.ext_logs = {}
186
+ if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
187
+ self.unet_update()
188
+ if self.accelerator.is_main_process:
189
+ self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)
190
+
191
+ def save_model(self, unwrap_unet, shared_modules, save_dir):
192
+ if self.accelerator.is_main_process:
193
+ pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
194
+ pipeline.save_pretrained(save_dir)
195
+ self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")
196
+
197
+ def save_debug_info(self, save_name="debug", **kwargs):
198
+ if self.cfg.debug:
199
+ to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
200
+ import pickle
201
+ import os
202
+ if os.path.exists(f"{save_name}.pkl"):
203
+ for i in range(100):
204
+ if not os.path.exists(f"{save_name}_v{i}.pkl"):
205
+ save_name = f"{save_name}_v{i}"
206
+ break
207
+ with open(f"{save_name}.pkl", "wb") as f:
208
+ pickle.dump(to_saves, f)
custum_3d_diffusion/trainings/config_classes.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional
3
+
4
+
5
+ @dataclass
6
+ class TrainerSubConfig:
7
+ trainer_type: str = ""
8
+ trainer: dict = field(default_factory=dict)
9
+
10
+
11
+ @dataclass
12
+ class ExprimentConfig:
13
+ trainers: List[dict] = field(default_factory=lambda: [])
14
+ init_config: dict = field(default_factory=dict)
15
+ pretrained_model_name_or_path: str = ""
16
+ pretrained_unet_state_dict_path: str = ""
17
+ # expriments related parameters
18
+ linear_beta_schedule: bool = False
19
+ zero_snr: bool = False
20
+ prediction_type: Optional[str] = None
21
+ seed: Optional[int] = None
22
+ max_train_steps: int = 1000000
23
+ gradient_accumulation_steps: int = 1
24
+ learning_rate: float = 1e-4
25
+ lr_scheduler: str = "constant"
26
+ lr_warmup_steps: int = 500
27
+ use_8bit_adam: bool = False
28
+ adam_beta1: float = 0.9
29
+ adam_beta2: float = 0.999
30
+ adam_weight_decay: float = 1e-2
31
+ adam_epsilon: float = 1e-08
32
+ max_grad_norm: float = 1.0
33
+ mixed_precision: Optional[str] = None # ["no", "fp16", "bf16", "fp8"]
34
+ skip_training: bool = False
35
+ debug: bool = False
custum_3d_diffusion/trainings/image2image_trainer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from diffusers import EulerAncestralDiscreteScheduler, DDPMScheduler
4
+ from dataclasses import dataclass
5
+
6
+ from custum_3d_diffusion.modules import register
7
+ from custum_3d_diffusion.trainings.image2mvimage_trainer import Image2MVImageTrainer
8
+ from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2img import StableDiffusionImageCustomPipeline
9
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
10
+
11
+ def get_HW(resolution):
12
+ if isinstance(resolution, str):
13
+ resolution = json.loads(resolution)
14
+ if isinstance(resolution, int):
15
+ H = W = resolution
16
+ elif isinstance(resolution, list):
17
+ H, W = resolution
18
+ return H, W
19
+
20
+
21
+ @register("image2image_trainer")
22
+ class Image2ImageTrainer(Image2MVImageTrainer):
23
+ """
24
+ Trainer for simple image to multiview images.
25
+ """
26
+ @dataclass
27
+ class TrainerConfig(Image2MVImageTrainer.TrainerConfig):
28
+ trainer_name: str = "image2image"
29
+
30
+ cfg: TrainerConfig
31
+
32
+ def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
33
+ raise NotImplementedError()
34
+
35
+ def construct_pipeline(self, shared_modules, unet, old_version=False):
36
+ MyPipeline = StableDiffusionImageCustomPipeline
37
+ pipeline = MyPipeline.from_pretrained(
38
+ self.cfg.pretrained_model_name_or_path,
39
+ vae=shared_modules['vae'],
40
+ image_encoder=shared_modules['image_encoder'],
41
+ feature_extractor=shared_modules['feature_extractor'],
42
+ unet=unet,
43
+ safety_checker=None,
44
+ torch_dtype=self.weight_dtype,
45
+ latents_offset=self.cfg.latents_offset,
46
+ noisy_cond_latents=self.cfg.noisy_condition_input,
47
+ )
48
+ pipeline.set_progress_bar_config(disable=True)
49
+ scheduler_dict = {}
50
+ if self.cfg.zero_snr:
51
+ scheduler_dict.update(rescale_betas_zero_snr=True)
52
+ if self.cfg.linear_beta_schedule:
53
+ scheduler_dict.update(beta_schedule='linear')
54
+
55
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
56
+ return pipeline
57
+
58
+ def get_forward_args(self):
59
+ if self.cfg.seed is None:
60
+ generator = None
61
+ else:
62
+ generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
63
+
64
+ H, W = get_HW(self.cfg.resolution)
65
+ H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
66
+
67
+ forward_args = dict(
68
+ num_images_per_prompt=1,
69
+ num_inference_steps=20,
70
+ height=H,
71
+ width=W,
72
+ height_cond=H_cond,
73
+ width_cond=W_cond,
74
+ generator=generator,
75
+ )
76
+ if self.cfg.zero_snr:
77
+ forward_args.update(guidance_rescale=0.7)
78
+ return forward_args
79
+
80
+ def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
81
+ forward_args = self.get_forward_args()
82
+ forward_args.update(pipeline_call_kwargs)
83
+ return pipeline(**forward_args)
84
+
85
+ def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
86
+ raise NotImplementedError()
custum_3d_diffusion/trainings/image2mvimage_trainer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler
3
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature
4
+
5
+ import json
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional
8
+
9
+ from custum_3d_diffusion.modules import register
10
+ from custum_3d_diffusion.trainings.base import BasicTrainer
11
+ from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline
12
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
+
14
+ def get_HW(resolution):
15
+ if isinstance(resolution, str):
16
+ resolution = json.loads(resolution)
17
+ if isinstance(resolution, int):
18
+ H = W = resolution
19
+ elif isinstance(resolution, list):
20
+ H, W = resolution
21
+ return H, W
22
+
23
+ @register("image2mvimage_trainer")
24
+ class Image2MVImageTrainer(BasicTrainer):
25
+ """
26
+ Trainer for simple image to multiview images.
27
+ """
28
+ @dataclass
29
+ class TrainerConfig(BasicTrainer.TrainerConfig):
30
+ trainer_name: str = "image2mvimage"
31
+ condition_image_column_name: str = "conditioning_image"
32
+ image_column_name: str = "image"
33
+ condition_dropout: float = 0.
34
+ condition_image_resolution: str = "512"
35
+ validation_images: Optional[List[str]] = None
36
+ noise_offset: float = 0.1
37
+ max_loss_drop: float = 0.
38
+ snr_gamma: float = 5.0
39
+ log_distribution: bool = False
40
+ latents_offset: Optional[List[float]] = None
41
+ input_perturbation: float = 0.
42
+ noisy_condition_input: bool = False # whether to add noise for ref unet input
43
+ normal_cls_offset: int = 0
44
+ condition_offset: bool = True
45
+ zero_snr: bool = False
46
+ linear_beta_schedule: bool = False
47
+
48
+ cfg: TrainerConfig
49
+
50
+ def configure(self) -> None:
51
+ return super().configure()
52
+
53
+ def init_shared_modules(self, shared_modules: dict) -> dict:
54
+ if 'vae' not in shared_modules:
55
+ vae = AutoencoderKL.from_pretrained(
56
+ self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype
57
+ )
58
+ vae.requires_grad_(False)
59
+ vae.to(self.accelerator.device, dtype=self.weight_dtype)
60
+ shared_modules['vae'] = vae
61
+ if 'image_encoder' not in shared_modules:
62
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
63
+ self.cfg.pretrained_model_name_or_path, subfolder="image_encoder"
64
+ )
65
+ image_encoder.requires_grad_(False)
66
+ image_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
67
+ shared_modules['image_encoder'] = image_encoder
68
+ if 'feature_extractor' not in shared_modules:
69
+ feature_extractor = CLIPImageProcessor.from_pretrained(
70
+ self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor"
71
+ )
72
+ shared_modules['feature_extractor'] = feature_extractor
73
+ return shared_modules
74
+
75
+ def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
76
+ raise NotImplementedError()
77
+
78
+ def loss_rescale(self, loss, timesteps=None):
79
+ raise NotImplementedError()
80
+
81
+ def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
82
+ raise NotImplementedError()
83
+
84
+ def construct_pipeline(self, shared_modules, unet, old_version=False):
85
+ MyPipeline = StableDiffusionImage2MVCustomPipeline
86
+ pipeline = MyPipeline.from_pretrained(
87
+ self.cfg.pretrained_model_name_or_path,
88
+ vae=shared_modules['vae'],
89
+ image_encoder=shared_modules['image_encoder'],
90
+ feature_extractor=shared_modules['feature_extractor'],
91
+ unet=unet,
92
+ safety_checker=None,
93
+ torch_dtype=self.weight_dtype,
94
+ latents_offset=self.cfg.latents_offset,
95
+ noisy_cond_latents=self.cfg.noisy_condition_input,
96
+ condition_offset=self.cfg.condition_offset,
97
+ )
98
+ pipeline.set_progress_bar_config(disable=True)
99
+ scheduler_dict = {}
100
+ if self.cfg.zero_snr:
101
+ scheduler_dict.update(rescale_betas_zero_snr=True)
102
+ if self.cfg.linear_beta_schedule:
103
+ scheduler_dict.update(beta_schedule='linear')
104
+
105
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
106
+ return pipeline
107
+
108
+ def get_forward_args(self):
109
+ if self.cfg.seed is None:
110
+ generator = None
111
+ else:
112
+ generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
113
+
114
+ H, W = get_HW(self.cfg.resolution)
115
+ H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
116
+
117
+ sub_img_H = H // 2
118
+ num_imgs = H // sub_img_H * W // sub_img_H
119
+
120
+ forward_args = dict(
121
+ num_images_per_prompt=num_imgs,
122
+ num_inference_steps=50,
123
+ height=sub_img_H,
124
+ width=sub_img_H,
125
+ height_cond=H_cond,
126
+ width_cond=W_cond,
127
+ generator=generator,
128
+ )
129
+ if self.cfg.zero_snr:
130
+ forward_args.update(guidance_rescale=0.7)
131
+ return forward_args
132
+
133
+ def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
134
+ forward_args = self.get_forward_args()
135
+ forward_args.update(pipeline_call_kwargs)
136
+ return pipeline(**forward_args)
137
+
138
+ def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
139
+ raise NotImplementedError()
custum_3d_diffusion/trainings/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import DictConfig, OmegaConf
2
+
3
+
4
+ def parse_structured(fields, cfg) -> DictConfig:
5
+ scfg = OmegaConf.structured(fields(**cfg))
6
+ return scfg
7
+
8
+
9
+ def load_config(fields, config, extras=None):
10
+ if extras is not None:
11
+ print("Warning! extra parameter in cli is not verified, may cause erros.")
12
+ if isinstance(config, str):
13
+ cfg = OmegaConf.load(config)
14
+ elif isinstance(config, dict):
15
+ cfg = OmegaConf.create(config)
16
+ elif isinstance(config, DictConfig):
17
+ cfg = config
18
+ else:
19
+ raise NotImplementedError(f"Unsupported config type {type(config)}")
20
+ if extras is not None:
21
+ cli_conf = OmegaConf.from_cli(extras)
22
+ cfg = OmegaConf.merge(cfg, cli_conf)
23
+ OmegaConf.resolve(cfg)
24
+ assert isinstance(cfg, DictConfig)
25
+ return parse_structured(fields, cfg)
docker/Dockerfile ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # get the development image from nvidia cuda 12.1
2
+ FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
3
+
4
+ LABEL name="unique3d" maintainer="unique3d"
5
+
6
+ # create workspace folder and set it as working directory
7
+ RUN mkdir -p /workspace
8
+ WORKDIR /workspace
9
+
10
+ # update package lists and install git, wget, vim, libegl1-mesa-dev, and libglib2.0-0
11
+ RUN apt-get update && apt-get install -y build-essential git wget vim libegl1-mesa-dev libglib2.0-0 unzip git-lfs
12
+
13
+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev cmake curl mesa-utils-extra
14
+ ENV PYTHONDONTWRITEBYTECODE=1
15
+ ENV PYTHONUNBUFFERED=1
16
+ ENV LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
17
+ ENV PYOPENGL_PLATFORM=egl
18
+
19
+ # install conda
20
+ RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
21
+ chmod +x Miniconda3-latest-Linux-x86_64.sh && \
22
+ ./Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 && \
23
+ rm Miniconda3-latest-Linux-x86_64.sh
24
+
25
+ # update PATH environment variable
26
+ ENV PATH="/workspace/miniconda3/bin:${PATH}"
27
+
28
+ # initialize conda
29
+ RUN conda init bash
30
+
31
+ # create and activate conda environment
32
+ RUN conda create -n unique3d python=3.10 && echo "source activate unique3d" > ~/.bashrc
33
+ ENV PATH /workspace/miniconda3/envs/unique3d/bin:$PATH
34
+
35
+ RUN conda install Ninja
36
+ RUN conda install cuda -c nvidia/label/cuda-12.1.0 -y
37
+
38
+ RUN pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 xformers triton --index-url https://download.pytorch.org/whl/cu121
39
+ RUN pip install diffusers==0.27.2
40
+
41
+ RUN git clone --depth 1 https://huggingface.co/spaces/Wuvin/Unique3D
42
+
43
+ # change the working directory to the repository
44
+
45
+ WORKDIR /workspace/Unique3D
46
+ # other dependencies
47
+ RUN pip install -r requirements.txt
48
+
49
+ RUN pip install nvidia-pyindex
50
+
51
+ RUN pip install --upgrade nvidia-tensorrt
52
+
53
+ RUN pip install spaces
54
+
docker/README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker setup
2
+
3
+ This docker setup is tested on Windows 10.
4
+
5
+ make sure you are under this directory yourworkspace/Unique3D/docker
6
+
7
+ Build docker image:
8
+
9
+ ```
10
+ docker build -t unique3d -f Dockerfile .
11
+ ```
12
+
13
+ Run docker image at the first time:
14
+
15
+ ```
16
+ docker run -it --name unique3d -p 7860:7860 --gpus all unique3d python app.py
17
+ ```
18
+
19
+ After first time:
20
+ ```
21
+ docker start unique3d
22
+ docker exec unique3d python app.py
23
+ ```
24
+
25
+ Stop the container:
26
+ ```
27
+ docker stop unique3d
28
+ ```
29
+
30
+ You can find the demo link showing in terminal, such as `https://94fc1ba77a08526e17.gradio.live/` or something similar else (it will be changed after each time to restart the container) to use the demo.
31
+
32
+ Some notes:
33
+ 1. this docker build is using https://huggingface.co/spaces/Wuvin/Unique3D rather than this repo to clone the source.
34
+ 2. the total built time might take more than one hour.
35
+ 3. the total size of the built image will be more than 70GB.
gradio_app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import os
3
+ import sys
4
+ sys.path.append(os.curdir)
5
+ import torch
6
+ torch.set_float32_matmul_precision('medium')
7
+ torch.backends.cuda.matmul.allow_tf32 = True
8
+ torch.set_grad_enabled(False)
9
+
10
+ import fire
11
+ import gradio as gr
12
+ from app.gradio_3dgen import create_ui as create_3d_ui
13
+ from app.all_models import model_zoo
14
+
15
+
16
+ _TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
17
+ _DESCRIPTION = '''
18
+ [Project page](https://wukailu.github.io/Unique3D/)
19
+
20
+ * High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
21
+
22
+ * The demo is still under construction, and more features are expected to be implemented soon.
23
+ '''
24
+
25
+ def launch():
26
+ model_zoo.init_models()
27
+
28
+ with gr.Blocks(
29
+ title=_TITLE,
30
+ theme=gr.themes.Monochrome(),
31
+ ) as demo:
32
+ with gr.Row():
33
+ with gr.Column(scale=1):
34
+ gr.Markdown('# ' + _TITLE)
35
+ gr.Markdown(_DESCRIPTION)
36
+ create_3d_ui("wkl")
37
+
38
+ demo.queue().launch(share=True)
39
+
40
+ if __name__ == '__main__':
41
+ fire.Fire(launch)