Spaces:
Build error
Build error
charbel-malo
commited on
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .editorconfig +12 -0
- .gitattributes +3 -0
- .gitignore +217 -0
- Installation.md +170 -0
- LICENSE +21 -0
- README.md +132 -7
- README_jp.md +126 -0
- README_zh.md +62 -0
- app/__init__.py +0 -0
- app/all_models.py +22 -0
- app/custom_models/image2mvimage.yaml +63 -0
- app/custom_models/image2normal.yaml +61 -0
- app/custom_models/mvimg_prediction.py +57 -0
- app/custom_models/normal_prediction.py +26 -0
- app/custom_models/utils.py +75 -0
- app/examples/Groot.png +0 -0
- app/examples/aaa.png +0 -0
- app/examples/abma.png +0 -0
- app/examples/akun.png +0 -0
- app/examples/anya.png +0 -0
- app/examples/bag.png +3 -0
- app/examples/ex1.png +3 -0
- app/examples/ex2.png +0 -0
- app/examples/ex3.jpg +0 -0
- app/examples/ex4.png +0 -0
- app/examples/generated_1715761545_frame0.png +0 -0
- app/examples/generated_1715762357_frame0.png +0 -0
- app/examples/generated_1715763329_frame0.png +0 -0
- app/examples/hatsune_miku.png +0 -0
- app/examples/princess-large.png +0 -0
- app/gradio_3dgen.py +71 -0
- app/gradio_3dgen_steps.py +87 -0
- app/gradio_local.py +76 -0
- app/utils.py +112 -0
- assets/teaser.jpg +0 -0
- assets/teaser_safe.jpg +3 -0
- custum_3d_diffusion/custum_modules/attention_processors.py +385 -0
- custum_3d_diffusion/custum_modules/unifield_processor.py +460 -0
- custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py +298 -0
- custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py +296 -0
- custum_3d_diffusion/modules.py +14 -0
- custum_3d_diffusion/trainings/__init__.py +0 -0
- custum_3d_diffusion/trainings/base.py +208 -0
- custum_3d_diffusion/trainings/config_classes.py +35 -0
- custum_3d_diffusion/trainings/image2image_trainer.py +86 -0
- custum_3d_diffusion/trainings/image2mvimage_trainer.py +139 -0
- custum_3d_diffusion/trainings/utils.py +25 -0
- docker/Dockerfile +54 -0
- docker/README.md +35 -0
- gradio_app.py +41 -0
.editorconfig
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
root = true
|
2 |
+
|
3 |
+
[*.py]
|
4 |
+
charset = utf-8
|
5 |
+
trim_trailing_whitespace = true
|
6 |
+
end_of_line = lf
|
7 |
+
insert_final_newline = true
|
8 |
+
indent_style = space
|
9 |
+
indent_size = 4
|
10 |
+
|
11 |
+
[*.md]
|
12 |
+
trim_trailing_whitespace = false
|
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
app/examples/bag.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
app/examples/ex1.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/teaser_safe.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/python
|
2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
3 |
+
|
4 |
+
### Python ###
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
share/python-wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
cover/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
.pybuilder/
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
# For a library or package, you might want to ignore these files since the code is
|
91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
92 |
+
# .python-version
|
93 |
+
|
94 |
+
# pipenv
|
95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
98 |
+
# install all needed dependencies.
|
99 |
+
#Pipfile.lock
|
100 |
+
|
101 |
+
# poetry
|
102 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
103 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
104 |
+
# commonly ignored for libraries.
|
105 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
106 |
+
#poetry.lock
|
107 |
+
|
108 |
+
# pdm
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
110 |
+
#pdm.lock
|
111 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
112 |
+
# in version control.
|
113 |
+
# https://pdm.fming.dev/#use-with-ide
|
114 |
+
.pdm.toml
|
115 |
+
|
116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
117 |
+
__pypackages__/
|
118 |
+
|
119 |
+
# Celery stuff
|
120 |
+
celerybeat-schedule
|
121 |
+
celerybeat.pid
|
122 |
+
|
123 |
+
# SageMath parsed files
|
124 |
+
*.sage.py
|
125 |
+
|
126 |
+
# Environments
|
127 |
+
.env
|
128 |
+
.venv
|
129 |
+
env/
|
130 |
+
venv/
|
131 |
+
ENV/
|
132 |
+
env.bak/
|
133 |
+
venv.bak/
|
134 |
+
|
135 |
+
# Spyder project settings
|
136 |
+
.spyderproject
|
137 |
+
.spyproject
|
138 |
+
|
139 |
+
# Rope project settings
|
140 |
+
.ropeproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# pytype static type analyzer
|
154 |
+
.pytype/
|
155 |
+
|
156 |
+
# Cython debug symbols
|
157 |
+
cython_debug/
|
158 |
+
|
159 |
+
# PyCharm
|
160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
164 |
+
.idea/
|
165 |
+
|
166 |
+
### Python Patch ###
|
167 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
168 |
+
poetry.toml
|
169 |
+
|
170 |
+
# ruff
|
171 |
+
.ruff_cache/
|
172 |
+
|
173 |
+
# LSP config files
|
174 |
+
pyrightconfig.json
|
175 |
+
|
176 |
+
# End of https://www.toptal.com/developers/gitignore/api/python
|
177 |
+
|
178 |
+
.vscode/
|
179 |
+
.threestudio_cache/
|
180 |
+
outputs
|
181 |
+
outputs/
|
182 |
+
outputs-gradio
|
183 |
+
outputs-gradio/
|
184 |
+
lightning_logs/
|
185 |
+
|
186 |
+
# pretrained model weights
|
187 |
+
*.ckpt
|
188 |
+
*.pt
|
189 |
+
*.pth
|
190 |
+
*.bin
|
191 |
+
*.param
|
192 |
+
|
193 |
+
# wandb
|
194 |
+
wandb/
|
195 |
+
|
196 |
+
# obj results
|
197 |
+
*.obj
|
198 |
+
*.glb
|
199 |
+
*.ply
|
200 |
+
|
201 |
+
# ckpts
|
202 |
+
ckpt/*
|
203 |
+
*.pth
|
204 |
+
*.pt
|
205 |
+
|
206 |
+
# tensorrt
|
207 |
+
*.engine
|
208 |
+
*.profile
|
209 |
+
|
210 |
+
# zipfiles
|
211 |
+
*.zip
|
212 |
+
*.tar
|
213 |
+
*.tar.gz
|
214 |
+
|
215 |
+
# others
|
216 |
+
run_30.sh
|
217 |
+
ckpt
|
Installation.md
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 官方安装指南
|
2 |
+
|
3 |
+
* 在 requirements-detail.txt 里,我们提供了详细的各个库的版本,这个对应的环境是 `python3.10 + cuda12.2`。
|
4 |
+
* 本项目依赖于几个重要的pypi包,这几个包安装起来会有一些困难。
|
5 |
+
|
6 |
+
### nvdiffrast 安装
|
7 |
+
|
8 |
+
* nvdiffrast 会在第一次运行时,编译对应的torch插件,这一步需要 ninja 及 cudatoolkit的支持。
|
9 |
+
* 因此需要先确保正确安装了 ninja 以及 cudatoolkit 并正确配置了 CUDA_HOME 环境变量。
|
10 |
+
* cudatoolkit 安装可以参考 [linux-cuda-installation-guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html), [windows-cuda-installation-guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html)
|
11 |
+
* ninja 则可用直接 `pip install ninja`
|
12 |
+
* 然后设置 CUDA_HOME 变量为 cudatoolkit 的安装目录,如 `/usr/local/cuda`。
|
13 |
+
* 最后 `pip install nvdiffrast` 即可。
|
14 |
+
* 如果无法在目标服务器上安装 cudatoolkit (如权限不够),可用使用我修改的[预编译版本 nvdiffrast](https://github.com/wukailu/nvdiffrast-torch) 在另一台拥有 cudatoolkit 且环境相似(python, torch, cuda版本相同)的服务器上预编译后安装。
|
15 |
+
|
16 |
+
### onnxruntime-gpu 安装
|
17 |
+
|
18 |
+
* 注意,同时安装 `onnxruntime` 与 `onnxruntime-gpu` 可能导致最终程序无法运行在GPU,而运行在CPU,导致极慢的推理速度。
|
19 |
+
* [onnxruntime 官方安装指南](https://onnxruntime.ai/docs/install/#python-installs)
|
20 |
+
* TLDR: For cuda11.x, `pip install onnxruntime-gpu`. For cuda12.x, `pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
|
21 |
+
`.
|
22 |
+
* 进一步的,可用安装基于 tensorrt 的 onnxruntime,进一步加快推理速度。
|
23 |
+
* 注意:如果没有安装基于 tensorrt 的 onnxruntime,建议将 `https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4` 中 `TensorrtExecutionProvider` 删除。
|
24 |
+
* 对于 cuda12.x 可用使用如下命令快速安装带有tensorrt的onnxruntime (注意将 `/root/miniconda3/lib/python3.10/site-packages` 修改为你的python 对应路径,将 `/root/.bashrc` 改为你的用户下路径 `.bashrc` 路劲)
|
25 |
+
```
|
26 |
+
pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/
|
27 |
+
pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/
|
28 |
+
pip install tensorrt==8.6.0
|
29 |
+
echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc
|
30 |
+
```
|
31 |
+
|
32 |
+
### pytorch3d 安装
|
33 |
+
|
34 |
+
* 根据 [pytorch3d 官方的安装建议](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux),建议使用预编译版本
|
35 |
+
```
|
36 |
+
import sys
|
37 |
+
import torch
|
38 |
+
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
|
39 |
+
version_str="".join([
|
40 |
+
f"py3{sys.version_info.minor}_cu",
|
41 |
+
torch.version.cuda.replace(".",""),
|
42 |
+
f"_pyt{pyt_version_str}"
|
43 |
+
])
|
44 |
+
!pip install fvcore iopath
|
45 |
+
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
|
46 |
+
```
|
47 |
+
|
48 |
+
### torch_scatter 安装
|
49 |
+
|
50 |
+
* 在[torch_scatter 官方安装指南](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation) 使用预编译的安装包快速安装。
|
51 |
+
* 或者直接编译安装 `pip install git+https://github.com/rusty1s/pytorch_scatter.git`
|
52 |
+
|
53 |
+
### 其他安装
|
54 |
+
|
55 |
+
* 其他文件 `pip install -r requirements.txt` 即可。
|
56 |
+
|
57 |
+
-----
|
58 |
+
|
59 |
+
# Detailed Installation Guide
|
60 |
+
|
61 |
+
* In `requirements-detail.txt`, we provide detailed versions of all packages, which correspond to the environment of `python3.10 + cuda12.2`.
|
62 |
+
* This project relies on several important PyPI packages, which may be difficult to install.
|
63 |
+
|
64 |
+
### Installation of nvdiffrast
|
65 |
+
|
66 |
+
* nvdiffrast will compile the corresponding torch plugin the first time it runs, which requires support from ninja and cudatoolkit.
|
67 |
+
* Therefore, it is necessary to ensure that ninja and cudatoolkit are correctly installed and that the CUDA_HOME environment variable is properly configured.
|
68 |
+
* For the installation of cudatoolkit, you can refer to the [Linux CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) and [Windows CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html).
|
69 |
+
* Ninja can be directly installed with `pip install ninja`.
|
70 |
+
* Then set the CUDA_HOME variable to the installation directory of cudatoolkit, such as `/usr/local/cuda`.
|
71 |
+
* Finally, `pip install nvdiffrast`.
|
72 |
+
* If you cannot install cudatoolkit on the computer (e.g., insufficient permissions), you can use my modified [pre-compiled version of nvdiffrast](https://github.com/wukailu/nvdiffrast-torch) to pre-compile on another computer that has cudatoolkit and a similar environment (same versions of python, torch, cuda) and then install the `.whl`.
|
73 |
+
|
74 |
+
### Installation of onnxruntime-gpu
|
75 |
+
|
76 |
+
* Note that installing both `onnxruntime` and `onnxruntime-gpu` may result in not running on the GPU but on the CPU, leading to extremely slow inference speed.
|
77 |
+
* [Official ONNX Runtime Installation Guide](https://onnxruntime.ai/docs/install/#python-installs)
|
78 |
+
* TLDR: For cuda11.x, `pip install onnxruntime-gpu`. For cuda12.x, `pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`.
|
79 |
+
* Furthermore, you can install onnxruntime based on tensorrt to further increase the inference speed.
|
80 |
+
* Note: If you do not correctly installed onnxruntime based on tensorrt, it is recommended to remove `TensorrtExecutionProvider` from `https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4`.
|
81 |
+
* For cuda12.x, you can quickly install onnxruntime with tensorrt using the following commands (note to change the path `/root/miniconda3/lib/python3.10/site-packages` to the corresponding path of your python, and change `/root/.bashrc` to the path of `.bashrc` under your user directory):
|
82 |
+
```
|
83 |
+
pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/
|
84 |
+
pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/
|
85 |
+
pip install tensorrt==8.6.0
|
86 |
+
echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc
|
87 |
+
```
|
88 |
+
|
89 |
+
### Installation of pytorch3d
|
90 |
+
|
91 |
+
* According to the [official installation recommendations of pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux), it is recommended to use the pre-compiled version:
|
92 |
+
```
|
93 |
+
import sys
|
94 |
+
import torch
|
95 |
+
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
|
96 |
+
version_str="".join([
|
97 |
+
f"py3{sys.version_info.minor}_cu",
|
98 |
+
torch.version.cuda.replace(".",""),
|
99 |
+
f"_pyt{pyt_version_str}"
|
100 |
+
])
|
101 |
+
!pip install fvcore iopath
|
102 |
+
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
|
103 |
+
```
|
104 |
+
|
105 |
+
### Installation of torch_scatter
|
106 |
+
|
107 |
+
* Use the pre-compiled installation package according to the [official installation guide of torch_scatter](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation) for a quick installation.
|
108 |
+
* Alternatively, you can directly compile and install with `pip install git+https://github.com/rusty1s/pytorch_scatter.git`.
|
109 |
+
|
110 |
+
### Other Installations
|
111 |
+
|
112 |
+
* For other packages, simply `pip install -r requirements.txt`.
|
113 |
+
|
114 |
+
-----
|
115 |
+
|
116 |
+
# 官方インストールガイド
|
117 |
+
|
118 |
+
* `requirements-detail.txt` には、各ライブラリのバージョンが詳細に提供されており、これは Python 3.10 + CUDA 12.2 に対応する環境です。
|
119 |
+
* このプロジェクトは、いくつかの重要な PyPI パッケージに依存しており、これらのパッケージのインストールにはいくつかの困難が伴います。
|
120 |
+
|
121 |
+
### nvdiffrast のインストール
|
122 |
+
|
123 |
+
* nvdiffrast は、最初に実行するときに、torch プラグインの対応バージョンをコンパイルします。このステップには、ninja および cudatoolkit のサポートが必要です。
|
124 |
+
* したがって、ninja および cudatoolkit の正確なインストールと、CUDA_HOME 環境変数の正確な設定を確保する必要があります。
|
125 |
+
* cudatoolkit のインストールについては、[Linux CUDA インストールガイド](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)、[Windows CUDA インストールガイド](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) を参照してください。
|
126 |
+
* ninja は、直接 `pip install ninja` でインストールできます。
|
127 |
+
* 次に、CUDA_HOME 変数を cudatoolkit のインストールディレクトリに設定します。例えば、`/usr/local/cuda` のように。
|
128 |
+
* 最後に、`pip install nvdiffrast` を実行します。
|
129 |
+
* 目標サーバーで cudatoolkit をインストールできない場合(例えば、権限が不足している場合)、私の修正した[事前コンパイル済みバージョンの nvdiffrast](https://github.com/wukailu/nvdiffrast-torch)を使用できます。これは、cudatoolkit があり、環境が似ている(Python、torch、cudaのバージョンが同じ)別のサーバーで事前コンパイルしてからインストールすることができます。
|
130 |
+
|
131 |
+
### onnxruntime-gpu のインストール
|
132 |
+
|
133 |
+
* 注意:`onnxruntime` と `onnxruntime-gpu` を同時にインストールすると、最終的なプログラムが GPU 上で実行されず、CPU 上で実行される可能性があり、推論速度���非常に遅くなることがあります。
|
134 |
+
* [onnxruntime 公式インストールガイド](https://onnxruntime.ai/docs/install/#python-installs)
|
135 |
+
* TLDR: cuda11.x 用には、`pip install onnxruntime-gpu` を使用します。cuda12.x 用には、`pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/` を使用します。
|
136 |
+
* さらに、TensorRT ベースの onnxruntime をインストールして、推論速度をさらに向上させることができます。
|
137 |
+
* 注意:TensorRT ベースの onnxruntime がインストールされていない場合は、`https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4` の `TensorrtExecutionProvider` を削除することをお勧めします。
|
138 |
+
* cuda12.x の場合、次のコマンドを使用して迅速に TensorRT を備えた onnxruntime をインストールできます(`/root/miniconda3/lib/python3.10/site-packages` をあなたの Python に対応するパスに、`/root/.bashrc` をあなたのユーザーのパスの下の `.bashrc` に変更してください)。
|
139 |
+
```bash
|
140 |
+
pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/
|
141 |
+
pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/
|
142 |
+
pip install tensorrt==8.6.0
|
143 |
+
echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc
|
144 |
+
```
|
145 |
+
|
146 |
+
### pytorch3d のインストール
|
147 |
+
|
148 |
+
* [pytorch3d 公式のインストール提案](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux)に従い、事前コンパイル済みバージョンを使用することをお勧めします。
|
149 |
+
```python
|
150 |
+
import sys
|
151 |
+
import torch
|
152 |
+
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
|
153 |
+
version_str="".join([
|
154 |
+
f"py3{sys.version_info.minor}_cu",
|
155 |
+
torch.version.cuda.replace(".",""),
|
156 |
+
f"_pyt{pyt_version_str}"
|
157 |
+
])
|
158 |
+
!pip install fvcore iopath
|
159 |
+
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
|
160 |
+
```
|
161 |
+
|
162 |
+
### torch_scatter のインストール
|
163 |
+
|
164 |
+
* [torch_scatter 公式インストールガイド](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation)に従い、事前コンパイル済みのインストールパッケージを使用して迅速インストールします。
|
165 |
+
* または、直接コンパイルしてインストールする `pip install git+https://github.com/rusty1s/pytorch_scatter.git` も可能です。
|
166 |
+
|
167 |
+
### その他のインストール
|
168 |
+
|
169 |
+
* その他のファイルについては、`pip install -r requirements.txt` を実行するだけです。
|
170 |
+
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 AiuniAI
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,137 @@
|
|
1 |
---
|
2 |
-
title: 3D
|
3 |
-
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.5.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: 3D-Genesis
|
3 |
+
app_file: gradio_app.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.5.0
|
|
|
|
|
6 |
---
|
7 |
+
**[中文版本](README_zh.md)**
|
8 |
|
9 |
+
**[日本語版](README_jp.md)**
|
10 |
+
|
11 |
+
# Unique3D
|
12 |
+
Official implementation of Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image.
|
13 |
+
|
14 |
+
[Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/)
|
15 |
+
|
16 |
+
## [Paper](https://arxiv.org/abs/2405.20343) | [Project page](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradio Demo](http://unique3d.demo.avar.cn/) | [Online Demo](https://www.aiuni.ai/)
|
17 |
+
|
18 |
+
* Demo inference speed: Gradio Demo > Huggingface Demo > Huggingface Demo2 > Online Demo
|
19 |
+
|
20 |
+
**If the Gradio Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.**
|
21 |
+
|
22 |
+
<p align="center">
|
23 |
+
<img src="assets/teaser_safe.jpg">
|
24 |
+
</p>
|
25 |
+
|
26 |
+
High-fidelity and diverse textured meshes generated by Unique3D from single-view wild images in 30 seconds.
|
27 |
+
|
28 |
+
## More features
|
29 |
+
|
30 |
+
The repo is still being under construction, thanks for your patience.
|
31 |
+
- [x] Upload weights.
|
32 |
+
- [x] Local gradio demo.
|
33 |
+
- [x] Detailed tutorial.
|
34 |
+
- [x] Huggingface demo.
|
35 |
+
- [ ] Detailed local demo.
|
36 |
+
- [x] Comfyui support.
|
37 |
+
- [x] Windows support.
|
38 |
+
- [x] Docker support.
|
39 |
+
- [ ] More stable reconstruction with normal.
|
40 |
+
- [ ] Training code release.
|
41 |
+
|
42 |
+
## Preparation for inference
|
43 |
+
|
44 |
+
* [Detailed linux installation guide](Installation.md).
|
45 |
+
|
46 |
+
### Linux System Setup.
|
47 |
+
|
48 |
+
Adapted for Ubuntu 22.04.4 LTS and CUDA 12.1.
|
49 |
+
```angular2html
|
50 |
+
conda create -n unique3d python=3.11
|
51 |
+
conda activate unique3d
|
52 |
+
|
53 |
+
pip install ninja
|
54 |
+
pip install diffusers==0.27.2
|
55 |
+
|
56 |
+
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3.1/index.html
|
57 |
+
|
58 |
+
pip install -r requirements.txt
|
59 |
+
```
|
60 |
+
|
61 |
+
[oak-barry](https://github.com/oak-barry) provide another setup script for torch210+cu121 at [here](https://github.com/oak-barry/Unique3D).
|
62 |
+
|
63 |
+
### Windows Setup.
|
64 |
+
|
65 |
+
* Thank you very much `jtydhr88` for the windows installation method! See [issues/15](https://github.com/AiuniAI/Unique3D/issues/15).
|
66 |
+
|
67 |
+
According to [issues/15](https://github.com/AiuniAI/Unique3D/issues/15), implemented a bat script to run the commands, so you can:
|
68 |
+
1. Might still require Visual Studio Build Tools, you can find it from [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/?q=build+tools).
|
69 |
+
2. Create conda env and activate it
|
70 |
+
1. `conda create -n unique3d-py311 python=3.11`
|
71 |
+
2. `conda activate unique3d-py311`
|
72 |
+
3. download [triton whl](https://huggingface.co/madbuda/triton-windows-builds/resolve/main/triton-2.1.0-cp311-cp311-win_amd64.whl) for py311, and put it into this project.
|
73 |
+
4. run **install_windows_win_py311_cu121.bat**
|
74 |
+
5. answer y while asking you uninstall onnxruntime and onnxruntime-gpu
|
75 |
+
6. create the output folder **tmp\gradio** under the driver root, such as F:\tmp\gradio for me.
|
76 |
+
7. python app/gradio_local.py --port 7860
|
77 |
+
|
78 |
+
More details prefer to [issues/15](https://github.com/AiuniAI/Unique3D/issues/15).
|
79 |
+
|
80 |
+
### Interactive inference: run your local gradio demo.
|
81 |
+
|
82 |
+
1. Download the weights from [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt) or [Tsinghua Cloud Drive](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/), and extract it to `ckpt/*`.
|
83 |
+
```
|
84 |
+
Unique3D
|
85 |
+
├──ckpt
|
86 |
+
├── controlnet-tile/
|
87 |
+
├── image2normal/
|
88 |
+
├── img2mvimg/
|
89 |
+
├── realesrgan-x4.onnx
|
90 |
+
└── v1-inference.yaml
|
91 |
+
```
|
92 |
+
|
93 |
+
2. Run the interactive inference locally.
|
94 |
+
```bash
|
95 |
+
python app/gradio_local.py --port 7860
|
96 |
+
```
|
97 |
+
|
98 |
+
## ComfyUI Support
|
99 |
+
|
100 |
+
Thanks for the [ComfyUI-Unique3D](https://github.com/jtydhr88/ComfyUI-Unique3D) implementation from [jtydhr88](https://github.com/jtydhr88)!
|
101 |
+
|
102 |
+
## Tips to get better results
|
103 |
+
|
104 |
+
**Important: Because the mesh is normalized by the longest edge of xyz during training, it is desirable that the input image needs to contain the longest edge of the object during inference, or else you may get erroneously squashed results.**
|
105 |
+
1. Unique3D is sensitive to the facing direction of input images. Due to the distribution of the training data, orthographic front-facing images with a rest pose always lead to good reconstructions.
|
106 |
+
2. Images with occlusions will cause worse reconstructions, since four views cannot cover the complete object. Images with fewer occlusions lead to better results.
|
107 |
+
3. Pass an image with as high a resolution as possible to the input when resolution is a factor.
|
108 |
+
|
109 |
+
## Acknowledgement
|
110 |
+
|
111 |
+
We have intensively borrowed code from the following repositories. Many thanks to the authors for sharing their code.
|
112 |
+
- [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
|
113 |
+
- [Wonder3d](https://github.com/xxlong0/Wonder3D)
|
114 |
+
- [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
|
115 |
+
- [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
|
116 |
+
- [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
|
117 |
+
|
118 |
+
## Collaborations
|
119 |
+
Our mission is to create a 4D generative model with 3D concepts. This is just our first step, and the road ahead is still long, but we are confident. We warmly invite you to join the discussion and explore potential collaborations in any capacity. <span style="color:red">**If you're interested in connecting or partnering with us, please don't hesitate to reach out via email (wkl22@mails.tsinghua.edu.cn)**</span>.
|
120 |
+
|
121 |
+
- Follow us on twitter for the latest updates: https://x.com/aiuni_ai
|
122 |
+
- Join AIGC 3D/4D generation community on discord: https://discord.gg/aiuni
|
123 |
+
- Research collaboration, please contact: ai@aiuni.ai
|
124 |
+
|
125 |
+
## Citation
|
126 |
+
|
127 |
+
If you found Unique3D helpful, please cite our report:
|
128 |
+
```bibtex
|
129 |
+
@misc{wu2024unique3d,
|
130 |
+
title={Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image},
|
131 |
+
author={Kailu Wu and Fangfu Liu and Zhihan Cai and Runjie Yan and Hanyang Wang and Yating Hu and Yueqi Duan and Kaisheng Ma},
|
132 |
+
year={2024},
|
133 |
+
eprint={2405.20343},
|
134 |
+
archivePrefix={arXiv},
|
135 |
+
primaryClass={cs.CV}
|
136 |
+
}
|
137 |
+
```
|
README_jp.md
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**他の言語のバージョン [英語](README.md) [中国語](README_zh.md)**
|
2 |
+
|
3 |
+
# Unique3D
|
4 |
+
Unique3D: 単一画像からの高品質かつ効率的な3Dメッシュ生成の公式実装。
|
5 |
+
|
6 |
+
[Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/)
|
7 |
+
|
8 |
+
## [論文](https://arxiv.org/abs/2405.20343) | [プロジェクトページ](https://wukailu.github.io/Unique3D/) | [Huggingfaceデモ](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradioデモ](http://unique3d.demo.avar.cn/) | [オンラインデモ](https://www.aiuni.ai/)
|
9 |
+
|
10 |
+
* デモ推論速度: Gradioデモ > Huggingfaceデモ > Huggingfaceデモ2 > オンラインデモ
|
11 |
+
|
12 |
+
**Gradioデモが残念ながらハングアップしたり、非常に混雑している場合は、[aiuni.ai](https://www.aiuni.ai/)のオンラインデモを使用できます。これは無料で試すことができます(登録招待コードを取得するには、Discordに参加してください: https://discord.gg/aiuni)。ただし、オンラインデモはGradioデモとは少し異なり、推論速度が遅く、生成結果が安定していない可能性がありますが、素材の品質は良いです。**
|
13 |
+
|
14 |
+
<p align="center">
|
15 |
+
<img src="assets/teaser_safe.jpg">
|
16 |
+
</p>
|
17 |
+
|
18 |
+
Unique3Dは、野生の単一画像から高忠実度および多様なテクスチャメッシュを30秒で生成します。
|
19 |
+
|
20 |
+
## より多くの機能
|
21 |
+
|
22 |
+
リポジトリはまだ構築中です。ご理解いただきありがとうございます。
|
23 |
+
- [x] 重みのアップロード。
|
24 |
+
- [x] ローカルGradioデモ。
|
25 |
+
- [ ] 詳細なチュートリアル。
|
26 |
+
- [x] Huggingfaceデモ。
|
27 |
+
- [ ] 詳細なローカルデモ。
|
28 |
+
- [x] Comfyuiサポート。
|
29 |
+
- [x] Windowsサポート。
|
30 |
+
- [ ] Dockerサポート。
|
31 |
+
- [ ] ノーマルでより安定した再構築。
|
32 |
+
- [ ] トレーニングコードのリリース。
|
33 |
+
|
34 |
+
## 推論の準備
|
35 |
+
|
36 |
+
### Linuxシステムセットアップ
|
37 |
+
|
38 |
+
Ubuntu 22.04.4 LTSおよびCUDA 12.1に適応。
|
39 |
+
```angular2html
|
40 |
+
conda create -n unique3d python=3.11
|
41 |
+
conda activate unique3d
|
42 |
+
|
43 |
+
pip install ninja
|
44 |
+
pip install diffusers==0.27.2
|
45 |
+
|
46 |
+
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3.1/index.html
|
47 |
+
|
48 |
+
pip install -r requirements.txt
|
49 |
+
```
|
50 |
+
|
51 |
+
[oak-barry](https://github.com/oak-barry)は、[こちら](https://github.com/oak-barry/Unique3D)でtorch210+cu121の別のセットアップスクリプトを提供しています。
|
52 |
+
|
53 |
+
### Windowsセットアップ
|
54 |
+
|
55 |
+
* `jtydhr88`によるWindowsインストール方法に非常に感謝します![issues/15](https://github.com/AiuniAI/Unique3D/issues/15)を参照してください。
|
56 |
+
|
57 |
+
[issues/15](https://github.com/AiuniAI/Unique3D/issues/15)によると、コマンドを実行するバッチスクリプトを実装したので、以下の手順に従ってください。
|
58 |
+
1. [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/?q=build+tools)からVisual Studio Build Toolsが必要になる場合があります。
|
59 |
+
2. conda envを作成し、アクティブにします。
|
60 |
+
1. `conda create -n unique3d-py311 python=3.11`
|
61 |
+
2. `conda activate unique3d-py311`
|
62 |
+
3. [triton whl](https://huggingface.co/madbuda/triton-windows-builds/resolve/main/triton-2.1.0-cp311-cp311-win_amd64.whl)をダウンロードし、このプロジェクトに配置します。
|
63 |
+
4. **install_windows_win_py311_cu121.bat**を実行します。
|
64 |
+
5. onnxruntimeおよびonnxruntime-gpuのアンインストールを求められた場合は、yと回答します。
|
65 |
+
6. ドライバールートの下に**tmp\gradio**フォルダを作成します(例:F:\tmp\gradio)。
|
66 |
+
7. python app/gradio_local.py --port 7860
|
67 |
+
|
68 |
+
詳細は[issues/15](https://github.com/AiuniAI/Unique3D/issues/15)を参照してください。
|
69 |
+
|
70 |
+
### インタラクティブ推論:ローカルGradioデモを実行する
|
71 |
+
|
72 |
+
1. [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt)または[Tsinghua Cloud Drive](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/)から重みをダウンロードし、`ckpt/*`に抽出します。
|
73 |
+
```
|
74 |
+
Unique3D
|
75 |
+
├──ckpt
|
76 |
+
├── controlnet-tile/
|
77 |
+
├── image2normal/
|
78 |
+
├── img2mvimg/
|
79 |
+
├── realesrgan-x4.onnx
|
80 |
+
└── v1-inference.yaml
|
81 |
+
```
|
82 |
+
|
83 |
+
2. インタラクティブ推論をローカルで実行します。
|
84 |
+
```bash
|
85 |
+
python app/gradio_local.py --port 7860
|
86 |
+
```
|
87 |
+
|
88 |
+
## ComfyUIサポート
|
89 |
+
|
90 |
+
[jtydhr88](https://github.com/jtydhr88)からの[ComfyUI-Unique3D](https://github.com/jtydhr88/ComfyUI-Unique3D)の実装に感謝します!
|
91 |
+
|
92 |
+
## より良い結果を得るためのヒント
|
93 |
+
|
94 |
+
1. Unique3Dは入力画像の向きに敏感です。トレーニングデータの分布により、正面を向いた直交画像は常に良い再構築につながります。
|
95 |
+
2. 遮���のある画像は、4つのビューがオブジェクトを完全にカバーできないため、再構築が悪化します。遮蔽の少ない画像は、より良い結果につながります。
|
96 |
+
3. 可能な限り高解像度の画像を入力として使用してください。
|
97 |
+
|
98 |
+
## 謝辞
|
99 |
+
|
100 |
+
以下のリポジトリからコードを大量に借用しました。コードを共有してくれた著者に感謝します。
|
101 |
+
- [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
|
102 |
+
- [Wonder3d](https://github.com/xxlong0/Wonder3D)
|
103 |
+
- [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
|
104 |
+
- [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
|
105 |
+
- [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
|
106 |
+
|
107 |
+
## コラボレーション
|
108 |
+
私たちの使命は、3Dの概念を持つ4D生成モデルを作成することです。これは私たちの最初のステップであり、前途はまだ長いですが、私たちは自信を持っています。あらゆる形態の潜在的なコラボレーションを探求し、議論に参加することを心から歓迎します。<span style="color:red">**私たちと連絡を取りたい、またはパートナーシップを結びたい方は、メールでお気軽にお問い合わせください (wkl22@mails.tsinghua.edu.cn)**</span>。
|
109 |
+
|
110 |
+
- 最新情報を入手するには、Twitterをフォローしてください: https://x.com/aiuni_ai
|
111 |
+
- DiscordでAIGC 3D/4D生成コミュニティに参加してください: https://discord.gg/aiuni
|
112 |
+
- 研究協力については、ai@aiuni.aiまでご連絡ください。
|
113 |
+
|
114 |
+
## 引用
|
115 |
+
|
116 |
+
Unique3Dが役立つと思われる場合は、私たちのレポートを引用してください:
|
117 |
+
```bibtex
|
118 |
+
@misc{wu2024unique3d,
|
119 |
+
title={Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image},
|
120 |
+
author={Kailu Wu and Fangfu Liu and Zhihan Cai and Runjie Yan and Hanyang Wang and Yating Hu and Yueqi Duan and Kaisheng Ma},
|
121 |
+
year={2024},
|
122 |
+
eprint={2405.20343},
|
123 |
+
archivePrefix={arXiv},
|
124 |
+
primaryClass={cs.CV}
|
125 |
+
}
|
126 |
+
```
|
README_zh.md
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**其他语言版本 [English](README.md)**
|
2 |
+
|
3 |
+
# Unique3D
|
4 |
+
High-Quality and Efficient 3D Mesh Generation from a Single Image
|
5 |
+
|
6 |
+
[Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/)
|
7 |
+
|
8 |
+
## [论文](https://arxiv.org/abs/2405.20343) | [项目页面](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradio Demo](http://unique3d.demo.avar.cn/) | [在线演示](https://www.aiuni.ai/)
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
<p align="center">
|
13 |
+
<img src="assets/teaser_safe.jpg">
|
14 |
+
</p>
|
15 |
+
|
16 |
+
Unique3D从单视图图像生成高保真度和多样化纹理的网格,在4090上大约需要30秒。
|
17 |
+
|
18 |
+
### 推理准备
|
19 |
+
|
20 |
+
#### Linux系统设置
|
21 |
+
```angular2html
|
22 |
+
conda create -n unique3d
|
23 |
+
conda activate unique3d
|
24 |
+
pip install -r requirements.txt
|
25 |
+
```
|
26 |
+
|
27 |
+
#### 交互式推理:运行您的本地gradio演示
|
28 |
+
|
29 |
+
1. 从 [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt) 下载或者从[清华云盘](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/)下载权重,并将其解压到`ckpt/*`。
|
30 |
+
```
|
31 |
+
Unique3D
|
32 |
+
├──ckpt
|
33 |
+
├── controlnet-tile/
|
34 |
+
├── image2normal/
|
35 |
+
├── img2mvimg/
|
36 |
+
├── realesrgan-x4.onnx
|
37 |
+
└── v1-inference.yaml
|
38 |
+
```
|
39 |
+
|
40 |
+
2. 在本地运行交互式推理。
|
41 |
+
```bash
|
42 |
+
python app/gradio_local.py --port 7860
|
43 |
+
```
|
44 |
+
|
45 |
+
## 获取更好结果的提示
|
46 |
+
|
47 |
+
1. Unique3D对输入图像的朝向非常敏感。由于训练数据的分布,**正交正视图像**通常总是能带来良好的重建。对于人物而言,最好是 A-pose 或者 T-pose,因为目前训练数据很少含有其他类型姿态。
|
48 |
+
2. 有遮挡的图像会导致更差的重建,因为4个视图无法覆盖完整的对象。遮挡较少的图像会带来更好的结果。
|
49 |
+
3. 尽可能将高分辨率的图像用作输入。
|
50 |
+
|
51 |
+
## 致谢
|
52 |
+
|
53 |
+
我们借用了以下代码库的代码。非常感谢作者们分享他们的代码。
|
54 |
+
- [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
|
55 |
+
- [Wonder3d](https://github.com/xxlong0/Wonder3D)
|
56 |
+
- [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus)
|
57 |
+
- [Continues Remeshing](https://github.com/Profactor/continuous-remeshing)
|
58 |
+
- [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals)
|
59 |
+
|
60 |
+
## 合作
|
61 |
+
|
62 |
+
我们使命是创建一个具有3D概念的4D生成模型。这只是我们的第一步,前方的道路仍然很长,但我们有信心。我们热情邀请您加入讨论,并探索任何形式的潜在合作。<span style="color:red">**如果您有兴趣联系或与我们合作,欢迎通过电子邮件(wkl22@mails.tsinghua.edu.cn)与我们联系**</span>。
|
app/__init__.py
ADDED
File without changes
|
app/all_models.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from scripts.sd_model_zoo import load_common_sd15_pipe
|
3 |
+
from diffusers import StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
|
4 |
+
|
5 |
+
|
6 |
+
class MyModelZoo:
|
7 |
+
_pipe_disney_controlnet_lineart_ipadapter_i2i: StableDiffusionControlNetImg2ImgPipeline = None
|
8 |
+
|
9 |
+
base_model = "runwayml/stable-diffusion-v1-5"
|
10 |
+
|
11 |
+
def __init__(self, base_model=None) -> None:
|
12 |
+
if base_model is not None:
|
13 |
+
self.base_model = base_model
|
14 |
+
|
15 |
+
@property
|
16 |
+
def pipe_disney_controlnet_tile_ipadapter_i2i(self):
|
17 |
+
return self._pipe_disney_controlnet_lineart_ipadapter_i2i
|
18 |
+
|
19 |
+
def init_models(self):
|
20 |
+
self._pipe_disney_controlnet_lineart_ipadapter_i2i = load_common_sd15_pipe(base_model=self.base_model, ip_adapter=True, plus_model=False, controlnet="./ckpt/controlnet-tile", pipeline_class=StableDiffusionControlNetImg2ImgPipeline)
|
21 |
+
|
22 |
+
model_zoo = MyModelZoo()
|
app/custom_models/image2mvimage.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_model_name_or_path: "./ckpt/img2mvimg"
|
2 |
+
mixed_precision: "bf16"
|
3 |
+
|
4 |
+
init_config:
|
5 |
+
# enable controls
|
6 |
+
enable_cross_attn_lora: False
|
7 |
+
enable_cross_attn_ip: False
|
8 |
+
enable_self_attn_lora: False
|
9 |
+
enable_self_attn_ref: False
|
10 |
+
enable_multiview_attn: True
|
11 |
+
|
12 |
+
# for cross attention
|
13 |
+
init_cross_attn_lora: False
|
14 |
+
init_cross_attn_ip: False
|
15 |
+
cross_attn_lora_rank: 256 # 0 for not enabled
|
16 |
+
cross_attn_lora_only_kv: False
|
17 |
+
ipadapter_pretrained_name: "h94/IP-Adapter"
|
18 |
+
ipadapter_subfolder_name: "models"
|
19 |
+
ipadapter_weight_name: "ip-adapter_sd15.safetensors"
|
20 |
+
ipadapter_effect_on: "all" # all, first
|
21 |
+
|
22 |
+
# for self attention
|
23 |
+
init_self_attn_lora: False
|
24 |
+
self_attn_lora_rank: 256
|
25 |
+
self_attn_lora_only_kv: False
|
26 |
+
|
27 |
+
# for self attention ref
|
28 |
+
init_self_attn_ref: False
|
29 |
+
self_attn_ref_position: "attn1"
|
30 |
+
self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
|
31 |
+
self_attn_ref_pixel_wise_crosspond: False
|
32 |
+
self_attn_ref_effect_on: "all"
|
33 |
+
|
34 |
+
# for multiview attention
|
35 |
+
init_multiview_attn: True
|
36 |
+
multiview_attn_position: "attn1"
|
37 |
+
use_mv_joint_attn: True
|
38 |
+
num_modalities: 1
|
39 |
+
|
40 |
+
# for unet
|
41 |
+
init_unet_path: "${pretrained_model_name_or_path}"
|
42 |
+
cat_condition: True # cat condition to input
|
43 |
+
|
44 |
+
# for cls embedding
|
45 |
+
init_num_cls_label: 8 # for initialize
|
46 |
+
cls_labels: [0, 1, 2, 3] # for current task
|
47 |
+
|
48 |
+
trainers:
|
49 |
+
- trainer_type: "image2mvimage_trainer"
|
50 |
+
trainer:
|
51 |
+
pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
|
52 |
+
attn_config:
|
53 |
+
cls_labels: [0, 1, 2, 3] # for current task
|
54 |
+
enable_cross_attn_lora: False
|
55 |
+
enable_cross_attn_ip: False
|
56 |
+
enable_self_attn_lora: False
|
57 |
+
enable_self_attn_ref: False
|
58 |
+
enable_multiview_attn: True
|
59 |
+
resolution: "256"
|
60 |
+
condition_image_resolution: "256"
|
61 |
+
normal_cls_offset: 4
|
62 |
+
condition_image_column_name: "conditioning_image"
|
63 |
+
image_column_name: "image"
|
app/custom_models/image2normal.yaml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pretrained_model_name_or_path: "lambdalabs/sd-image-variations-diffusers"
|
2 |
+
mixed_precision: "bf16"
|
3 |
+
|
4 |
+
init_config:
|
5 |
+
# enable controls
|
6 |
+
enable_cross_attn_lora: False
|
7 |
+
enable_cross_attn_ip: False
|
8 |
+
enable_self_attn_lora: False
|
9 |
+
enable_self_attn_ref: True
|
10 |
+
enable_multiview_attn: False
|
11 |
+
|
12 |
+
# for cross attention
|
13 |
+
init_cross_attn_lora: False
|
14 |
+
init_cross_attn_ip: False
|
15 |
+
cross_attn_lora_rank: 512 # 0 for not enabled
|
16 |
+
cross_attn_lora_only_kv: False
|
17 |
+
ipadapter_pretrained_name: "h94/IP-Adapter"
|
18 |
+
ipadapter_subfolder_name: "models"
|
19 |
+
ipadapter_weight_name: "ip-adapter_sd15.safetensors"
|
20 |
+
ipadapter_effect_on: "all" # all, first
|
21 |
+
|
22 |
+
# for self attention
|
23 |
+
init_self_attn_lora: False
|
24 |
+
self_attn_lora_rank: 512
|
25 |
+
self_attn_lora_only_kv: False
|
26 |
+
|
27 |
+
# for self attention ref
|
28 |
+
init_self_attn_ref: True
|
29 |
+
self_attn_ref_position: "attn1"
|
30 |
+
self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers"
|
31 |
+
self_attn_ref_pixel_wise_crosspond: True
|
32 |
+
self_attn_ref_effect_on: "all"
|
33 |
+
|
34 |
+
# for multiview attention
|
35 |
+
init_multiview_attn: False
|
36 |
+
multiview_attn_position: "attn1"
|
37 |
+
num_modalities: 1
|
38 |
+
|
39 |
+
# for unet
|
40 |
+
init_unet_path: "${pretrained_model_name_or_path}"
|
41 |
+
init_num_cls_label: 0 # for initialize
|
42 |
+
cls_labels: [] # for current task
|
43 |
+
|
44 |
+
trainers:
|
45 |
+
- trainer_type: "image2image_trainer"
|
46 |
+
trainer:
|
47 |
+
pretrained_model_name_or_path: "${pretrained_model_name_or_path}"
|
48 |
+
attn_config:
|
49 |
+
cls_labels: [] # for current task
|
50 |
+
enable_cross_attn_lora: False
|
51 |
+
enable_cross_attn_ip: False
|
52 |
+
enable_self_attn_lora: False
|
53 |
+
enable_self_attn_ref: True
|
54 |
+
enable_multiview_attn: False
|
55 |
+
resolution: "512"
|
56 |
+
condition_image_resolution: "512"
|
57 |
+
condition_image_column_name: "conditioning_image"
|
58 |
+
image_column_name: "image"
|
59 |
+
|
60 |
+
|
61 |
+
|
app/custom_models/mvimg_prediction.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from rembg import remove
|
7 |
+
from app.utils import change_rgba_bg, rgba_to_rgb
|
8 |
+
from app.custom_models.utils import load_pipeline
|
9 |
+
from scripts.all_typing import *
|
10 |
+
from scripts.utils import session, simple_preprocess
|
11 |
+
|
12 |
+
training_config = "app/custom_models/image2mvimage.yaml"
|
13 |
+
checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
|
14 |
+
trainer, pipeline = load_pipeline(training_config, checkpoint_path)
|
15 |
+
# pipeline.enable_model_cpu_offload()
|
16 |
+
|
17 |
+
def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
|
18 |
+
if isinstance(img_list, Image.Image):
|
19 |
+
img_list = [img_list]
|
20 |
+
img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
|
21 |
+
ret = []
|
22 |
+
for img in img_list:
|
23 |
+
images = trainer.pipeline_forward(
|
24 |
+
pipeline=pipeline,
|
25 |
+
image=img,
|
26 |
+
guidance_scale=guidance_scale,
|
27 |
+
**kwargs
|
28 |
+
).images
|
29 |
+
ret.extend(images)
|
30 |
+
return ret
|
31 |
+
|
32 |
+
|
33 |
+
def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145):
|
34 |
+
if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.:
|
35 |
+
# still do remove using rembg, since simple_preprocess requires RGBA image
|
36 |
+
print("RGB image not RGBA! still remove bg!")
|
37 |
+
remove_bg = True
|
38 |
+
|
39 |
+
if remove_bg:
|
40 |
+
input_image = remove(input_image, session=session)
|
41 |
+
|
42 |
+
# make front_pil RGBA with white bg
|
43 |
+
input_image = change_rgba_bg(input_image, "white")
|
44 |
+
single_image = simple_preprocess(input_image)
|
45 |
+
|
46 |
+
generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None
|
47 |
+
|
48 |
+
rgb_pils = predict(
|
49 |
+
single_image,
|
50 |
+
generator=generator,
|
51 |
+
guidance_scale=guidance_scale,
|
52 |
+
width=256,
|
53 |
+
height=256,
|
54 |
+
num_inference_steps=30,
|
55 |
+
)
|
56 |
+
|
57 |
+
return rgb_pils, single_image
|
app/custom_models/normal_prediction.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from PIL import Image
|
3 |
+
from app.utils import rgba_to_rgb, simple_remove
|
4 |
+
from app.custom_models.utils import load_pipeline
|
5 |
+
from scripts.utils import rotate_normals_torch
|
6 |
+
from scripts.all_typing import *
|
7 |
+
|
8 |
+
training_config = "app/custom_models/image2normal.yaml"
|
9 |
+
checkpoint_path = "ckpt/image2normal/unet_state_dict.pth"
|
10 |
+
trainer, pipeline = load_pipeline(training_config, checkpoint_path)
|
11 |
+
# pipeline.enable_model_cpu_offload()
|
12 |
+
|
13 |
+
def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs):
|
14 |
+
img_list = image if isinstance(image, list) else [image]
|
15 |
+
img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
|
16 |
+
images = trainer.pipeline_forward(
|
17 |
+
pipeline=pipeline,
|
18 |
+
image=img_list,
|
19 |
+
num_inference_steps=num_inference_steps,
|
20 |
+
guidance_scale=guidance_scale,
|
21 |
+
**kwargs
|
22 |
+
).images
|
23 |
+
images = simple_remove(images)
|
24 |
+
if do_rotate and len(images) > 1:
|
25 |
+
images = rotate_normals_torch(images, return_types='pil')
|
26 |
+
return images
|
app/custom_models/utils.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from app.utils import rgba_to_rgb
|
5 |
+
from custum_3d_diffusion.trainings.config_classes import ExprimentConfig, TrainerSubConfig
|
6 |
+
from custum_3d_diffusion import modules
|
7 |
+
from custum_3d_diffusion.custum_modules.unifield_processor import AttnConfig, ConfigurableUNet2DConditionModel
|
8 |
+
from custum_3d_diffusion.trainings.base import BasicTrainer
|
9 |
+
from custum_3d_diffusion.trainings.utils import load_config
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class FakeAccelerator:
|
14 |
+
device: torch.device = torch.device("cuda")
|
15 |
+
|
16 |
+
|
17 |
+
def init_trainers(cfg_path: str, weight_dtype: torch.dtype, extras: dict):
|
18 |
+
accelerator = FakeAccelerator()
|
19 |
+
cfg: ExprimentConfig = load_config(ExprimentConfig, cfg_path, extras)
|
20 |
+
init_config: AttnConfig = load_config(AttnConfig, cfg.init_config)
|
21 |
+
configurable_unet = ConfigurableUNet2DConditionModel(init_config, weight_dtype)
|
22 |
+
configurable_unet.enable_xformers_memory_efficient_attention()
|
23 |
+
trainer_cfgs: List[TrainerSubConfig] = [load_config(TrainerSubConfig, trainer) for trainer in cfg.trainers]
|
24 |
+
trainers: List[BasicTrainer] = [modules.find(trainer.trainer_type)(accelerator, None, configurable_unet, trainer.trainer, weight_dtype, i) for i, trainer in enumerate(trainer_cfgs)]
|
25 |
+
return trainers, configurable_unet
|
26 |
+
|
27 |
+
from app.utils import make_image_grid, split_image
|
28 |
+
def process_image(function, img, guidance_scale=2., merged_image=False, remove_bg=True):
|
29 |
+
from rembg import remove
|
30 |
+
if remove_bg:
|
31 |
+
img = remove(img)
|
32 |
+
img = rgba_to_rgb(img)
|
33 |
+
if merged_image:
|
34 |
+
img = split_image(img, rows=2)
|
35 |
+
images = function(
|
36 |
+
image=img,
|
37 |
+
guidance_scale=guidance_scale,
|
38 |
+
)
|
39 |
+
if len(images) > 1:
|
40 |
+
return make_image_grid(images, rows=2)
|
41 |
+
else:
|
42 |
+
return images[0]
|
43 |
+
|
44 |
+
|
45 |
+
def process_text(trainer, pipeline, img, guidance_scale=2.):
|
46 |
+
pipeline.cfg.validation_prompts = [img]
|
47 |
+
titles, images = trainer.batched_validation_forward(pipeline, guidance_scale=[guidance_scale])
|
48 |
+
return images[0]
|
49 |
+
|
50 |
+
|
51 |
+
def load_pipeline(config_path, ckpt_path, pipeline_filter=lambda x: True, weight_dtype = torch.bfloat16):
|
52 |
+
training_config = config_path
|
53 |
+
load_from_checkpoint = ckpt_path
|
54 |
+
extras = []
|
55 |
+
device = "cuda"
|
56 |
+
trainers, configurable_unet = init_trainers(training_config, weight_dtype, extras)
|
57 |
+
shared_modules = dict()
|
58 |
+
for trainer in trainers:
|
59 |
+
shared_modules = trainer.init_shared_modules(shared_modules)
|
60 |
+
|
61 |
+
if load_from_checkpoint is not None:
|
62 |
+
state_dict = torch.load(load_from_checkpoint)
|
63 |
+
configurable_unet.unet.load_state_dict(state_dict, strict=False)
|
64 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
65 |
+
configurable_unet.unet.to(device, dtype=weight_dtype)
|
66 |
+
|
67 |
+
pipeline = None
|
68 |
+
trainer_out = None
|
69 |
+
for trainer in trainers:
|
70 |
+
if pipeline_filter(trainer.cfg.trainer_name):
|
71 |
+
pipeline = trainer.construct_pipeline(shared_modules, configurable_unet.unet)
|
72 |
+
pipeline.set_progress_bar_config(disable=False)
|
73 |
+
trainer_out = trainer
|
74 |
+
pipeline = pipeline.to(device)
|
75 |
+
return trainer_out, pipeline
|
app/examples/Groot.png
ADDED
app/examples/aaa.png
ADDED
app/examples/abma.png
ADDED
app/examples/akun.png
ADDED
app/examples/anya.png
ADDED
app/examples/bag.png
ADDED
Git LFS Details
|
app/examples/ex1.png
ADDED
Git LFS Details
|
app/examples/ex2.png
ADDED
app/examples/ex3.jpg
ADDED
app/examples/ex4.png
ADDED
app/examples/generated_1715761545_frame0.png
ADDED
app/examples/generated_1715762357_frame0.png
ADDED
app/examples/generated_1715763329_frame0.png
ADDED
app/examples/hatsune_miku.png
ADDED
app/examples/princess-large.png
ADDED
app/gradio_3dgen.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from PIL import Image
|
4 |
+
from pytorch3d.structures import Meshes
|
5 |
+
from app.utils import clean_up
|
6 |
+
from app.custom_models.mvimg_prediction import run_mvprediction
|
7 |
+
from app.custom_models.normal_prediction import predict_normals
|
8 |
+
from scripts.refine_lr_to_sr import run_sr_fast
|
9 |
+
from scripts.utils import save_glb_and_video
|
10 |
+
from scripts.multiview_inference import geo_reconstruct
|
11 |
+
|
12 |
+
def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
|
13 |
+
if preview_img is None:
|
14 |
+
raise gr.Error("preview_img is none")
|
15 |
+
if isinstance(preview_img, str):
|
16 |
+
preview_img = Image.open(preview_img)
|
17 |
+
|
18 |
+
if preview_img.size[0] <= 512:
|
19 |
+
preview_img = run_sr_fast([preview_img])[0]
|
20 |
+
rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
|
21 |
+
new_meshes = geo_reconstruct(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
|
22 |
+
vertices = new_meshes.verts_packed()
|
23 |
+
vertices = vertices / 2 * 1.35
|
24 |
+
vertices[..., [0, 2]] = - vertices[..., [0, 2]]
|
25 |
+
new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures)
|
26 |
+
|
27 |
+
ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video)
|
28 |
+
return ret_mesh, video
|
29 |
+
|
30 |
+
#######################################
|
31 |
+
def create_ui(concurrency_id="wkl"):
|
32 |
+
with gr.Row():
|
33 |
+
with gr.Column(scale=2):
|
34 |
+
input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
|
35 |
+
|
36 |
+
example_folder = os.path.join(os.path.dirname(__file__), "./examples")
|
37 |
+
example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)])
|
38 |
+
gr.Examples(
|
39 |
+
examples=example_fns,
|
40 |
+
inputs=[input_image],
|
41 |
+
cache_examples=False,
|
42 |
+
label='Examples (click one of the images below to start)',
|
43 |
+
examples_per_page=12
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
with gr.Column(scale=3):
|
48 |
+
# export mesh display
|
49 |
+
output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320)
|
50 |
+
output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
|
51 |
+
|
52 |
+
input_processing = gr.Checkbox(
|
53 |
+
value=True,
|
54 |
+
label='Remove Background',
|
55 |
+
visible=True,
|
56 |
+
)
|
57 |
+
do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False)
|
58 |
+
expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
|
59 |
+
init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False)
|
60 |
+
setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed")
|
61 |
+
render_video = gr.Checkbox(value=False, visible=False, label="generate video")
|
62 |
+
fullrunv2_btn = gr.Button('Generate 3D', interactive=True)
|
63 |
+
|
64 |
+
fullrunv2_btn.click(
|
65 |
+
fn = generate3dv2,
|
66 |
+
inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type],
|
67 |
+
outputs=[output_mesh, output_video],
|
68 |
+
concurrency_id=concurrency_id,
|
69 |
+
api_name="generate3dv2",
|
70 |
+
).success(clean_up, api_name=False)
|
71 |
+
return input_image
|
app/gradio_3dgen_steps.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
from app.custom_models.mvimg_prediction import run_mvprediction
|
5 |
+
from app.utils import make_image_grid, split_image
|
6 |
+
from scripts.utils import save_glb_and_video
|
7 |
+
|
8 |
+
def concept_to_multiview(preview_img, input_processing, seed, guidance=1.):
|
9 |
+
seed = int(seed)
|
10 |
+
if preview_img is None:
|
11 |
+
raise gr.Error("preview_img is none.")
|
12 |
+
if isinstance(preview_img, str):
|
13 |
+
preview_img = Image.open(preview_img)
|
14 |
+
|
15 |
+
rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=seed, guidance_scale=guidance)
|
16 |
+
rgb_pil = make_image_grid(rgb_pils, rows=2)
|
17 |
+
return rgb_pil, front_pil
|
18 |
+
|
19 |
+
def concept_to_multiview_ui(concurrency_id="wkl"):
|
20 |
+
with gr.Row():
|
21 |
+
with gr.Column(scale=2):
|
22 |
+
preview_img = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
|
23 |
+
input_processing = gr.Checkbox(
|
24 |
+
value=True,
|
25 |
+
label='Remove Background',
|
26 |
+
)
|
27 |
+
seed = gr.Slider(minimum=-1, maximum=1000000000, value=-1, step=1.0, label="seed")
|
28 |
+
guidance = gr.Slider(minimum=1.0, maximum=5.0, value=1.0, label="Guidance Scale", step=0.5)
|
29 |
+
run_btn = gr.Button('Generate Multiview', interactive=True)
|
30 |
+
with gr.Column(scale=3):
|
31 |
+
# export mesh display
|
32 |
+
output_rgb = gr.Image(type='pil', label="RGB", show_label=True)
|
33 |
+
output_front = gr.Image(type='pil', image_mode='RGBA', label="Frontview", show_label=True)
|
34 |
+
run_btn.click(
|
35 |
+
fn = concept_to_multiview,
|
36 |
+
inputs=[preview_img, input_processing, seed, guidance],
|
37 |
+
outputs=[output_rgb, output_front],
|
38 |
+
concurrency_id=concurrency_id,
|
39 |
+
api_name=False,
|
40 |
+
)
|
41 |
+
return output_rgb, output_front
|
42 |
+
|
43 |
+
from app.custom_models.normal_prediction import predict_normals
|
44 |
+
from scripts.multiview_inference import geo_reconstruct
|
45 |
+
def multiview_to_mesh_v2(rgb_pil, normal_pil, front_pil, do_refine=False, expansion_weight=0.1, init_type="std"):
|
46 |
+
rgb_pils = split_image(rgb_pil, rows=2)
|
47 |
+
if normal_pil is not None:
|
48 |
+
normal_pil = split_image(normal_pil, rows=2)
|
49 |
+
if front_pil is None:
|
50 |
+
front_pil = rgb_pils[0]
|
51 |
+
new_meshes = geo_reconstruct(rgb_pils, normal_pil, front_pil, do_refine=do_refine, predict_normal=normal_pil is None, expansion_weight=expansion_weight, init_type=init_type)
|
52 |
+
ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=False)
|
53 |
+
return ret_mesh
|
54 |
+
|
55 |
+
def new_multiview_to_mesh_ui(concurrency_id="wkl"):
|
56 |
+
with gr.Row():
|
57 |
+
with gr.Column(scale=2):
|
58 |
+
rgb_pil = gr.Image(type='pil', image_mode='RGB', label='RGB')
|
59 |
+
front_pil = gr.Image(type='pil', image_mode='RGBA', label='Frontview(Optinal)')
|
60 |
+
normal_pil = gr.Image(type='pil', image_mode='RGBA', label='Normal(Optinal)')
|
61 |
+
do_refine = gr.Checkbox(
|
62 |
+
value=False,
|
63 |
+
label='Refine rgb',
|
64 |
+
visible=False,
|
65 |
+
)
|
66 |
+
expansion_weight = gr.Slider(minimum=-1.0, maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
|
67 |
+
init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh initialization", value="std", visible=False)
|
68 |
+
run_btn = gr.Button('Generate 3D', interactive=True)
|
69 |
+
with gr.Column(scale=3):
|
70 |
+
# export mesh display
|
71 |
+
output_mesh = gr.Model3D(value=None, label="mesh model", show_label=True)
|
72 |
+
run_btn.click(
|
73 |
+
fn = multiview_to_mesh_v2,
|
74 |
+
inputs=[rgb_pil, normal_pil, front_pil, do_refine, expansion_weight, init_type],
|
75 |
+
outputs=[output_mesh],
|
76 |
+
concurrency_id=concurrency_id,
|
77 |
+
api_name="multiview_to_mesh",
|
78 |
+
)
|
79 |
+
return rgb_pil, front_pil, output_mesh
|
80 |
+
|
81 |
+
|
82 |
+
#######################################
|
83 |
+
def create_step_ui(concurrency_id="wkl"):
|
84 |
+
with gr.Tab(label="3D:concept_to_multiview"):
|
85 |
+
concept_to_multiview_ui(concurrency_id)
|
86 |
+
with gr.Tab(label="3D:new_multiview_to_mesh"):
|
87 |
+
new_multiview_to_mesh_ui(concurrency_id)
|
app/gradio_local.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
if __name__ == "__main__":
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
sys.path.append(os.curdir)
|
5 |
+
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
|
6 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
7 |
+
os.environ['TRANSFORMERS_OFFLINE']='0'
|
8 |
+
os.environ['DIFFUSERS_OFFLINE']='0'
|
9 |
+
os.environ['HF_HUB_OFFLINE']='0'
|
10 |
+
os.environ['GRADIO_ANALYTICS_ENABLED']='False'
|
11 |
+
os.environ['HF_ENDPOINT']='https://hf-mirror.com'
|
12 |
+
import torch
|
13 |
+
torch.set_float32_matmul_precision('medium')
|
14 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
15 |
+
torch.set_grad_enabled(False)
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
from app.gradio_3dgen import create_ui as create_3d_ui
|
21 |
+
# from app.gradio_3dgen_steps import create_step_ui
|
22 |
+
from app.all_models import model_zoo
|
23 |
+
|
24 |
+
|
25 |
+
_TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
|
26 |
+
_DESCRIPTION = '''
|
27 |
+
[Project page](https://wukailu.github.io/Unique3D/)
|
28 |
+
|
29 |
+
* High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
|
30 |
+
|
31 |
+
**If the Gradio Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.**
|
32 |
+
'''
|
33 |
+
|
34 |
+
def launch(
|
35 |
+
port,
|
36 |
+
listen=False,
|
37 |
+
share=False,
|
38 |
+
gradio_root="",
|
39 |
+
):
|
40 |
+
model_zoo.init_models()
|
41 |
+
|
42 |
+
with gr.Blocks(
|
43 |
+
title=_TITLE,
|
44 |
+
theme=gr.themes.Monochrome(),
|
45 |
+
) as demo:
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column(scale=1):
|
48 |
+
gr.Markdown('# ' + _TITLE)
|
49 |
+
gr.Markdown(_DESCRIPTION)
|
50 |
+
create_3d_ui("wkl")
|
51 |
+
|
52 |
+
launch_args = {}
|
53 |
+
if listen:
|
54 |
+
launch_args["server_name"] = "0.0.0.0"
|
55 |
+
|
56 |
+
demo.queue(default_concurrency_limit=1).launch(
|
57 |
+
server_port=None if port == 0 else port,
|
58 |
+
share=share,
|
59 |
+
root_path=gradio_root if gradio_root != "" else None, # "/myapp"
|
60 |
+
**launch_args,
|
61 |
+
)
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
parser = argparse.ArgumentParser()
|
65 |
+
args, extra = parser.parse_known_args()
|
66 |
+
parser.add_argument("--listen", action="store_true")
|
67 |
+
parser.add_argument("--port", type=int, default=0)
|
68 |
+
parser.add_argument("--share", action="store_true")
|
69 |
+
parser.add_argument("--gradio_root", default="")
|
70 |
+
args = parser.parse_args()
|
71 |
+
launch(
|
72 |
+
args.port,
|
73 |
+
listen=args.listen,
|
74 |
+
share=args.share,
|
75 |
+
gradio_root=args.gradio_root,
|
76 |
+
)
|
app/utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import gc
|
5 |
+
import numpy as np
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from scripts.refine_lr_to_sr import run_sr_fast
|
9 |
+
|
10 |
+
GRADIO_CACHE = "/tmp/gradio/"
|
11 |
+
|
12 |
+
def clean_up():
|
13 |
+
torch.cuda.empty_cache()
|
14 |
+
gc.collect()
|
15 |
+
|
16 |
+
def remove_color(arr):
|
17 |
+
if arr.shape[-1] == 4:
|
18 |
+
arr = arr[..., :3]
|
19 |
+
# calc diffs
|
20 |
+
base = arr[0, 0]
|
21 |
+
diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1)
|
22 |
+
alpha = (diffs <= 80)
|
23 |
+
|
24 |
+
arr[alpha] = 255
|
25 |
+
alpha = ~alpha
|
26 |
+
arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1)
|
27 |
+
return arr
|
28 |
+
|
29 |
+
def simple_remove(imgs, run_sr=True):
|
30 |
+
"""Only works for normal"""
|
31 |
+
if not isinstance(imgs, list):
|
32 |
+
imgs = [imgs]
|
33 |
+
single_input = True
|
34 |
+
else:
|
35 |
+
single_input = False
|
36 |
+
if run_sr:
|
37 |
+
imgs = run_sr_fast(imgs)
|
38 |
+
rets = []
|
39 |
+
for img in imgs:
|
40 |
+
arr = np.array(img)
|
41 |
+
arr = remove_color(arr)
|
42 |
+
rets.append(Image.fromarray(arr.astype(np.uint8)))
|
43 |
+
if single_input:
|
44 |
+
return rets[0]
|
45 |
+
return rets
|
46 |
+
|
47 |
+
def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"):
|
48 |
+
new_image = Image.new("RGBA", rgba.size, bkgd)
|
49 |
+
new_image.paste(rgba, (0, 0), rgba)
|
50 |
+
new_image = new_image.convert('RGB')
|
51 |
+
return new_image
|
52 |
+
|
53 |
+
def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"):
|
54 |
+
rgb_white = rgba_to_rgb(rgba, bkgd)
|
55 |
+
new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1))
|
56 |
+
return new_rgba
|
57 |
+
|
58 |
+
def split_image(image, rows=None, cols=None):
|
59 |
+
"""
|
60 |
+
inverse function of make_image_grid
|
61 |
+
"""
|
62 |
+
# image is in square
|
63 |
+
if rows is None and cols is None:
|
64 |
+
# image.size [W, H]
|
65 |
+
rows = 1
|
66 |
+
cols = image.size[0] // image.size[1]
|
67 |
+
assert cols * image.size[1] == image.size[0]
|
68 |
+
subimg_size = image.size[1]
|
69 |
+
elif rows is None:
|
70 |
+
subimg_size = image.size[0] // cols
|
71 |
+
rows = image.size[1] // subimg_size
|
72 |
+
assert rows * subimg_size == image.size[1]
|
73 |
+
elif cols is None:
|
74 |
+
subimg_size = image.size[1] // rows
|
75 |
+
cols = image.size[0] // subimg_size
|
76 |
+
assert cols * subimg_size == image.size[0]
|
77 |
+
else:
|
78 |
+
subimg_size = image.size[1] // rows
|
79 |
+
assert cols * subimg_size == image.size[0]
|
80 |
+
subimgs = []
|
81 |
+
for i in range(rows):
|
82 |
+
for j in range(cols):
|
83 |
+
subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size))
|
84 |
+
subimgs.append(subimg)
|
85 |
+
return subimgs
|
86 |
+
|
87 |
+
def make_image_grid(images, rows=None, cols=None, resize=None):
|
88 |
+
if rows is None and cols is None:
|
89 |
+
rows = 1
|
90 |
+
cols = len(images)
|
91 |
+
if rows is None:
|
92 |
+
rows = len(images) // cols
|
93 |
+
if len(images) % cols != 0:
|
94 |
+
rows += 1
|
95 |
+
if cols is None:
|
96 |
+
cols = len(images) // rows
|
97 |
+
if len(images) % rows != 0:
|
98 |
+
cols += 1
|
99 |
+
total_imgs = rows * cols
|
100 |
+
if total_imgs > len(images):
|
101 |
+
images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))]
|
102 |
+
|
103 |
+
if resize is not None:
|
104 |
+
images = [img.resize((resize, resize)) for img in images]
|
105 |
+
|
106 |
+
w, h = images[0].size
|
107 |
+
grid = Image.new(images[0].mode, size=(cols * w, rows * h))
|
108 |
+
|
109 |
+
for i, img in enumerate(images):
|
110 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
111 |
+
return grid
|
112 |
+
|
assets/teaser.jpg
ADDED
assets/teaser_safe.jpg
ADDED
Git LFS Details
|
custum_3d_diffusion/custum_modules/attention_processors.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional
|
2 |
+
import torch
|
3 |
+
from diffusers.models.attention_processor import Attention
|
4 |
+
|
5 |
+
def construct_pix2pix_attention(hidden_states_dim, norm_type="none"):
|
6 |
+
if norm_type == "layernorm":
|
7 |
+
norm = torch.nn.LayerNorm(hidden_states_dim)
|
8 |
+
else:
|
9 |
+
norm = torch.nn.Identity()
|
10 |
+
attention = Attention(
|
11 |
+
query_dim=hidden_states_dim,
|
12 |
+
heads=8,
|
13 |
+
dim_head=hidden_states_dim // 8,
|
14 |
+
bias=True,
|
15 |
+
)
|
16 |
+
# NOTE: xformers 0.22 does not support batchsize >= 4096
|
17 |
+
attention.xformers_not_supported = True # hacky solution
|
18 |
+
return norm, attention
|
19 |
+
|
20 |
+
class ExtraAttnProc(torch.nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
chained_proc,
|
24 |
+
enabled=False,
|
25 |
+
name=None,
|
26 |
+
mode='extract',
|
27 |
+
with_proj_in=False,
|
28 |
+
proj_in_dim=768,
|
29 |
+
target_dim=None,
|
30 |
+
pixel_wise_crosspond=False,
|
31 |
+
norm_type="none", # none or layernorm
|
32 |
+
crosspond_effect_on="all", # all or first
|
33 |
+
crosspond_chain_pos="parralle", # before or parralle or after
|
34 |
+
simple_3d=False,
|
35 |
+
views=4,
|
36 |
+
) -> None:
|
37 |
+
super().__init__()
|
38 |
+
self.enabled = enabled
|
39 |
+
self.chained_proc = chained_proc
|
40 |
+
self.name = name
|
41 |
+
self.mode = mode
|
42 |
+
self.with_proj_in=with_proj_in
|
43 |
+
self.proj_in_dim = proj_in_dim
|
44 |
+
self.target_dim = target_dim or proj_in_dim
|
45 |
+
self.hidden_states_dim = self.target_dim
|
46 |
+
self.pixel_wise_crosspond = pixel_wise_crosspond
|
47 |
+
self.crosspond_effect_on = crosspond_effect_on
|
48 |
+
self.crosspond_chain_pos = crosspond_chain_pos
|
49 |
+
self.views = views
|
50 |
+
self.simple_3d = simple_3d
|
51 |
+
if self.with_proj_in and self.enabled:
|
52 |
+
self.in_linear = torch.nn.Linear(self.proj_in_dim, self.target_dim, bias=False)
|
53 |
+
if self.target_dim == self.proj_in_dim:
|
54 |
+
self.in_linear.weight.data = torch.eye(proj_in_dim)
|
55 |
+
else:
|
56 |
+
self.in_linear = None
|
57 |
+
if self.pixel_wise_crosspond and self.enabled:
|
58 |
+
self.crosspond_norm, self.crosspond_attention = construct_pix2pix_attention(self.hidden_states_dim, norm_type=norm_type)
|
59 |
+
|
60 |
+
def do_crosspond_attention(self, hidden_states: torch.FloatTensor, other_states: torch.FloatTensor):
|
61 |
+
hidden_states = self.crosspond_norm(hidden_states)
|
62 |
+
|
63 |
+
batch, L, D = hidden_states.shape
|
64 |
+
assert hidden_states.shape == other_states.shape, f"got {hidden_states.shape} and {other_states.shape}"
|
65 |
+
# to -> batch * L, 1, D
|
66 |
+
hidden_states = hidden_states.reshape(batch * L, 1, D)
|
67 |
+
other_states = other_states.reshape(batch * L, 1, D)
|
68 |
+
hidden_states_catted = other_states
|
69 |
+
hidden_states = self.crosspond_attention(
|
70 |
+
hidden_states,
|
71 |
+
encoder_hidden_states=hidden_states_catted,
|
72 |
+
)
|
73 |
+
return hidden_states.reshape(batch, L, D)
|
74 |
+
|
75 |
+
def __call__(
|
76 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
77 |
+
ref_dict: dict = None, mode=None, **kwargs
|
78 |
+
) -> Any:
|
79 |
+
if not self.enabled:
|
80 |
+
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
81 |
+
if encoder_hidden_states is None:
|
82 |
+
encoder_hidden_states = hidden_states
|
83 |
+
assert ref_dict is not None
|
84 |
+
if (mode or self.mode) == 'extract':
|
85 |
+
ref_dict[self.name] = hidden_states
|
86 |
+
hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
87 |
+
if self.pixel_wise_crosspond and self.crosspond_chain_pos == "after":
|
88 |
+
ref_dict[self.name] = hidden_states1
|
89 |
+
return hidden_states1
|
90 |
+
elif (mode or self.mode) == 'inject':
|
91 |
+
ref_state = ref_dict.pop(self.name)
|
92 |
+
if self.with_proj_in:
|
93 |
+
ref_state = self.in_linear(ref_state)
|
94 |
+
|
95 |
+
B, L, D = ref_state.shape
|
96 |
+
if hidden_states.shape[0] == B:
|
97 |
+
modalities = 1
|
98 |
+
views = 1
|
99 |
+
else:
|
100 |
+
modalities = hidden_states.shape[0] // B // self.views
|
101 |
+
views = self.views
|
102 |
+
if self.pixel_wise_crosspond:
|
103 |
+
if self.crosspond_effect_on == "all":
|
104 |
+
ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, *ref_state.shape[-2:])
|
105 |
+
|
106 |
+
if self.crosspond_chain_pos == "before":
|
107 |
+
hidden_states = hidden_states + self.do_crosspond_attention(hidden_states, ref_state)
|
108 |
+
|
109 |
+
hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
110 |
+
|
111 |
+
if self.crosspond_chain_pos == "parralle":
|
112 |
+
hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states, ref_state)
|
113 |
+
|
114 |
+
if self.crosspond_chain_pos == "after":
|
115 |
+
hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states1, ref_state)
|
116 |
+
return hidden_states1
|
117 |
+
else:
|
118 |
+
assert self.crosspond_effect_on == "first"
|
119 |
+
# hidden_states [B * modalities * views, L, D]
|
120 |
+
# ref_state [B, L, D]
|
121 |
+
ref_state = ref_state[:, None].expand(-1, modalities, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1]) # [B * modalities, L, D]
|
122 |
+
|
123 |
+
def do_paritial_crosspond(hidden_states, ref_state):
|
124 |
+
first_view_hidden_states = hidden_states.view(-1, views, hidden_states.shape[1], hidden_states.shape[2])[:, 0] # [B * modalities, L, D]
|
125 |
+
hidden_states2 = self.do_crosspond_attention(first_view_hidden_states, ref_state) # [B * modalities, L, D]
|
126 |
+
hidden_states2_padded = torch.zeros_like(hidden_states).reshape(-1, views, hidden_states.shape[1], hidden_states.shape[2])
|
127 |
+
hidden_states2_padded[:, 0] = hidden_states2
|
128 |
+
hidden_states2_padded = hidden_states2_padded.reshape(-1, hidden_states.shape[1], hidden_states.shape[2])
|
129 |
+
return hidden_states2_padded
|
130 |
+
|
131 |
+
if self.crosspond_chain_pos == "before":
|
132 |
+
hidden_states = hidden_states + do_paritial_crosspond(hidden_states, ref_state)
|
133 |
+
|
134 |
+
hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) # [B * modalities * views, L, D]
|
135 |
+
if self.crosspond_chain_pos == "parralle":
|
136 |
+
hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states, ref_state)
|
137 |
+
if self.crosspond_chain_pos == "after":
|
138 |
+
hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states1, ref_state)
|
139 |
+
return hidden_states1
|
140 |
+
elif self.simple_3d:
|
141 |
+
B, L, C = encoder_hidden_states.shape
|
142 |
+
mv = self.views
|
143 |
+
encoder_hidden_states = encoder_hidden_states.reshape(B // mv, mv, L, C)
|
144 |
+
ref_state = ref_state[:, None]
|
145 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
|
146 |
+
encoder_hidden_states = encoder_hidden_states.reshape(B // mv, 1, (mv+1) * L, C)
|
147 |
+
encoder_hidden_states = encoder_hidden_states.repeat(1, mv, 1, 1).reshape(-1, (mv+1) * L, C)
|
148 |
+
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
149 |
+
else:
|
150 |
+
ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1])
|
151 |
+
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1)
|
152 |
+
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
153 |
+
else:
|
154 |
+
raise NotImplementedError("mode or self.mode is required to be 'extract' or 'inject'")
|
155 |
+
|
156 |
+
def add_extra_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
|
157 |
+
return_dict = torch.nn.ModuleDict()
|
158 |
+
proj_in_dim = kwargs.get('proj_in_dim', False)
|
159 |
+
kwargs.pop('proj_in_dim', None)
|
160 |
+
|
161 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
162 |
+
for sub_name, child in module.named_children():
|
163 |
+
if "ref_unet" not in (sub_name + name):
|
164 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
165 |
+
|
166 |
+
if isinstance(module, Attention):
|
167 |
+
new_processor = ExtraAttnProc(
|
168 |
+
chained_proc=module.get_processor(),
|
169 |
+
enabled=enable_filter(f"{name}.processor"),
|
170 |
+
name=f"{name}.processor",
|
171 |
+
proj_in_dim=proj_in_dim if proj_in_dim else module.cross_attention_dim,
|
172 |
+
target_dim=module.cross_attention_dim,
|
173 |
+
**kwargs
|
174 |
+
)
|
175 |
+
module.set_processor(new_processor)
|
176 |
+
return_dict[f"{name}.processor".replace(".", "__")] = new_processor
|
177 |
+
|
178 |
+
for name, module in model.named_children():
|
179 |
+
recursive_add_processors(name, module)
|
180 |
+
return return_dict
|
181 |
+
|
182 |
+
def switch_extra_processor(model, enable_filter=lambda x:True):
|
183 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
184 |
+
for sub_name, child in module.named_children():
|
185 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
186 |
+
|
187 |
+
if isinstance(module, ExtraAttnProc):
|
188 |
+
module.enabled = enable_filter(name)
|
189 |
+
|
190 |
+
for name, module in model.named_children():
|
191 |
+
recursive_add_processors(name, module)
|
192 |
+
|
193 |
+
class multiviewAttnProc(torch.nn.Module):
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
chained_proc,
|
197 |
+
enabled=False,
|
198 |
+
name=None,
|
199 |
+
hidden_states_dim=None,
|
200 |
+
chain_pos="parralle", # before or parralle or after
|
201 |
+
num_modalities=1,
|
202 |
+
views=4,
|
203 |
+
base_img_size=64,
|
204 |
+
) -> None:
|
205 |
+
super().__init__()
|
206 |
+
self.enabled = enabled
|
207 |
+
self.chained_proc = chained_proc
|
208 |
+
self.name = name
|
209 |
+
self.hidden_states_dim = hidden_states_dim
|
210 |
+
self.num_modalities = num_modalities
|
211 |
+
self.views = views
|
212 |
+
self.base_img_size = base_img_size
|
213 |
+
self.chain_pos = chain_pos
|
214 |
+
self.diff_joint_attn = True
|
215 |
+
|
216 |
+
def __call__(
|
217 |
+
self,
|
218 |
+
attn: Attention,
|
219 |
+
hidden_states: torch.FloatTensor,
|
220 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
221 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
222 |
+
**kwargs
|
223 |
+
) -> torch.Tensor:
|
224 |
+
if not self.enabled:
|
225 |
+
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
226 |
+
|
227 |
+
B, L, C = hidden_states.shape
|
228 |
+
mv = self.views
|
229 |
+
hidden_states = hidden_states.reshape(B // mv, mv, L, C).reshape(-1, mv * L, C)
|
230 |
+
hidden_states = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
|
231 |
+
return hidden_states.reshape(B // mv, mv, L, C).reshape(-1, L, C)
|
232 |
+
|
233 |
+
def add_multiview_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs):
|
234 |
+
return_dict = torch.nn.ModuleDict()
|
235 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
236 |
+
for sub_name, child in module.named_children():
|
237 |
+
if "ref_unet" not in (sub_name + name):
|
238 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
239 |
+
|
240 |
+
if isinstance(module, Attention):
|
241 |
+
new_processor = multiviewAttnProc(
|
242 |
+
chained_proc=module.get_processor(),
|
243 |
+
enabled=enable_filter(f"{name}.processor"),
|
244 |
+
name=f"{name}.processor",
|
245 |
+
hidden_states_dim=module.inner_dim,
|
246 |
+
**kwargs
|
247 |
+
)
|
248 |
+
module.set_processor(new_processor)
|
249 |
+
return_dict[f"{name}.processor".replace(".", "__")] = new_processor
|
250 |
+
|
251 |
+
for name, module in model.named_children():
|
252 |
+
recursive_add_processors(name, module)
|
253 |
+
|
254 |
+
return return_dict
|
255 |
+
|
256 |
+
def switch_multiview_processor(model, enable_filter=lambda x:True):
|
257 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
258 |
+
for sub_name, child in module.named_children():
|
259 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
260 |
+
|
261 |
+
if isinstance(module, Attention):
|
262 |
+
processor = module.get_processor()
|
263 |
+
if isinstance(processor, multiviewAttnProc):
|
264 |
+
processor.enabled = enable_filter(f"{name}.processor")
|
265 |
+
|
266 |
+
for name, module in model.named_children():
|
267 |
+
recursive_add_processors(name, module)
|
268 |
+
|
269 |
+
class NNModuleWrapper(torch.nn.Module):
|
270 |
+
def __init__(self, module):
|
271 |
+
super().__init__()
|
272 |
+
self.module = module
|
273 |
+
|
274 |
+
def forward(self, *args, **kwargs):
|
275 |
+
return self.module(*args, **kwargs)
|
276 |
+
|
277 |
+
def __getattr__(self, name: str):
|
278 |
+
try:
|
279 |
+
return super().__getattr__(name)
|
280 |
+
except AttributeError:
|
281 |
+
return getattr(self.module, name)
|
282 |
+
|
283 |
+
class AttnProcessorSwitch(torch.nn.Module):
|
284 |
+
def __init__(
|
285 |
+
self,
|
286 |
+
proc_dict: dict,
|
287 |
+
enabled_proc="default",
|
288 |
+
name=None,
|
289 |
+
switch_name="default_switch",
|
290 |
+
):
|
291 |
+
super().__init__()
|
292 |
+
self.proc_dict = torch.nn.ModuleDict({k: (v if isinstance(v, torch.nn.Module) else NNModuleWrapper(v)) for k, v in proc_dict.items()})
|
293 |
+
self.enabled_proc = enabled_proc
|
294 |
+
self.name = name
|
295 |
+
self.switch_name = switch_name
|
296 |
+
self.choose_module(enabled_proc)
|
297 |
+
|
298 |
+
def choose_module(self, enabled_proc):
|
299 |
+
self.enabled_proc = enabled_proc
|
300 |
+
assert enabled_proc in self.proc_dict.keys()
|
301 |
+
|
302 |
+
def __call__(
|
303 |
+
self,
|
304 |
+
*args,
|
305 |
+
**kwargs
|
306 |
+
) -> torch.FloatTensor:
|
307 |
+
used_proc = self.proc_dict[self.enabled_proc]
|
308 |
+
return used_proc(*args, **kwargs)
|
309 |
+
|
310 |
+
def add_switch(model: torch.nn.Module, module_filter=lambda x:True, switch_dict_fn=lambda x: {"default": x}, switch_name="default_switch", enabled_proc="default"):
|
311 |
+
return_dict = torch.nn.ModuleDict()
|
312 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
313 |
+
for sub_name, child in module.named_children():
|
314 |
+
if "ref_unet" not in (sub_name + name):
|
315 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
316 |
+
|
317 |
+
if isinstance(module, Attention):
|
318 |
+
processor = module.get_processor()
|
319 |
+
if module_filter(processor):
|
320 |
+
proc_dict = switch_dict_fn(processor)
|
321 |
+
new_processor = AttnProcessorSwitch(
|
322 |
+
proc_dict=proc_dict,
|
323 |
+
enabled_proc=enabled_proc,
|
324 |
+
name=f"{name}.processor",
|
325 |
+
switch_name=switch_name,
|
326 |
+
)
|
327 |
+
module.set_processor(new_processor)
|
328 |
+
return_dict[f"{name}.processor".replace(".", "__")] = new_processor
|
329 |
+
|
330 |
+
for name, module in model.named_children():
|
331 |
+
recursive_add_processors(name, module)
|
332 |
+
|
333 |
+
return return_dict
|
334 |
+
|
335 |
+
def change_switch(model: torch.nn.Module, switch_name="default_switch", enabled_proc="default"):
|
336 |
+
def recursive_change_processors(name: str, module: torch.nn.Module):
|
337 |
+
for sub_name, child in module.named_children():
|
338 |
+
recursive_change_processors(f"{name}.{sub_name}", child)
|
339 |
+
|
340 |
+
if isinstance(module, Attention):
|
341 |
+
processor = module.get_processor()
|
342 |
+
if isinstance(processor, AttnProcessorSwitch) and processor.switch_name == switch_name:
|
343 |
+
processor.choose_module(enabled_proc)
|
344 |
+
|
345 |
+
for name, module in model.named_children():
|
346 |
+
recursive_change_processors(name, module)
|
347 |
+
|
348 |
+
########## Hack: Attention fix #############
|
349 |
+
from diffusers.models.attention import Attention
|
350 |
+
|
351 |
+
def forward(
|
352 |
+
self,
|
353 |
+
hidden_states: torch.FloatTensor,
|
354 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
355 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
356 |
+
**cross_attention_kwargs,
|
357 |
+
) -> torch.Tensor:
|
358 |
+
r"""
|
359 |
+
The forward method of the `Attention` class.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
hidden_states (`torch.Tensor`):
|
363 |
+
The hidden states of the query.
|
364 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
365 |
+
The hidden states of the encoder.
|
366 |
+
attention_mask (`torch.Tensor`, *optional*):
|
367 |
+
The attention mask to use. If `None`, no mask is applied.
|
368 |
+
**cross_attention_kwargs:
|
369 |
+
Additional keyword arguments to pass along to the cross attention.
|
370 |
+
|
371 |
+
Returns:
|
372 |
+
`torch.Tensor`: The output of the attention layer.
|
373 |
+
"""
|
374 |
+
# The `Attention` class can call different attention processors / attention functions
|
375 |
+
# here we simply pass along all tensors to the selected processor class
|
376 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
377 |
+
return self.processor(
|
378 |
+
self,
|
379 |
+
hidden_states,
|
380 |
+
encoder_hidden_states=encoder_hidden_states,
|
381 |
+
attention_mask=attention_mask,
|
382 |
+
**cross_attention_kwargs,
|
383 |
+
)
|
384 |
+
|
385 |
+
Attention.forward = forward
|
custum_3d_diffusion/custum_modules/unifield_processor.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import FunctionType
|
2 |
+
from typing import Any, Dict, List
|
3 |
+
from diffusers import UNet2DConditionModel
|
4 |
+
import torch
|
5 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, ImageProjection
|
6 |
+
from diffusers.models.attention_processor import Attention, AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from diffusers.loaders import IPAdapterMixin
|
9 |
+
from custum_3d_diffusion.custum_modules.attention_processors import add_extra_processor, switch_extra_processor, add_multiview_processor, switch_multiview_processor, add_switch, change_switch
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class AttnConfig:
|
13 |
+
"""
|
14 |
+
* CrossAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), IPAdapter module (achieves conceptual control).
|
15 |
+
* SelfAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), Reference Attention module (achieves pixel-level control).
|
16 |
+
* Multiview Attention module: Multiview Attention module (achieves multi-view consistency).
|
17 |
+
* Cross Modality Attention module: Cross Modality Attention module (achieves multi-modality consistency).
|
18 |
+
|
19 |
+
For setups:
|
20 |
+
train_xxx_lr is implemented in the U-Net architecture.
|
21 |
+
enable_xxx_lora is implemented in the U-Net architecture.
|
22 |
+
enable_xxx_ip is implemented in the processor and U-Net architecture.
|
23 |
+
enable_xxx_ref_proj_in is implemented in the processor.
|
24 |
+
"""
|
25 |
+
latent_size: int = 64
|
26 |
+
|
27 |
+
train_lr: float = 0
|
28 |
+
# for cross attention
|
29 |
+
# 0 learning rate for not training
|
30 |
+
train_cross_attn_lr: float = 0
|
31 |
+
train_cross_attn_lora_lr: float = 0
|
32 |
+
train_cross_attn_ip_lr: float = 0 # 0 for not trained
|
33 |
+
init_cross_attn_lora: bool = False
|
34 |
+
enable_cross_attn_lora: bool = False
|
35 |
+
init_cross_attn_ip: bool = False
|
36 |
+
enable_cross_attn_ip: bool = False
|
37 |
+
cross_attn_lora_rank: int = 64 # 0 for not enabled
|
38 |
+
cross_attn_lora_only_kv: bool = False
|
39 |
+
ipadapter_pretrained_name: str = "h94/IP-Adapter"
|
40 |
+
ipadapter_subfolder_name: str = "models"
|
41 |
+
ipadapter_weight_name: str = "ip-adapter-plus_sd15.safetensors"
|
42 |
+
ipadapter_effect_on: str = "all" # all, first
|
43 |
+
|
44 |
+
# for self attention
|
45 |
+
train_self_attn_lr: float = 0
|
46 |
+
train_self_attn_lora_lr: float = 0
|
47 |
+
init_self_attn_lora: bool = False
|
48 |
+
enable_self_attn_lora: bool = False
|
49 |
+
self_attn_lora_rank: int = 64
|
50 |
+
self_attn_lora_only_kv: bool = False
|
51 |
+
|
52 |
+
train_self_attn_ref_lr: float = 0
|
53 |
+
train_ref_unet_lr: float = 0
|
54 |
+
init_self_attn_ref: bool = False
|
55 |
+
enable_self_attn_ref: bool = False
|
56 |
+
self_attn_ref_other_model_name: str = ""
|
57 |
+
self_attn_ref_position: str = "attn1"
|
58 |
+
self_attn_ref_pixel_wise_crosspond: bool = False # enable pixel_wise_crosspond in refattn
|
59 |
+
self_attn_ref_chain_pos: str = "parralle" # before or parralle or after
|
60 |
+
self_attn_ref_effect_on: str = "all" # all or first, for _crosspond attn
|
61 |
+
self_attn_ref_zero_init: bool = True
|
62 |
+
use_simple3d_attn: bool = False
|
63 |
+
|
64 |
+
# for multiview attention
|
65 |
+
init_multiview_attn: bool = False
|
66 |
+
enable_multiview_attn: bool = False
|
67 |
+
multiview_attn_position: str = "attn1"
|
68 |
+
multiview_chain_pose: str = "parralle" # before or parralle or after
|
69 |
+
num_modalities: int = 1
|
70 |
+
use_mv_joint_attn: bool = False
|
71 |
+
|
72 |
+
# for unet
|
73 |
+
init_unet_path: str = "runwayml/stable-diffusion-v1-5"
|
74 |
+
init_num_cls_label: int = 0 # for initialize
|
75 |
+
cls_labels: List[int] = field(default_factory=lambda: [])
|
76 |
+
cls_label_type: str = "embedding"
|
77 |
+
cat_condition: bool = False # cat condition to input
|
78 |
+
|
79 |
+
class Configurable:
|
80 |
+
attn_config: AttnConfig
|
81 |
+
|
82 |
+
def set_config(self, attn_config: AttnConfig):
|
83 |
+
raise NotImplementedError()
|
84 |
+
|
85 |
+
def update_config(self, attn_config: AttnConfig):
|
86 |
+
self.attn_config = attn_config
|
87 |
+
|
88 |
+
def do_set_config(self, attn_config: AttnConfig):
|
89 |
+
self.set_config(attn_config)
|
90 |
+
for name, module in self.named_modules():
|
91 |
+
if isinstance(module, Configurable):
|
92 |
+
if hasattr(module, "do_set_config"):
|
93 |
+
module.do_set_config(attn_config)
|
94 |
+
else:
|
95 |
+
print(f"Warning: {name} has no attribute do_set_config, but is an instance of Configurable")
|
96 |
+
module.attn_config = attn_config
|
97 |
+
|
98 |
+
def do_update_config(self, attn_config: AttnConfig):
|
99 |
+
self.update_config(attn_config)
|
100 |
+
for name, module in self.named_modules():
|
101 |
+
if isinstance(module, Configurable):
|
102 |
+
if hasattr(module, "do_update_config"):
|
103 |
+
module.do_update_config(attn_config)
|
104 |
+
else:
|
105 |
+
print(f"Warning: {name} has no attribute do_update_config, but is an instance of Configurable")
|
106 |
+
module.attn_config = attn_config
|
107 |
+
|
108 |
+
from diffusers import ModelMixin # Must import ModelMixin for CompiledUNet
|
109 |
+
class UnifieldWrappedUNet(UNet2DConditionModel):
|
110 |
+
forward_hook: FunctionType
|
111 |
+
|
112 |
+
def forward(self, *args, **kwargs):
|
113 |
+
if hasattr(self, 'forward_hook'):
|
114 |
+
return self.forward_hook(super().forward, *args, **kwargs)
|
115 |
+
return super().forward(*args, **kwargs)
|
116 |
+
|
117 |
+
|
118 |
+
class ConfigurableUNet2DConditionModel(Configurable, IPAdapterMixin):
|
119 |
+
unet: UNet2DConditionModel
|
120 |
+
|
121 |
+
cls_embedding_param_dict = {}
|
122 |
+
cross_attn_lora_param_dict = {}
|
123 |
+
self_attn_lora_param_dict = {}
|
124 |
+
cross_attn_param_dict = {}
|
125 |
+
self_attn_param_dict = {}
|
126 |
+
ipadapter_param_dict = {}
|
127 |
+
ref_attn_param_dict = {}
|
128 |
+
ref_unet_param_dict = {}
|
129 |
+
multiview_attn_param_dict = {}
|
130 |
+
other_param_dict = {}
|
131 |
+
|
132 |
+
rev_param_name_mapping = {}
|
133 |
+
|
134 |
+
class_labels = []
|
135 |
+
def set_class_labels(self, class_labels: torch.Tensor):
|
136 |
+
if self.attn_config.init_num_cls_label != 0:
|
137 |
+
self.class_labels = class_labels.to(self.unet.device).long()
|
138 |
+
|
139 |
+
def __init__(self, init_config: AttnConfig, weight_dtype) -> None:
|
140 |
+
super().__init__()
|
141 |
+
self.weight_dtype = weight_dtype
|
142 |
+
self.set_config(init_config)
|
143 |
+
|
144 |
+
def enable_xformers_memory_efficient_attention(self):
|
145 |
+
self.unet.enable_xformers_memory_efficient_attention
|
146 |
+
def recursive_add_processors(name: str, module: torch.nn.Module):
|
147 |
+
for sub_name, child in module.named_children():
|
148 |
+
recursive_add_processors(f"{name}.{sub_name}", child)
|
149 |
+
|
150 |
+
if isinstance(module, Attention):
|
151 |
+
if hasattr(module, 'xformers_not_supported'):
|
152 |
+
return
|
153 |
+
old_processor = module.get_processor()
|
154 |
+
if isinstance(old_processor, (AttnProcessor, AttnProcessor2_0)):
|
155 |
+
module.set_use_memory_efficient_attention_xformers(True)
|
156 |
+
|
157 |
+
for name, module in self.unet.named_children():
|
158 |
+
recursive_add_processors(name, module)
|
159 |
+
|
160 |
+
def __getattr__(self, name: str) -> Any:
|
161 |
+
try:
|
162 |
+
return super().__getattr__(name)
|
163 |
+
except AttributeError:
|
164 |
+
return getattr(self.unet, name)
|
165 |
+
|
166 |
+
# --- for IPAdapterMixin
|
167 |
+
|
168 |
+
def register_modules(self, **kwargs):
|
169 |
+
for name, module in kwargs.items():
|
170 |
+
# set models
|
171 |
+
setattr(self, name, module)
|
172 |
+
|
173 |
+
def register_to_config(self, **kwargs):
|
174 |
+
pass
|
175 |
+
|
176 |
+
def unload_ip_adapter(self):
|
177 |
+
raise NotImplementedError()
|
178 |
+
|
179 |
+
# --- for Configurable
|
180 |
+
|
181 |
+
def get_refunet(self):
|
182 |
+
if self.attn_config.self_attn_ref_other_model_name == "self":
|
183 |
+
return self.unet
|
184 |
+
else:
|
185 |
+
return self.unet.ref_unet
|
186 |
+
|
187 |
+
def set_config(self, attn_config: AttnConfig):
|
188 |
+
self.attn_config = attn_config
|
189 |
+
|
190 |
+
unet_type = UnifieldWrappedUNet
|
191 |
+
# class_embed_type = "projection" for 'camera'
|
192 |
+
# class_embed_type = None for 'embedding'
|
193 |
+
unet_kwargs = {}
|
194 |
+
if attn_config.init_num_cls_label > 0:
|
195 |
+
if attn_config.cls_label_type == "embedding":
|
196 |
+
unet_kwargs = {
|
197 |
+
"num_class_embeds": attn_config.init_num_cls_label,
|
198 |
+
"device_map": None,
|
199 |
+
"low_cpu_mem_usage": False,
|
200 |
+
"class_embed_type": None,
|
201 |
+
}
|
202 |
+
else:
|
203 |
+
raise ValueError(f"cls_label_type {attn_config.cls_label_type} is not supported")
|
204 |
+
|
205 |
+
self.unet: UnifieldWrappedUNet = unet_type.from_pretrained(
|
206 |
+
attn_config.init_unet_path, subfolder="unet", torch_dtype=self.weight_dtype,
|
207 |
+
ignore_mismatched_sizes=True, # Added this line
|
208 |
+
**unet_kwargs
|
209 |
+
)
|
210 |
+
assert isinstance(self.unet, UnifieldWrappedUNet)
|
211 |
+
self.unet.forward_hook = self.unet_forward_hook
|
212 |
+
|
213 |
+
if self.attn_config.cat_condition:
|
214 |
+
# double in_channels
|
215 |
+
if self.unet.config.in_channels != 8:
|
216 |
+
self.unet.register_to_config(in_channels=self.unet.config.in_channels * 2)
|
217 |
+
# repeate unet.conv_in weight twice
|
218 |
+
doubled_conv_in = torch.nn.Conv2d(self.unet.conv_in.in_channels * 2, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
|
219 |
+
doubled_conv_in.weight.data = torch.cat([self.unet.conv_in.weight.data, torch.zeros_like(self.unet.conv_in.weight.data)], dim=1)
|
220 |
+
doubled_conv_in.bias.data = self.unet.conv_in.bias.data
|
221 |
+
self.unet.conv_in = doubled_conv_in
|
222 |
+
|
223 |
+
used_param_ids = set()
|
224 |
+
|
225 |
+
if attn_config.init_cross_attn_lora:
|
226 |
+
# setup lora
|
227 |
+
from peft import LoraConfig
|
228 |
+
from peft.utils import get_peft_model_state_dict
|
229 |
+
if attn_config.cross_attn_lora_only_kv:
|
230 |
+
target_modules=["attn2.to_k", "attn2.to_v"]
|
231 |
+
else:
|
232 |
+
target_modules=["attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0"]
|
233 |
+
lora_config: LoraConfig = LoraConfig(
|
234 |
+
r=attn_config.cross_attn_lora_rank,
|
235 |
+
lora_alpha=attn_config.cross_attn_lora_rank,
|
236 |
+
init_lora_weights="gaussian",
|
237 |
+
target_modules=target_modules,
|
238 |
+
)
|
239 |
+
adapter_name="cross_attn_lora"
|
240 |
+
self.unet.add_adapter(lora_config, adapter_name=adapter_name)
|
241 |
+
# update cross_attn_lora_param_dict
|
242 |
+
self.cross_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
|
243 |
+
used_param_ids.update(self.cross_attn_lora_param_dict.keys())
|
244 |
+
|
245 |
+
if attn_config.init_self_attn_lora:
|
246 |
+
# setup lora
|
247 |
+
from peft import LoraConfig
|
248 |
+
if attn_config.self_attn_lora_only_kv:
|
249 |
+
target_modules=["attn1.to_k", "attn1.to_v"]
|
250 |
+
else:
|
251 |
+
target_modules=["attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0"]
|
252 |
+
lora_config: LoraConfig = LoraConfig(
|
253 |
+
r=attn_config.self_attn_lora_rank,
|
254 |
+
lora_alpha=attn_config.self_attn_lora_rank,
|
255 |
+
init_lora_weights="gaussian",
|
256 |
+
target_modules=target_modules,
|
257 |
+
)
|
258 |
+
adapter_name="self_attn_lora"
|
259 |
+
self.unet.add_adapter(lora_config, adapter_name=adapter_name)
|
260 |
+
# update cross_self_lora_param_dict
|
261 |
+
self.self_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids}
|
262 |
+
used_param_ids.update(self.self_attn_lora_param_dict.keys())
|
263 |
+
|
264 |
+
if attn_config.init_num_cls_label != 0:
|
265 |
+
self.cls_embedding_param_dict = {id(param): param for param in self.unet.class_embedding.parameters()}
|
266 |
+
used_param_ids.update(self.cls_embedding_param_dict.keys())
|
267 |
+
self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
|
268 |
+
|
269 |
+
if attn_config.init_cross_attn_ip:
|
270 |
+
self.image_encoder = None
|
271 |
+
# setup ipadapter
|
272 |
+
self.load_ip_adapter(
|
273 |
+
attn_config.ipadapter_pretrained_name,
|
274 |
+
subfolder=attn_config.ipadapter_subfolder_name,
|
275 |
+
weight_name=attn_config.ipadapter_weight_name
|
276 |
+
)
|
277 |
+
# warp ip_adapter_attn_proc with switch
|
278 |
+
from diffusers.models.attention_processor import IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0
|
279 |
+
add_switch(self.unet, module_filter=lambda x: isinstance(x, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)), switch_dict_fn=lambda x: {"ipadapter": x, "default": XFormersAttnProcessor()}, switch_name="ipadapter_switch", enabled_proc="ipadapter")
|
280 |
+
# update ipadapter_param_dict
|
281 |
+
# weights are in attention processors and unet.encoder_hid_proj
|
282 |
+
self.ipadapter_param_dict = {id(param): param for param in self.unet.encoder_hid_proj.parameters() if id(param) not in used_param_ids}
|
283 |
+
used_param_ids.update(self.ipadapter_param_dict.keys())
|
284 |
+
print("DEBUG: ipadapter_param_dict len in encoder_hid_proj", len(self.ipadapter_param_dict))
|
285 |
+
for name, processor in self.unet.attn_processors.items():
|
286 |
+
if hasattr(processor, "to_k_ip"):
|
287 |
+
self.ipadapter_param_dict.update({id(param): param for param in processor.parameters()})
|
288 |
+
print(f"DEBUG: ipadapter_param_dict len in all", len(self.ipadapter_param_dict))
|
289 |
+
|
290 |
+
ref_unet = None
|
291 |
+
if attn_config.init_self_attn_ref:
|
292 |
+
# setup reference attention processor
|
293 |
+
if attn_config.self_attn_ref_other_model_name == "self":
|
294 |
+
raise NotImplementedError("self reference is not fully implemented")
|
295 |
+
else:
|
296 |
+
ref_unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
|
297 |
+
attn_config.self_attn_ref_other_model_name, subfolder="unet", torch_dtype=self.unet.dtype
|
298 |
+
)
|
299 |
+
ref_unet.to(self.unet.device)
|
300 |
+
if self.attn_config.train_ref_unet_lr == 0:
|
301 |
+
ref_unet.eval()
|
302 |
+
ref_unet.requires_grad_(False)
|
303 |
+
else:
|
304 |
+
ref_unet.train()
|
305 |
+
|
306 |
+
add_extra_processor(
|
307 |
+
model=ref_unet,
|
308 |
+
enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
|
309 |
+
mode='extract',
|
310 |
+
with_proj_in=False,
|
311 |
+
pixel_wise_crosspond=False,
|
312 |
+
)
|
313 |
+
# NOTE: Here require cross_attention_dim in two unet's self attention should be the same
|
314 |
+
processor_dict = add_extra_processor(
|
315 |
+
model=self.unet,
|
316 |
+
enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"),
|
317 |
+
mode='inject',
|
318 |
+
with_proj_in=False,
|
319 |
+
pixel_wise_crosspond=attn_config.self_attn_ref_pixel_wise_crosspond,
|
320 |
+
crosspond_effect_on=attn_config.self_attn_ref_effect_on,
|
321 |
+
crosspond_chain_pos=attn_config.self_attn_ref_chain_pos,
|
322 |
+
simple_3d=attn_config.use_simple3d_attn,
|
323 |
+
)
|
324 |
+
self.ref_unet_param_dict = {id(param): param for name, param in ref_unet.named_parameters() if id(param) not in used_param_ids and (attn_config.self_attn_ref_position in name)}
|
325 |
+
if attn_config.self_attn_ref_chain_pos != "after":
|
326 |
+
# pop untrainable paramters
|
327 |
+
for name, param in ref_unet.named_parameters():
|
328 |
+
if id(param) in self.ref_unet_param_dict and ('up_blocks.3.attentions.2.transformer_blocks.0.' in name):
|
329 |
+
self.ref_unet_param_dict.pop(id(param))
|
330 |
+
used_param_ids.update(self.ref_unet_param_dict.keys())
|
331 |
+
# update ref_attn_param_dict
|
332 |
+
self.ref_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
|
333 |
+
used_param_ids.update(self.ref_attn_param_dict.keys())
|
334 |
+
|
335 |
+
if attn_config.init_multiview_attn:
|
336 |
+
processor_dict = add_multiview_processor(
|
337 |
+
model = self.unet,
|
338 |
+
enable_filter = lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"),
|
339 |
+
num_modalities = attn_config.num_modalities,
|
340 |
+
base_img_size = attn_config.latent_size,
|
341 |
+
chain_pos = attn_config.multiview_chain_pose,
|
342 |
+
)
|
343 |
+
# update multiview_attn_param_dict
|
344 |
+
self.multiview_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids}
|
345 |
+
used_param_ids.update(self.multiview_attn_param_dict.keys())
|
346 |
+
|
347 |
+
# initialize cross_attn_param_dict parameters
|
348 |
+
self.cross_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn2" in name and id(param) not in used_param_ids}
|
349 |
+
used_param_ids.update(self.cross_attn_param_dict.keys())
|
350 |
+
|
351 |
+
# initialize self_attn_param_dict parameters
|
352 |
+
self.self_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn1" in name and id(param) not in used_param_ids}
|
353 |
+
used_param_ids.update(self.self_attn_param_dict.keys())
|
354 |
+
|
355 |
+
# initialize other_param_dict parameters
|
356 |
+
self.other_param_dict = {id(param): param for name, param in self.unet.named_parameters() if id(param) not in used_param_ids}
|
357 |
+
|
358 |
+
if ref_unet is not None:
|
359 |
+
self.unet.ref_unet = ref_unet
|
360 |
+
|
361 |
+
self.rev_param_name_mapping = {id(param): name for name, param in self.unet.named_parameters()}
|
362 |
+
|
363 |
+
self.update_config(attn_config, force_update=True)
|
364 |
+
return self.unet
|
365 |
+
|
366 |
+
_attn_keys_to_update = ["enable_cross_attn_lora", "enable_cross_attn_ip", "enable_self_attn_lora", "enable_self_attn_ref", "enable_multiview_attn", "cls_labels"]
|
367 |
+
|
368 |
+
def update_config(self, attn_config: AttnConfig, force_update=False):
|
369 |
+
assert isinstance(self.unet, UNet2DConditionModel), "unet must be an instance of UNet2DConditionModel"
|
370 |
+
|
371 |
+
need_to_update = False
|
372 |
+
# update cls_labels
|
373 |
+
for key in self._attn_keys_to_update:
|
374 |
+
if getattr(self.attn_config, key) != getattr(attn_config, key):
|
375 |
+
need_to_update = True
|
376 |
+
break
|
377 |
+
if not force_update and not need_to_update:
|
378 |
+
return
|
379 |
+
|
380 |
+
self.set_class_labels(torch.tensor(attn_config.cls_labels).long())
|
381 |
+
|
382 |
+
# setup loras
|
383 |
+
if self.attn_config.init_cross_attn_lora or self.attn_config.init_self_attn_lora:
|
384 |
+
if attn_config.enable_cross_attn_lora or attn_config.enable_self_attn_lora:
|
385 |
+
cross_attn_lora_weight = 1. if attn_config.enable_cross_attn_lora > 0 else 0
|
386 |
+
self_attn_lora_weight = 1. if attn_config.enable_self_attn_lora > 0 else 0
|
387 |
+
self.unet.set_adapters(["cross_attn_lora", "self_attn_lora"], weights=[cross_attn_lora_weight, self_attn_lora_weight])
|
388 |
+
else:
|
389 |
+
self.unet.disable_adapters()
|
390 |
+
|
391 |
+
# setup ipadapter
|
392 |
+
if self.attn_config.init_cross_attn_ip:
|
393 |
+
if attn_config.enable_cross_attn_ip:
|
394 |
+
change_switch(self.unet, "ipadapter_switch", "ipadapter")
|
395 |
+
else:
|
396 |
+
change_switch(self.unet, "ipadapter_switch", "default")
|
397 |
+
|
398 |
+
# setup reference attention processor
|
399 |
+
if self.attn_config.init_self_attn_ref:
|
400 |
+
if attn_config.enable_self_attn_ref:
|
401 |
+
switch_extra_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"))
|
402 |
+
else:
|
403 |
+
switch_extra_processor(self.unet, enable_filter=lambda name: False)
|
404 |
+
|
405 |
+
# setup multiview attention processor
|
406 |
+
if self.attn_config.init_multiview_attn:
|
407 |
+
if attn_config.enable_multiview_attn:
|
408 |
+
switch_multiview_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"))
|
409 |
+
else:
|
410 |
+
switch_multiview_processor(self.unet, enable_filter=lambda name: False)
|
411 |
+
|
412 |
+
# update cls_labels
|
413 |
+
for key in self._attn_keys_to_update:
|
414 |
+
setattr(self.attn_config, key, getattr(attn_config, key))
|
415 |
+
|
416 |
+
def unet_forward_hook(self, raw_forward, sample: torch.FloatTensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, cross_attention_kwargs=None, condition_latents=None, class_labels=None, noisy_condition_input=False, cond_pixels_clip=None, **kwargs):
|
417 |
+
if class_labels is None and len(self.class_labels) > 0:
|
418 |
+
class_labels = self.class_labels.repeat(sample.shape[0] // self.class_labels.shape[0]).to(sample.device)
|
419 |
+
elif self.attn_config.init_num_cls_label != 0:
|
420 |
+
assert class_labels is not None, "class_labels should be passed if self.class_labels is empty and self.attn_config.init_num_cls_label is not 0"
|
421 |
+
if class_labels is not None:
|
422 |
+
if self.attn_config.cls_label_type == "embedding":
|
423 |
+
pass
|
424 |
+
else:
|
425 |
+
raise ValueError(f"cls_label_type {self.attn_config.cls_label_type} is not supported")
|
426 |
+
if self.attn_config.init_self_attn_ref and self.attn_config.enable_self_attn_ref:
|
427 |
+
# NOTE: extra step, extract condition
|
428 |
+
ref_dict = {}
|
429 |
+
ref_unet = self.get_refunet().to(sample.device)
|
430 |
+
assert condition_latents is not None
|
431 |
+
if self.attn_config.self_attn_ref_other_model_name == "self":
|
432 |
+
raise NotImplementedError()
|
433 |
+
else:
|
434 |
+
with torch.no_grad():
|
435 |
+
cond_encoder_hidden_states = encoder_hidden_states.reshape(condition_latents.shape[0], -1, *encoder_hidden_states.shape[1:])[:, 0]
|
436 |
+
if timestep.dim() == 0:
|
437 |
+
cond_timestep = timestep
|
438 |
+
else:
|
439 |
+
cond_timestep = timestep.reshape(condition_latents.shape[0], -1)[:, 0]
|
440 |
+
ref_unet(condition_latents, cond_timestep, cond_encoder_hidden_states, cross_attention_kwargs=dict(ref_dict=ref_dict))
|
441 |
+
# NOTE: extra step, inject condition
|
442 |
+
# Predict the noise residual and compute loss
|
443 |
+
if cross_attention_kwargs is None:
|
444 |
+
cross_attention_kwargs = {}
|
445 |
+
cross_attention_kwargs.update(ref_dict=ref_dict, mode='inject')
|
446 |
+
elif condition_latents is not None:
|
447 |
+
if not hasattr(self, 'condition_latents_raised'):
|
448 |
+
print("Warning! condition_latents is not None, but self_attn_ref is not enabled! This warning will only be raised once.")
|
449 |
+
self.condition_latents_raised = True
|
450 |
+
|
451 |
+
if self.attn_config.init_cross_attn_ip:
|
452 |
+
raise NotImplementedError()
|
453 |
+
|
454 |
+
if self.attn_config.cat_condition:
|
455 |
+
assert condition_latents is not None
|
456 |
+
B = condition_latents.shape[0]
|
457 |
+
cat_latents = condition_latents.reshape(B, 1, *condition_latents.shape[1:]).repeat(1, sample.shape[0] // B, 1, 1, 1).reshape(*sample.shape)
|
458 |
+
sample = torch.cat([sample, cat_latents], dim=1)
|
459 |
+
|
460 |
+
return raw_forward(sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, **kwargs)
|
custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# modified by Wuvin
|
15 |
+
|
16 |
+
|
17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
|
23 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
24 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
|
25 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
26 |
+
from PIL import Image
|
27 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
class StableDiffusionImageCustomPipeline(
|
32 |
+
StableDiffusionImageVariationPipeline
|
33 |
+
):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
vae: AutoencoderKL,
|
37 |
+
image_encoder: CLIPVisionModelWithProjection,
|
38 |
+
unet: UNet2DConditionModel,
|
39 |
+
scheduler: KarrasDiffusionSchedulers,
|
40 |
+
safety_checker: StableDiffusionSafetyChecker,
|
41 |
+
feature_extractor: CLIPImageProcessor,
|
42 |
+
requires_safety_checker: bool = True,
|
43 |
+
latents_offset=None,
|
44 |
+
noisy_cond_latents=False,
|
45 |
+
):
|
46 |
+
super().__init__(
|
47 |
+
vae=vae,
|
48 |
+
image_encoder=image_encoder,
|
49 |
+
unet=unet,
|
50 |
+
scheduler=scheduler,
|
51 |
+
safety_checker=safety_checker,
|
52 |
+
feature_extractor=feature_extractor,
|
53 |
+
requires_safety_checker=requires_safety_checker
|
54 |
+
)
|
55 |
+
latents_offset = tuple(latents_offset) if latents_offset is not None else None
|
56 |
+
self.latents_offset = latents_offset
|
57 |
+
if latents_offset is not None:
|
58 |
+
self.register_to_config(latents_offset=latents_offset)
|
59 |
+
self.noisy_cond_latents = noisy_cond_latents
|
60 |
+
self.register_to_config(noisy_cond_latents=noisy_cond_latents)
|
61 |
+
|
62 |
+
def encode_latents(self, image, device, dtype, height, width):
|
63 |
+
# support batchsize > 1
|
64 |
+
if isinstance(image, Image.Image):
|
65 |
+
image = [image]
|
66 |
+
image = [img.convert("RGB") for img in image]
|
67 |
+
images = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype)
|
68 |
+
latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
|
69 |
+
if self.latents_offset is not None:
|
70 |
+
return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
|
71 |
+
else:
|
72 |
+
return latents
|
73 |
+
|
74 |
+
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
|
75 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
76 |
+
|
77 |
+
if not isinstance(image, torch.Tensor):
|
78 |
+
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
79 |
+
|
80 |
+
image = image.to(device=device, dtype=dtype)
|
81 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
82 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
83 |
+
|
84 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
85 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
86 |
+
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
|
87 |
+
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
88 |
+
|
89 |
+
if do_classifier_free_guidance:
|
90 |
+
# NOTE: the same as original code
|
91 |
+
negative_prompt_embeds = torch.zeros_like(image_embeddings)
|
92 |
+
# For classifier free guidance, we need to do two forward passes.
|
93 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
94 |
+
# to avoid doing two forward passes
|
95 |
+
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
|
96 |
+
|
97 |
+
return image_embeddings
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def __call__(
|
101 |
+
self,
|
102 |
+
image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
|
103 |
+
height: Optional[int] = 1024,
|
104 |
+
width: Optional[int] = 1024,
|
105 |
+
height_cond: Optional[int] = 512,
|
106 |
+
width_cond: Optional[int] = 512,
|
107 |
+
num_inference_steps: int = 50,
|
108 |
+
guidance_scale: float = 7.5,
|
109 |
+
num_images_per_prompt: Optional[int] = 1,
|
110 |
+
eta: float = 0.0,
|
111 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
112 |
+
latents: Optional[torch.FloatTensor] = None,
|
113 |
+
output_type: Optional[str] = "pil",
|
114 |
+
return_dict: bool = True,
|
115 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
116 |
+
callback_steps: int = 1,
|
117 |
+
upper_left_feature: bool = False,
|
118 |
+
):
|
119 |
+
r"""
|
120 |
+
The call function to the pipeline for generation.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
|
124 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
125 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
126 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
127 |
+
The height in pixels of the generated image.
|
128 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
129 |
+
The width in pixels of the generated image.
|
130 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
131 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
132 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
133 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
134 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
135 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
136 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
137 |
+
The number of images to generate per prompt.
|
138 |
+
eta (`float`, *optional*, defaults to 0.0):
|
139 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
140 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
141 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
142 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
143 |
+
generation deterministic.
|
144 |
+
latents (`torch.FloatTensor`, *optional*):
|
145 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
146 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
147 |
+
tensor is generated by sampling using the supplied random `generator`.
|
148 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
149 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
150 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
151 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
152 |
+
plain tuple.
|
153 |
+
callback (`Callable`, *optional*):
|
154 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
155 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
156 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
157 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
158 |
+
every step.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
162 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
163 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
164 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
165 |
+
"not-safe-for-work" (nsfw) content.
|
166 |
+
|
167 |
+
Examples:
|
168 |
+
|
169 |
+
```py
|
170 |
+
from diffusers import StableDiffusionImageVariationPipeline
|
171 |
+
from PIL import Image
|
172 |
+
from io import BytesIO
|
173 |
+
import requests
|
174 |
+
|
175 |
+
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
176 |
+
"lambdalabs/sd-image-variations-diffusers", revision="v2.0"
|
177 |
+
)
|
178 |
+
pipe = pipe.to("cuda")
|
179 |
+
|
180 |
+
url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
|
181 |
+
|
182 |
+
response = requests.get(url)
|
183 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
184 |
+
|
185 |
+
out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
|
186 |
+
out["images"][0].save("result.jpg")
|
187 |
+
```
|
188 |
+
"""
|
189 |
+
# 0. Default height and width to unet
|
190 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
191 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
192 |
+
|
193 |
+
# 1. Check inputs. Raise error if not correct
|
194 |
+
self.check_inputs(image, height, width, callback_steps)
|
195 |
+
|
196 |
+
# 2. Define call parameters
|
197 |
+
if isinstance(image, Image.Image):
|
198 |
+
batch_size = 1
|
199 |
+
elif isinstance(image, list):
|
200 |
+
batch_size = len(image)
|
201 |
+
else:
|
202 |
+
batch_size = image.shape[0]
|
203 |
+
device = self._execution_device
|
204 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
205 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
206 |
+
# corresponds to doing no classifier free guidance.
|
207 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
208 |
+
|
209 |
+
# 3. Encode input image
|
210 |
+
if isinstance(image, Image.Image) and upper_left_feature:
|
211 |
+
# only use the first one of four images
|
212 |
+
emb_image = image.crop((0, 0, image.size[0] // 2, image.size[1] // 2))
|
213 |
+
else:
|
214 |
+
emb_image = image
|
215 |
+
|
216 |
+
image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
|
217 |
+
cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
|
218 |
+
|
219 |
+
# 4. Prepare timesteps
|
220 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
221 |
+
timesteps = self.scheduler.timesteps
|
222 |
+
|
223 |
+
# 5. Prepare latent variables
|
224 |
+
num_channels_latents = self.unet.config.out_channels
|
225 |
+
latents = self.prepare_latents(
|
226 |
+
batch_size * num_images_per_prompt,
|
227 |
+
num_channels_latents,
|
228 |
+
height,
|
229 |
+
width,
|
230 |
+
image_embeddings.dtype,
|
231 |
+
device,
|
232 |
+
generator,
|
233 |
+
latents,
|
234 |
+
)
|
235 |
+
|
236 |
+
# 6. Prepare extra step kwargs.
|
237 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
238 |
+
|
239 |
+
# 7. Denoising loop
|
240 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
241 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
242 |
+
for i, t in enumerate(timesteps):
|
243 |
+
if self.noisy_cond_latents:
|
244 |
+
raise ValueError("Noisy condition latents is not recommended.")
|
245 |
+
else:
|
246 |
+
noisy_cond_latents = cond_latents
|
247 |
+
|
248 |
+
noisy_cond_latents = torch.cat([torch.zeros_like(noisy_cond_latents), noisy_cond_latents]) if do_classifier_free_guidance else noisy_cond_latents
|
249 |
+
# expand the latents if we are doing classifier free guidance
|
250 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
251 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
252 |
+
|
253 |
+
# predict the noise residual
|
254 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=noisy_cond_latents).sample
|
255 |
+
|
256 |
+
# perform guidance
|
257 |
+
if do_classifier_free_guidance:
|
258 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
259 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
260 |
+
|
261 |
+
# compute the previous noisy sample x_t -> x_t-1
|
262 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
263 |
+
|
264 |
+
# call the callback, if provided
|
265 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
266 |
+
progress_bar.update()
|
267 |
+
if callback is not None and i % callback_steps == 0:
|
268 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
269 |
+
callback(step_idx, t, latents)
|
270 |
+
|
271 |
+
self.maybe_free_model_hooks()
|
272 |
+
|
273 |
+
if self.latents_offset is not None:
|
274 |
+
latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
|
275 |
+
|
276 |
+
if not output_type == "latent":
|
277 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
278 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
|
279 |
+
else:
|
280 |
+
image = latents
|
281 |
+
has_nsfw_concept = None
|
282 |
+
|
283 |
+
if has_nsfw_concept is None:
|
284 |
+
do_denormalize = [True] * image.shape[0]
|
285 |
+
else:
|
286 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
287 |
+
|
288 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
289 |
+
|
290 |
+
self.maybe_free_model_hooks()
|
291 |
+
|
292 |
+
if not return_dict:
|
293 |
+
return (image, has_nsfw_concept)
|
294 |
+
|
295 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
296 |
+
|
297 |
+
if __name__ == "__main__":
|
298 |
+
pass
|
custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# modified by Wuvin
|
15 |
+
|
16 |
+
|
17 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline
|
23 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler
|
24 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
|
25 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
26 |
+
from PIL import Image
|
27 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
class StableDiffusionImage2MVCustomPipeline(
|
32 |
+
StableDiffusionImageVariationPipeline
|
33 |
+
):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
vae: AutoencoderKL,
|
37 |
+
image_encoder: CLIPVisionModelWithProjection,
|
38 |
+
unet: UNet2DConditionModel,
|
39 |
+
scheduler: KarrasDiffusionSchedulers,
|
40 |
+
safety_checker: StableDiffusionSafetyChecker,
|
41 |
+
feature_extractor: CLIPImageProcessor,
|
42 |
+
requires_safety_checker: bool = True,
|
43 |
+
latents_offset=None,
|
44 |
+
noisy_cond_latents=False,
|
45 |
+
condition_offset=True,
|
46 |
+
):
|
47 |
+
super().__init__(
|
48 |
+
vae=vae,
|
49 |
+
image_encoder=image_encoder,
|
50 |
+
unet=unet,
|
51 |
+
scheduler=scheduler,
|
52 |
+
safety_checker=safety_checker,
|
53 |
+
feature_extractor=feature_extractor,
|
54 |
+
requires_safety_checker=requires_safety_checker
|
55 |
+
)
|
56 |
+
latents_offset = tuple(latents_offset) if latents_offset is not None else None
|
57 |
+
self.latents_offset = latents_offset
|
58 |
+
if latents_offset is not None:
|
59 |
+
self.register_to_config(latents_offset=latents_offset)
|
60 |
+
if noisy_cond_latents:
|
61 |
+
raise NotImplementedError("Noisy condition latents not supported Now.")
|
62 |
+
self.condition_offset = condition_offset
|
63 |
+
self.register_to_config(condition_offset=condition_offset)
|
64 |
+
|
65 |
+
def encode_latents(self, image: Image.Image, device, dtype, height, width):
|
66 |
+
images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype)
|
67 |
+
# NOTE: .mode() for condition
|
68 |
+
latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor
|
69 |
+
if self.latents_offset is not None and self.condition_offset:
|
70 |
+
return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
|
71 |
+
else:
|
72 |
+
return latents
|
73 |
+
|
74 |
+
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
|
75 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
76 |
+
|
77 |
+
if not isinstance(image, torch.Tensor):
|
78 |
+
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
79 |
+
|
80 |
+
image = image.to(device=device, dtype=dtype)
|
81 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
82 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
83 |
+
|
84 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
85 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
86 |
+
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
|
87 |
+
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
88 |
+
|
89 |
+
if do_classifier_free_guidance:
|
90 |
+
# NOTE: the same as original code
|
91 |
+
negative_prompt_embeds = torch.zeros_like(image_embeddings)
|
92 |
+
# For classifier free guidance, we need to do two forward passes.
|
93 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
94 |
+
# to avoid doing two forward passes
|
95 |
+
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
|
96 |
+
|
97 |
+
return image_embeddings
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def __call__(
|
101 |
+
self,
|
102 |
+
image: Union[Image.Image, List[Image.Image], torch.FloatTensor],
|
103 |
+
height: Optional[int] = 1024,
|
104 |
+
width: Optional[int] = 1024,
|
105 |
+
height_cond: Optional[int] = 512,
|
106 |
+
width_cond: Optional[int] = 512,
|
107 |
+
num_inference_steps: int = 50,
|
108 |
+
guidance_scale: float = 7.5,
|
109 |
+
num_images_per_prompt: Optional[int] = 1,
|
110 |
+
eta: float = 0.0,
|
111 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
112 |
+
latents: Optional[torch.FloatTensor] = None,
|
113 |
+
output_type: Optional[str] = "pil",
|
114 |
+
return_dict: bool = True,
|
115 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
116 |
+
callback_steps: int = 1,
|
117 |
+
):
|
118 |
+
r"""
|
119 |
+
The call function to the pipeline for generation.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`):
|
123 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
124 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
125 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
126 |
+
The height in pixels of the generated image.
|
127 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
128 |
+
The width in pixels of the generated image.
|
129 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
130 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
131 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
132 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
133 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
134 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
135 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
136 |
+
The number of images to generate per prompt.
|
137 |
+
eta (`float`, *optional*, defaults to 0.0):
|
138 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
139 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
140 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
141 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
142 |
+
generation deterministic.
|
143 |
+
latents (`torch.FloatTensor`, *optional*):
|
144 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
145 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
146 |
+
tensor is generated by sampling using the supplied random `generator`.
|
147 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
148 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
149 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
150 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
151 |
+
plain tuple.
|
152 |
+
callback (`Callable`, *optional*):
|
153 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
154 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
155 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
156 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
157 |
+
every step.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
161 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
162 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
163 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
164 |
+
"not-safe-for-work" (nsfw) content.
|
165 |
+
|
166 |
+
Examples:
|
167 |
+
|
168 |
+
```py
|
169 |
+
from diffusers import StableDiffusionImageVariationPipeline
|
170 |
+
from PIL import Image
|
171 |
+
from io import BytesIO
|
172 |
+
import requests
|
173 |
+
|
174 |
+
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
175 |
+
"lambdalabs/sd-image-variations-diffusers", revision="v2.0"
|
176 |
+
)
|
177 |
+
pipe = pipe.to("cuda")
|
178 |
+
|
179 |
+
url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
|
180 |
+
|
181 |
+
response = requests.get(url)
|
182 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
183 |
+
|
184 |
+
out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
|
185 |
+
out["images"][0].save("result.jpg")
|
186 |
+
```
|
187 |
+
"""
|
188 |
+
# 0. Default height and width to unet
|
189 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
190 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
191 |
+
|
192 |
+
# 1. Check inputs. Raise error if not correct
|
193 |
+
self.check_inputs(image, height, width, callback_steps)
|
194 |
+
|
195 |
+
# 2. Define call parameters
|
196 |
+
if isinstance(image, Image.Image):
|
197 |
+
batch_size = 1
|
198 |
+
elif len(image) == 1:
|
199 |
+
image = image[0]
|
200 |
+
batch_size = 1
|
201 |
+
else:
|
202 |
+
raise NotImplementedError()
|
203 |
+
# elif isinstance(image, list):
|
204 |
+
# batch_size = len(image)
|
205 |
+
# else:
|
206 |
+
# batch_size = image.shape[0]
|
207 |
+
device = self._execution_device
|
208 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
209 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
210 |
+
# corresponds to doing no classifier free guidance.
|
211 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
212 |
+
|
213 |
+
# 3. Encode input image
|
214 |
+
emb_image = image
|
215 |
+
|
216 |
+
image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance)
|
217 |
+
cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond)
|
218 |
+
cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents
|
219 |
+
image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values
|
220 |
+
if do_classifier_free_guidance:
|
221 |
+
image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0)
|
222 |
+
|
223 |
+
# 4. Prepare timesteps
|
224 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
225 |
+
timesteps = self.scheduler.timesteps
|
226 |
+
|
227 |
+
# 5. Prepare latent variables
|
228 |
+
num_channels_latents = self.unet.config.out_channels
|
229 |
+
latents = self.prepare_latents(
|
230 |
+
batch_size * num_images_per_prompt,
|
231 |
+
num_channels_latents,
|
232 |
+
height,
|
233 |
+
width,
|
234 |
+
image_embeddings.dtype,
|
235 |
+
device,
|
236 |
+
generator,
|
237 |
+
latents,
|
238 |
+
)
|
239 |
+
|
240 |
+
|
241 |
+
# 6. Prepare extra step kwargs.
|
242 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
243 |
+
# 7. Denoising loop
|
244 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
245 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
246 |
+
for i, t in enumerate(timesteps):
|
247 |
+
# expand the latents if we are doing classifier free guidance
|
248 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
249 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
250 |
+
|
251 |
+
# predict the noise residual
|
252 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=cond_latents, noisy_condition_input=False, cond_pixels_clip=image_pixels).sample
|
253 |
+
|
254 |
+
# perform guidance
|
255 |
+
if do_classifier_free_guidance:
|
256 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
257 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
258 |
+
|
259 |
+
# compute the previous noisy sample x_t -> x_t-1
|
260 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
261 |
+
|
262 |
+
# call the callback, if provided
|
263 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
264 |
+
progress_bar.update()
|
265 |
+
if callback is not None and i % callback_steps == 0:
|
266 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
267 |
+
callback(step_idx, t, latents)
|
268 |
+
|
269 |
+
self.maybe_free_model_hooks()
|
270 |
+
|
271 |
+
if self.latents_offset is not None:
|
272 |
+
latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None]
|
273 |
+
|
274 |
+
if not output_type == "latent":
|
275 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
276 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
|
277 |
+
else:
|
278 |
+
image = latents
|
279 |
+
has_nsfw_concept = None
|
280 |
+
|
281 |
+
if has_nsfw_concept is None:
|
282 |
+
do_denormalize = [True] * image.shape[0]
|
283 |
+
else:
|
284 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
285 |
+
|
286 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
287 |
+
|
288 |
+
self.maybe_free_model_hooks()
|
289 |
+
|
290 |
+
if not return_dict:
|
291 |
+
return (image, has_nsfw_concept)
|
292 |
+
|
293 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
294 |
+
|
295 |
+
if __name__ == "__main__":
|
296 |
+
pass
|
custum_3d_diffusion/modules.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__modules__ = {}
|
2 |
+
|
3 |
+
def register(name):
|
4 |
+
def decorator(cls):
|
5 |
+
__modules__[name] = cls
|
6 |
+
return cls
|
7 |
+
|
8 |
+
return decorator
|
9 |
+
|
10 |
+
|
11 |
+
def find(name):
|
12 |
+
return __modules__[name]
|
13 |
+
|
14 |
+
from custum_3d_diffusion.trainings import base, image2mvimage_trainer, image2image_trainer
|
custum_3d_diffusion/trainings/__init__.py
ADDED
File without changes
|
custum_3d_diffusion/trainings/base.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from accelerate import Accelerator
|
3 |
+
from accelerate.logging import MultiProcessAdapter
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from typing import Optional, Union
|
6 |
+
from datasets import load_dataset
|
7 |
+
import json
|
8 |
+
import abc
|
9 |
+
from diffusers.utils import make_image_grid
|
10 |
+
import numpy as np
|
11 |
+
import wandb
|
12 |
+
|
13 |
+
from custum_3d_diffusion.trainings.utils import load_config
|
14 |
+
from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig
|
15 |
+
|
16 |
+
class BasicTrainer(torch.nn.Module, abc.ABC):
|
17 |
+
accelerator: Accelerator
|
18 |
+
logger: MultiProcessAdapter
|
19 |
+
unet: ConfigurableUNet2DConditionModel
|
20 |
+
train_dataloader: torch.utils.data.DataLoader
|
21 |
+
test_dataset: torch.utils.data.Dataset
|
22 |
+
attn_config: AttnConfig
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class TrainerConfig:
|
26 |
+
trainer_name: str = "basic"
|
27 |
+
pretrained_model_name_or_path: str = ""
|
28 |
+
|
29 |
+
attn_config: dict = field(default_factory=dict)
|
30 |
+
dataset_name: str = ""
|
31 |
+
dataset_config_name: Optional[str] = None
|
32 |
+
resolution: str = "1024"
|
33 |
+
dataloader_num_workers: int = 4
|
34 |
+
pair_sampler_group_size: int = 1
|
35 |
+
num_views: int = 4
|
36 |
+
|
37 |
+
max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps)
|
38 |
+
training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps
|
39 |
+
max_train_samples: Optional[int] = None
|
40 |
+
seed: Optional[int] = None # For dataset related operations and validation stuff
|
41 |
+
train_batch_size: int = 1
|
42 |
+
|
43 |
+
validation_interval: int = 5000
|
44 |
+
debug: bool = False
|
45 |
+
|
46 |
+
cfg: TrainerConfig # only enable_xxx is used
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
accelerator: Accelerator,
|
51 |
+
logger: MultiProcessAdapter,
|
52 |
+
unet: ConfigurableUNet2DConditionModel,
|
53 |
+
config: Union[dict, str],
|
54 |
+
weight_dtype: torch.dtype,
|
55 |
+
index: int,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.index = index # index in all trainers
|
59 |
+
self.accelerator = accelerator
|
60 |
+
self.logger = logger
|
61 |
+
self.unet = unet
|
62 |
+
self.weight_dtype = weight_dtype
|
63 |
+
self.ext_logs = {}
|
64 |
+
self.cfg = load_config(self.TrainerConfig, config)
|
65 |
+
self.attn_config = load_config(AttnConfig, self.cfg.attn_config)
|
66 |
+
self.test_dataset = None
|
67 |
+
self.validate_trainer_config()
|
68 |
+
self.configure()
|
69 |
+
|
70 |
+
def get_HW(self):
|
71 |
+
resolution = json.loads(self.cfg.resolution)
|
72 |
+
if isinstance(resolution, int):
|
73 |
+
H = W = resolution
|
74 |
+
elif isinstance(resolution, list):
|
75 |
+
H, W = resolution
|
76 |
+
return H, W
|
77 |
+
|
78 |
+
def unet_update(self):
|
79 |
+
self.unet.update_config(self.attn_config)
|
80 |
+
|
81 |
+
def validate_trainer_config(self):
|
82 |
+
pass
|
83 |
+
|
84 |
+
def is_train_finished(self, current_step):
|
85 |
+
assert isinstance(self.cfg.max_train_steps, int)
|
86 |
+
return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps
|
87 |
+
|
88 |
+
def next_train_step(self, current_step):
|
89 |
+
if self.is_train_finished(current_step):
|
90 |
+
return None
|
91 |
+
return current_step + self.cfg.training_step_interval
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def make_image_into_grid(cls, all_imgs, rows=2, columns=2):
|
95 |
+
catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)]
|
96 |
+
return make_image_grid(catted, rows=1, cols=len(catted))
|
97 |
+
|
98 |
+
def configure(self) -> None:
|
99 |
+
pass
|
100 |
+
|
101 |
+
@abc.abstractmethod
|
102 |
+
def init_shared_modules(self, shared_modules: dict) -> dict:
|
103 |
+
pass
|
104 |
+
|
105 |
+
def load_dataset(self):
|
106 |
+
dataset = load_dataset(
|
107 |
+
self.cfg.dataset_name,
|
108 |
+
self.cfg.dataset_config_name,
|
109 |
+
trust_remote_code=True
|
110 |
+
)
|
111 |
+
return dataset
|
112 |
+
|
113 |
+
@abc.abstractmethod
|
114 |
+
def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
|
115 |
+
"""Both init train_dataloader and test_dataset, but returns train_dataloader only"""
|
116 |
+
pass
|
117 |
+
|
118 |
+
@abc.abstractmethod
|
119 |
+
def forward_step(
|
120 |
+
self,
|
121 |
+
*args,
|
122 |
+
**kwargs
|
123 |
+
) -> torch.Tensor:
|
124 |
+
"""
|
125 |
+
input a batch
|
126 |
+
return a loss
|
127 |
+
"""
|
128 |
+
self.unet_update()
|
129 |
+
pass
|
130 |
+
|
131 |
+
@abc.abstractmethod
|
132 |
+
def construct_pipeline(self, shared_modules, unet):
|
133 |
+
pass
|
134 |
+
|
135 |
+
@abc.abstractmethod
|
136 |
+
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
|
137 |
+
"""
|
138 |
+
For inference time forward.
|
139 |
+
"""
|
140 |
+
pass
|
141 |
+
|
142 |
+
@abc.abstractmethod
|
143 |
+
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
|
144 |
+
pass
|
145 |
+
|
146 |
+
def do_validation(
|
147 |
+
self,
|
148 |
+
shared_modules,
|
149 |
+
unet,
|
150 |
+
global_step,
|
151 |
+
):
|
152 |
+
self.unet_update()
|
153 |
+
self.logger.info("Running validation... ")
|
154 |
+
pipeline = self.construct_pipeline(shared_modules, unet)
|
155 |
+
pipeline.set_progress_bar_config(disable=True)
|
156 |
+
titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.])
|
157 |
+
for tracker in self.accelerator.trackers:
|
158 |
+
if tracker.name == "tensorboard":
|
159 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
160 |
+
tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
|
161 |
+
elif tracker.name == "wandb":
|
162 |
+
[image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation
|
163 |
+
tracker.log({"validation": [
|
164 |
+
wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg")
|
165 |
+
for i, image in enumerate(images)]})
|
166 |
+
else:
|
167 |
+
self.logger.warn(f"image logging not implemented for {tracker.name}")
|
168 |
+
del pipeline
|
169 |
+
torch.cuda.empty_cache()
|
170 |
+
return images
|
171 |
+
|
172 |
+
|
173 |
+
@torch.no_grad()
|
174 |
+
def log_validation(
|
175 |
+
self,
|
176 |
+
shared_modules,
|
177 |
+
unet,
|
178 |
+
global_step,
|
179 |
+
force=False
|
180 |
+
):
|
181 |
+
if self.accelerator.is_main_process:
|
182 |
+
for tracker in self.accelerator.trackers:
|
183 |
+
if tracker.name == "wandb":
|
184 |
+
tracker.log(self.ext_logs)
|
185 |
+
self.ext_logs = {}
|
186 |
+
if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force:
|
187 |
+
self.unet_update()
|
188 |
+
if self.accelerator.is_main_process:
|
189 |
+
self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step)
|
190 |
+
|
191 |
+
def save_model(self, unwrap_unet, shared_modules, save_dir):
|
192 |
+
if self.accelerator.is_main_process:
|
193 |
+
pipeline = self.construct_pipeline(shared_modules, unwrap_unet)
|
194 |
+
pipeline.save_pretrained(save_dir)
|
195 |
+
self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}")
|
196 |
+
|
197 |
+
def save_debug_info(self, save_name="debug", **kwargs):
|
198 |
+
if self.cfg.debug:
|
199 |
+
to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()}
|
200 |
+
import pickle
|
201 |
+
import os
|
202 |
+
if os.path.exists(f"{save_name}.pkl"):
|
203 |
+
for i in range(100):
|
204 |
+
if not os.path.exists(f"{save_name}_v{i}.pkl"):
|
205 |
+
save_name = f"{save_name}_v{i}"
|
206 |
+
break
|
207 |
+
with open(f"{save_name}.pkl", "wb") as f:
|
208 |
+
pickle.dump(to_saves, f)
|
custum_3d_diffusion/trainings/config_classes.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class TrainerSubConfig:
|
7 |
+
trainer_type: str = ""
|
8 |
+
trainer: dict = field(default_factory=dict)
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class ExprimentConfig:
|
13 |
+
trainers: List[dict] = field(default_factory=lambda: [])
|
14 |
+
init_config: dict = field(default_factory=dict)
|
15 |
+
pretrained_model_name_or_path: str = ""
|
16 |
+
pretrained_unet_state_dict_path: str = ""
|
17 |
+
# expriments related parameters
|
18 |
+
linear_beta_schedule: bool = False
|
19 |
+
zero_snr: bool = False
|
20 |
+
prediction_type: Optional[str] = None
|
21 |
+
seed: Optional[int] = None
|
22 |
+
max_train_steps: int = 1000000
|
23 |
+
gradient_accumulation_steps: int = 1
|
24 |
+
learning_rate: float = 1e-4
|
25 |
+
lr_scheduler: str = "constant"
|
26 |
+
lr_warmup_steps: int = 500
|
27 |
+
use_8bit_adam: bool = False
|
28 |
+
adam_beta1: float = 0.9
|
29 |
+
adam_beta2: float = 0.999
|
30 |
+
adam_weight_decay: float = 1e-2
|
31 |
+
adam_epsilon: float = 1e-08
|
32 |
+
max_grad_norm: float = 1.0
|
33 |
+
mixed_precision: Optional[str] = None # ["no", "fp16", "bf16", "fp8"]
|
34 |
+
skip_training: bool = False
|
35 |
+
debug: bool = False
|
custum_3d_diffusion/trainings/image2image_trainer.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from diffusers import EulerAncestralDiscreteScheduler, DDPMScheduler
|
4 |
+
from dataclasses import dataclass
|
5 |
+
|
6 |
+
from custum_3d_diffusion.modules import register
|
7 |
+
from custum_3d_diffusion.trainings.image2mvimage_trainer import Image2MVImageTrainer
|
8 |
+
from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2img import StableDiffusionImageCustomPipeline
|
9 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
10 |
+
|
11 |
+
def get_HW(resolution):
|
12 |
+
if isinstance(resolution, str):
|
13 |
+
resolution = json.loads(resolution)
|
14 |
+
if isinstance(resolution, int):
|
15 |
+
H = W = resolution
|
16 |
+
elif isinstance(resolution, list):
|
17 |
+
H, W = resolution
|
18 |
+
return H, W
|
19 |
+
|
20 |
+
|
21 |
+
@register("image2image_trainer")
|
22 |
+
class Image2ImageTrainer(Image2MVImageTrainer):
|
23 |
+
"""
|
24 |
+
Trainer for simple image to multiview images.
|
25 |
+
"""
|
26 |
+
@dataclass
|
27 |
+
class TrainerConfig(Image2MVImageTrainer.TrainerConfig):
|
28 |
+
trainer_name: str = "image2image"
|
29 |
+
|
30 |
+
cfg: TrainerConfig
|
31 |
+
|
32 |
+
def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
|
33 |
+
raise NotImplementedError()
|
34 |
+
|
35 |
+
def construct_pipeline(self, shared_modules, unet, old_version=False):
|
36 |
+
MyPipeline = StableDiffusionImageCustomPipeline
|
37 |
+
pipeline = MyPipeline.from_pretrained(
|
38 |
+
self.cfg.pretrained_model_name_or_path,
|
39 |
+
vae=shared_modules['vae'],
|
40 |
+
image_encoder=shared_modules['image_encoder'],
|
41 |
+
feature_extractor=shared_modules['feature_extractor'],
|
42 |
+
unet=unet,
|
43 |
+
safety_checker=None,
|
44 |
+
torch_dtype=self.weight_dtype,
|
45 |
+
latents_offset=self.cfg.latents_offset,
|
46 |
+
noisy_cond_latents=self.cfg.noisy_condition_input,
|
47 |
+
)
|
48 |
+
pipeline.set_progress_bar_config(disable=True)
|
49 |
+
scheduler_dict = {}
|
50 |
+
if self.cfg.zero_snr:
|
51 |
+
scheduler_dict.update(rescale_betas_zero_snr=True)
|
52 |
+
if self.cfg.linear_beta_schedule:
|
53 |
+
scheduler_dict.update(beta_schedule='linear')
|
54 |
+
|
55 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
|
56 |
+
return pipeline
|
57 |
+
|
58 |
+
def get_forward_args(self):
|
59 |
+
if self.cfg.seed is None:
|
60 |
+
generator = None
|
61 |
+
else:
|
62 |
+
generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
|
63 |
+
|
64 |
+
H, W = get_HW(self.cfg.resolution)
|
65 |
+
H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
|
66 |
+
|
67 |
+
forward_args = dict(
|
68 |
+
num_images_per_prompt=1,
|
69 |
+
num_inference_steps=20,
|
70 |
+
height=H,
|
71 |
+
width=W,
|
72 |
+
height_cond=H_cond,
|
73 |
+
width_cond=W_cond,
|
74 |
+
generator=generator,
|
75 |
+
)
|
76 |
+
if self.cfg.zero_snr:
|
77 |
+
forward_args.update(guidance_rescale=0.7)
|
78 |
+
return forward_args
|
79 |
+
|
80 |
+
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
|
81 |
+
forward_args = self.get_forward_args()
|
82 |
+
forward_args.update(pipeline_call_kwargs)
|
83 |
+
return pipeline(**forward_args)
|
84 |
+
|
85 |
+
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
|
86 |
+
raise NotImplementedError()
|
custum_3d_diffusion/trainings/image2mvimage_trainer.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler
|
3 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature
|
4 |
+
|
5 |
+
import json
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import List, Optional
|
8 |
+
|
9 |
+
from custum_3d_diffusion.modules import register
|
10 |
+
from custum_3d_diffusion.trainings.base import BasicTrainer
|
11 |
+
from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline
|
12 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
13 |
+
|
14 |
+
def get_HW(resolution):
|
15 |
+
if isinstance(resolution, str):
|
16 |
+
resolution = json.loads(resolution)
|
17 |
+
if isinstance(resolution, int):
|
18 |
+
H = W = resolution
|
19 |
+
elif isinstance(resolution, list):
|
20 |
+
H, W = resolution
|
21 |
+
return H, W
|
22 |
+
|
23 |
+
@register("image2mvimage_trainer")
|
24 |
+
class Image2MVImageTrainer(BasicTrainer):
|
25 |
+
"""
|
26 |
+
Trainer for simple image to multiview images.
|
27 |
+
"""
|
28 |
+
@dataclass
|
29 |
+
class TrainerConfig(BasicTrainer.TrainerConfig):
|
30 |
+
trainer_name: str = "image2mvimage"
|
31 |
+
condition_image_column_name: str = "conditioning_image"
|
32 |
+
image_column_name: str = "image"
|
33 |
+
condition_dropout: float = 0.
|
34 |
+
condition_image_resolution: str = "512"
|
35 |
+
validation_images: Optional[List[str]] = None
|
36 |
+
noise_offset: float = 0.1
|
37 |
+
max_loss_drop: float = 0.
|
38 |
+
snr_gamma: float = 5.0
|
39 |
+
log_distribution: bool = False
|
40 |
+
latents_offset: Optional[List[float]] = None
|
41 |
+
input_perturbation: float = 0.
|
42 |
+
noisy_condition_input: bool = False # whether to add noise for ref unet input
|
43 |
+
normal_cls_offset: int = 0
|
44 |
+
condition_offset: bool = True
|
45 |
+
zero_snr: bool = False
|
46 |
+
linear_beta_schedule: bool = False
|
47 |
+
|
48 |
+
cfg: TrainerConfig
|
49 |
+
|
50 |
+
def configure(self) -> None:
|
51 |
+
return super().configure()
|
52 |
+
|
53 |
+
def init_shared_modules(self, shared_modules: dict) -> dict:
|
54 |
+
if 'vae' not in shared_modules:
|
55 |
+
vae = AutoencoderKL.from_pretrained(
|
56 |
+
self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype
|
57 |
+
)
|
58 |
+
vae.requires_grad_(False)
|
59 |
+
vae.to(self.accelerator.device, dtype=self.weight_dtype)
|
60 |
+
shared_modules['vae'] = vae
|
61 |
+
if 'image_encoder' not in shared_modules:
|
62 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
63 |
+
self.cfg.pretrained_model_name_or_path, subfolder="image_encoder"
|
64 |
+
)
|
65 |
+
image_encoder.requires_grad_(False)
|
66 |
+
image_encoder.to(self.accelerator.device, dtype=self.weight_dtype)
|
67 |
+
shared_modules['image_encoder'] = image_encoder
|
68 |
+
if 'feature_extractor' not in shared_modules:
|
69 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(
|
70 |
+
self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor"
|
71 |
+
)
|
72 |
+
shared_modules['feature_extractor'] = feature_extractor
|
73 |
+
return shared_modules
|
74 |
+
|
75 |
+
def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader:
|
76 |
+
raise NotImplementedError()
|
77 |
+
|
78 |
+
def loss_rescale(self, loss, timesteps=None):
|
79 |
+
raise NotImplementedError()
|
80 |
+
|
81 |
+
def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor:
|
82 |
+
raise NotImplementedError()
|
83 |
+
|
84 |
+
def construct_pipeline(self, shared_modules, unet, old_version=False):
|
85 |
+
MyPipeline = StableDiffusionImage2MVCustomPipeline
|
86 |
+
pipeline = MyPipeline.from_pretrained(
|
87 |
+
self.cfg.pretrained_model_name_or_path,
|
88 |
+
vae=shared_modules['vae'],
|
89 |
+
image_encoder=shared_modules['image_encoder'],
|
90 |
+
feature_extractor=shared_modules['feature_extractor'],
|
91 |
+
unet=unet,
|
92 |
+
safety_checker=None,
|
93 |
+
torch_dtype=self.weight_dtype,
|
94 |
+
latents_offset=self.cfg.latents_offset,
|
95 |
+
noisy_cond_latents=self.cfg.noisy_condition_input,
|
96 |
+
condition_offset=self.cfg.condition_offset,
|
97 |
+
)
|
98 |
+
pipeline.set_progress_bar_config(disable=True)
|
99 |
+
scheduler_dict = {}
|
100 |
+
if self.cfg.zero_snr:
|
101 |
+
scheduler_dict.update(rescale_betas_zero_snr=True)
|
102 |
+
if self.cfg.linear_beta_schedule:
|
103 |
+
scheduler_dict.update(beta_schedule='linear')
|
104 |
+
|
105 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict)
|
106 |
+
return pipeline
|
107 |
+
|
108 |
+
def get_forward_args(self):
|
109 |
+
if self.cfg.seed is None:
|
110 |
+
generator = None
|
111 |
+
else:
|
112 |
+
generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed)
|
113 |
+
|
114 |
+
H, W = get_HW(self.cfg.resolution)
|
115 |
+
H_cond, W_cond = get_HW(self.cfg.condition_image_resolution)
|
116 |
+
|
117 |
+
sub_img_H = H // 2
|
118 |
+
num_imgs = H // sub_img_H * W // sub_img_H
|
119 |
+
|
120 |
+
forward_args = dict(
|
121 |
+
num_images_per_prompt=num_imgs,
|
122 |
+
num_inference_steps=50,
|
123 |
+
height=sub_img_H,
|
124 |
+
width=sub_img_H,
|
125 |
+
height_cond=H_cond,
|
126 |
+
width_cond=W_cond,
|
127 |
+
generator=generator,
|
128 |
+
)
|
129 |
+
if self.cfg.zero_snr:
|
130 |
+
forward_args.update(guidance_rescale=0.7)
|
131 |
+
return forward_args
|
132 |
+
|
133 |
+
def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput:
|
134 |
+
forward_args = self.get_forward_args()
|
135 |
+
forward_args.update(pipeline_call_kwargs)
|
136 |
+
return pipeline(**forward_args)
|
137 |
+
|
138 |
+
def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple:
|
139 |
+
raise NotImplementedError()
|
custum_3d_diffusion/trainings/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from omegaconf import DictConfig, OmegaConf
|
2 |
+
|
3 |
+
|
4 |
+
def parse_structured(fields, cfg) -> DictConfig:
|
5 |
+
scfg = OmegaConf.structured(fields(**cfg))
|
6 |
+
return scfg
|
7 |
+
|
8 |
+
|
9 |
+
def load_config(fields, config, extras=None):
|
10 |
+
if extras is not None:
|
11 |
+
print("Warning! extra parameter in cli is not verified, may cause erros.")
|
12 |
+
if isinstance(config, str):
|
13 |
+
cfg = OmegaConf.load(config)
|
14 |
+
elif isinstance(config, dict):
|
15 |
+
cfg = OmegaConf.create(config)
|
16 |
+
elif isinstance(config, DictConfig):
|
17 |
+
cfg = config
|
18 |
+
else:
|
19 |
+
raise NotImplementedError(f"Unsupported config type {type(config)}")
|
20 |
+
if extras is not None:
|
21 |
+
cli_conf = OmegaConf.from_cli(extras)
|
22 |
+
cfg = OmegaConf.merge(cfg, cli_conf)
|
23 |
+
OmegaConf.resolve(cfg)
|
24 |
+
assert isinstance(cfg, DictConfig)
|
25 |
+
return parse_structured(fields, cfg)
|
docker/Dockerfile
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# get the development image from nvidia cuda 12.1
|
2 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
3 |
+
|
4 |
+
LABEL name="unique3d" maintainer="unique3d"
|
5 |
+
|
6 |
+
# create workspace folder and set it as working directory
|
7 |
+
RUN mkdir -p /workspace
|
8 |
+
WORKDIR /workspace
|
9 |
+
|
10 |
+
# update package lists and install git, wget, vim, libegl1-mesa-dev, and libglib2.0-0
|
11 |
+
RUN apt-get update && apt-get install -y build-essential git wget vim libegl1-mesa-dev libglib2.0-0 unzip git-lfs
|
12 |
+
|
13 |
+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev cmake curl mesa-utils-extra
|
14 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
15 |
+
ENV PYTHONUNBUFFERED=1
|
16 |
+
ENV LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
|
17 |
+
ENV PYOPENGL_PLATFORM=egl
|
18 |
+
|
19 |
+
# install conda
|
20 |
+
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
21 |
+
chmod +x Miniconda3-latest-Linux-x86_64.sh && \
|
22 |
+
./Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 && \
|
23 |
+
rm Miniconda3-latest-Linux-x86_64.sh
|
24 |
+
|
25 |
+
# update PATH environment variable
|
26 |
+
ENV PATH="/workspace/miniconda3/bin:${PATH}"
|
27 |
+
|
28 |
+
# initialize conda
|
29 |
+
RUN conda init bash
|
30 |
+
|
31 |
+
# create and activate conda environment
|
32 |
+
RUN conda create -n unique3d python=3.10 && echo "source activate unique3d" > ~/.bashrc
|
33 |
+
ENV PATH /workspace/miniconda3/envs/unique3d/bin:$PATH
|
34 |
+
|
35 |
+
RUN conda install Ninja
|
36 |
+
RUN conda install cuda -c nvidia/label/cuda-12.1.0 -y
|
37 |
+
|
38 |
+
RUN pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 xformers triton --index-url https://download.pytorch.org/whl/cu121
|
39 |
+
RUN pip install diffusers==0.27.2
|
40 |
+
|
41 |
+
RUN git clone --depth 1 https://huggingface.co/spaces/Wuvin/Unique3D
|
42 |
+
|
43 |
+
# change the working directory to the repository
|
44 |
+
|
45 |
+
WORKDIR /workspace/Unique3D
|
46 |
+
# other dependencies
|
47 |
+
RUN pip install -r requirements.txt
|
48 |
+
|
49 |
+
RUN pip install nvidia-pyindex
|
50 |
+
|
51 |
+
RUN pip install --upgrade nvidia-tensorrt
|
52 |
+
|
53 |
+
RUN pip install spaces
|
54 |
+
|
docker/README.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Docker setup
|
2 |
+
|
3 |
+
This docker setup is tested on Windows 10.
|
4 |
+
|
5 |
+
make sure you are under this directory yourworkspace/Unique3D/docker
|
6 |
+
|
7 |
+
Build docker image:
|
8 |
+
|
9 |
+
```
|
10 |
+
docker build -t unique3d -f Dockerfile .
|
11 |
+
```
|
12 |
+
|
13 |
+
Run docker image at the first time:
|
14 |
+
|
15 |
+
```
|
16 |
+
docker run -it --name unique3d -p 7860:7860 --gpus all unique3d python app.py
|
17 |
+
```
|
18 |
+
|
19 |
+
After first time:
|
20 |
+
```
|
21 |
+
docker start unique3d
|
22 |
+
docker exec unique3d python app.py
|
23 |
+
```
|
24 |
+
|
25 |
+
Stop the container:
|
26 |
+
```
|
27 |
+
docker stop unique3d
|
28 |
+
```
|
29 |
+
|
30 |
+
You can find the demo link showing in terminal, such as `https://94fc1ba77a08526e17.gradio.live/` or something similar else (it will be changed after each time to restart the container) to use the demo.
|
31 |
+
|
32 |
+
Some notes:
|
33 |
+
1. this docker build is using https://huggingface.co/spaces/Wuvin/Unique3D rather than this repo to clone the source.
|
34 |
+
2. the total built time might take more than one hour.
|
35 |
+
3. the total size of the built image will be more than 70GB.
|
gradio_app.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
if __name__ == "__main__":
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
sys.path.append(os.curdir)
|
5 |
+
import torch
|
6 |
+
torch.set_float32_matmul_precision('medium')
|
7 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
8 |
+
torch.set_grad_enabled(False)
|
9 |
+
|
10 |
+
import fire
|
11 |
+
import gradio as gr
|
12 |
+
from app.gradio_3dgen import create_ui as create_3d_ui
|
13 |
+
from app.all_models import model_zoo
|
14 |
+
|
15 |
+
|
16 |
+
_TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image'''
|
17 |
+
_DESCRIPTION = '''
|
18 |
+
[Project page](https://wukailu.github.io/Unique3D/)
|
19 |
+
|
20 |
+
* High-fidelity and diverse textured meshes generated by Unique3D from single-view images.
|
21 |
+
|
22 |
+
* The demo is still under construction, and more features are expected to be implemented soon.
|
23 |
+
'''
|
24 |
+
|
25 |
+
def launch():
|
26 |
+
model_zoo.init_models()
|
27 |
+
|
28 |
+
with gr.Blocks(
|
29 |
+
title=_TITLE,
|
30 |
+
theme=gr.themes.Monochrome(),
|
31 |
+
) as demo:
|
32 |
+
with gr.Row():
|
33 |
+
with gr.Column(scale=1):
|
34 |
+
gr.Markdown('# ' + _TITLE)
|
35 |
+
gr.Markdown(_DESCRIPTION)
|
36 |
+
create_3d_ui("wkl")
|
37 |
+
|
38 |
+
demo.queue().launch(share=True)
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
fire.Fire(launch)
|