diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000000000000000000000000000000000..afc079492cbd19cc91244e4341731002583d8d03 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,12 @@ +root = true + +[*.py] +charset = utf-8 +trim_trailing_whitespace = true +end_of_line = lf +insert_final_newline = true +indent_style = space +indent_size = 4 + +[*.md] +trim_trailing_whitespace = false diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..71d5908fbf45c7a675f1e975588159f136d0c0b7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +app/examples/bag.png filter=lfs diff=lfs merge=lfs -text +app/examples/ex1.png filter=lfs diff=lfs merge=lfs -text +assets/teaser_safe.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..77962a086ce5a18583c0de529117b6d4ab470166 --- /dev/null +++ b/.gitignore @@ -0,0 +1,217 @@ +# Created by https://www.toptal.com/developers/gitignore/api/python +# Edit at https://www.toptal.com/developers/gitignore?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +# End of https://www.toptal.com/developers/gitignore/api/python + +.vscode/ +.threestudio_cache/ +outputs +outputs/ +outputs-gradio +outputs-gradio/ +lightning_logs/ + +# pretrained model weights +*.ckpt +*.pt +*.pth +*.bin +*.param + +# wandb +wandb/ + +# obj results +*.obj +*.glb +*.ply + +# ckpts +ckpt/* +*.pth +*.pt + +# tensorrt +*.engine +*.profile + +# zipfiles +*.zip +*.tar +*.tar.gz + +# others +run_30.sh +ckpt \ No newline at end of file diff --git a/Installation.md b/Installation.md new file mode 100644 index 0000000000000000000000000000000000000000..78fb2cd42e98f55389c44e1b8c01dc164b6e1647 --- /dev/null +++ b/Installation.md @@ -0,0 +1,170 @@ +# 官方安装指南 + +* 在 requirements-detail.txt 里,我们提供了详细的各个库的版本,这个对应的环境是 `python3.10 + cuda12.2`。 +* 本项目依赖于几个重要的pypi包,这几个包安装起来会有一些困难。 + +### nvdiffrast 安装 + +* nvdiffrast 会在第一次运行时,编译对应的torch插件,这一步需要 ninja 及 cudatoolkit的支持。 +* 因此需要先确保正确安装了 ninja 以及 cudatoolkit 并正确配置了 CUDA_HOME 环境变量。 +* cudatoolkit 安装可以参考 [linux-cuda-installation-guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html), [windows-cuda-installation-guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) +* ninja 则可用直接 `pip install ninja` +* 然后设置 CUDA_HOME 变量为 cudatoolkit 的安装目录,如 `/usr/local/cuda`。 +* 最后 `pip install nvdiffrast` 即可。 +* 如果无法在目标服务器上安装 cudatoolkit (如权限不够),可用使用我修改的[预编译版本 nvdiffrast](https://github.com/wukailu/nvdiffrast-torch) 在另一台拥有 cudatoolkit 且环境相似(python, torch, cuda版本相同)的服务器上预编译后安装。 + +### onnxruntime-gpu 安装 + +* 注意,同时安装 `onnxruntime` 与 `onnxruntime-gpu` 可能导致最终程序无法运行在GPU,而运行在CPU,导致极慢的推理速度。 +* [onnxruntime 官方安装指南](https://onnxruntime.ai/docs/install/#python-installs) +* TLDR: For cuda11.x, `pip install onnxruntime-gpu`. For cuda12.x, `pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ +`. +* 进一步的,可用安装基于 tensorrt 的 onnxruntime,进一步加快推理速度。 +* 注意:如果没有安装基于 tensorrt 的 onnxruntime,建议将 `https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4` 中 `TensorrtExecutionProvider` 删除。 +* 对于 cuda12.x 可用使用如下命令快速安装带有tensorrt的onnxruntime (注意将 `/root/miniconda3/lib/python3.10/site-packages` 修改为你的python 对应路径,将 `/root/.bashrc` 改为你的用户下路径 `.bashrc` 路劲) +``` +pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/ +pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/ +pip install tensorrt==8.6.0 +echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc +``` + +### pytorch3d 安装 + +* 根据 [pytorch3d 官方的安装建议](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux),建议使用预编译版本 +``` +import sys +import torch +pyt_version_str=torch.__version__.split("+")[0].replace(".", "") +version_str="".join([ + f"py3{sys.version_info.minor}_cu", + torch.version.cuda.replace(".",""), + f"_pyt{pyt_version_str}" +]) +!pip install fvcore iopath +!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html +``` + +### torch_scatter 安装 + +* 在[torch_scatter 官方安装指南](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation) 使用预编译的安装包快速安装。 +* 或者直接编译安装 `pip install git+https://github.com/rusty1s/pytorch_scatter.git` + +### 其他安装 + +* 其他文件 `pip install -r requirements.txt` 即可。 + +----- + +# Detailed Installation Guide + +* In `requirements-detail.txt`, we provide detailed versions of all packages, which correspond to the environment of `python3.10 + cuda12.2`. +* This project relies on several important PyPI packages, which may be difficult to install. + +### Installation of nvdiffrast + +* nvdiffrast will compile the corresponding torch plugin the first time it runs, which requires support from ninja and cudatoolkit. +* Therefore, it is necessary to ensure that ninja and cudatoolkit are correctly installed and that the CUDA_HOME environment variable is properly configured. +* For the installation of cudatoolkit, you can refer to the [Linux CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) and [Windows CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html). +* Ninja can be directly installed with `pip install ninja`. +* Then set the CUDA_HOME variable to the installation directory of cudatoolkit, such as `/usr/local/cuda`. +* Finally, `pip install nvdiffrast`. +* If you cannot install cudatoolkit on the computer (e.g., insufficient permissions), you can use my modified [pre-compiled version of nvdiffrast](https://github.com/wukailu/nvdiffrast-torch) to pre-compile on another computer that has cudatoolkit and a similar environment (same versions of python, torch, cuda) and then install the `.whl`. + +### Installation of onnxruntime-gpu + +* Note that installing both `onnxruntime` and `onnxruntime-gpu` may result in not running on the GPU but on the CPU, leading to extremely slow inference speed. +* [Official ONNX Runtime Installation Guide](https://onnxruntime.ai/docs/install/#python-installs) +* TLDR: For cuda11.x, `pip install onnxruntime-gpu`. For cuda12.x, `pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`. +* Furthermore, you can install onnxruntime based on tensorrt to further increase the inference speed. +* Note: If you do not correctly installed onnxruntime based on tensorrt, it is recommended to remove `TensorrtExecutionProvider` from `https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4`. +* For cuda12.x, you can quickly install onnxruntime with tensorrt using the following commands (note to change the path `/root/miniconda3/lib/python3.10/site-packages` to the corresponding path of your python, and change `/root/.bashrc` to the path of `.bashrc` under your user directory): +``` +pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/ +pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/ +pip install tensorrt==8.6.0 +echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc +``` + +### Installation of pytorch3d + +* According to the [official installation recommendations of pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux), it is recommended to use the pre-compiled version: +``` +import sys +import torch +pyt_version_str=torch.__version__.split("+")[0].replace(".", "") +version_str="".join([ + f"py3{sys.version_info.minor}_cu", + torch.version.cuda.replace(".",""), + f"_pyt{pyt_version_str}" +]) +!pip install fvcore iopath +!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html +``` + +### Installation of torch_scatter + +* Use the pre-compiled installation package according to the [official installation guide of torch_scatter](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation) for a quick installation. +* Alternatively, you can directly compile and install with `pip install git+https://github.com/rusty1s/pytorch_scatter.git`. + +### Other Installations + +* For other packages, simply `pip install -r requirements.txt`. + +----- + +# 官方インストールガイド + +* `requirements-detail.txt` には、各ライブラリのバージョンが詳細に提供されており、これは Python 3.10 + CUDA 12.2 に対応する環境です。 +* このプロジェクトは、いくつかの重要な PyPI パッケージに依存しており、これらのパッケージのインストールにはいくつかの困難が伴います。 + +### nvdiffrast のインストール + +* nvdiffrast は、最初に実行するときに、torch プラグインの対応バージョンをコンパイルします。このステップには、ninja および cudatoolkit のサポートが必要です。 +* したがって、ninja および cudatoolkit の正確なインストールと、CUDA_HOME 環境変数の正確な設定を確保する必要があります。 +* cudatoolkit のインストールについては、[Linux CUDA インストールガイド](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)、[Windows CUDA インストールガイド](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) を参照してください。 +* ninja は、直接 `pip install ninja` でインストールできます。 +* 次に、CUDA_HOME 変数を cudatoolkit のインストールディレクトリに設定します。例えば、`/usr/local/cuda` のように。 +* 最後に、`pip install nvdiffrast` を実行します。 +* 目標サーバーで cudatoolkit をインストールできない場合(例えば、権限が不足している場合)、私の修正した[事前コンパイル済みバージョンの nvdiffrast](https://github.com/wukailu/nvdiffrast-torch)を使用できます。これは、cudatoolkit があり、環境が似ている(Python、torch、cudaのバージョンが同じ)別のサーバーで事前コンパイルしてからインストールすることができます。 + +### onnxruntime-gpu のインストール + +* 注意:`onnxruntime` と `onnxruntime-gpu` を同時にインストールすると、最終的なプログラムが GPU 上で実行されず、CPU 上で実行される可能性があり、推論速度が非常に遅くなることがあります。 +* [onnxruntime 公式インストールガイド](https://onnxruntime.ai/docs/install/#python-installs) +* TLDR: cuda11.x 用には、`pip install onnxruntime-gpu` を使用します。cuda12.x 用には、`pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/` を使用します。 +* さらに、TensorRT ベースの onnxruntime をインストールして、推論速度をさらに向上させることができます。 +* 注意:TensorRT ベースの onnxruntime がインストールされていない場合は、`https://github.com/AiuniAI/Unique3D/blob/4e1174c3896fee992ffc780d0ea813500401fae9/scripts/load_onnx.py#L4` の `TensorrtExecutionProvider` を削除することをお勧めします。 +* cuda12.x の場合、次のコマンドを使用して迅速に TensorRT を備えた onnxruntime をインストールできます(`/root/miniconda3/lib/python3.10/site-packages` をあなたの Python に対応するパスに、`/root/.bashrc` をあなたのユーザーのパスの下の `.bashrc` に変更してください)。 +```bash +pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/ +pip install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple/ +pip install tensorrt==8.6.0 +echo -e "export LD_LIBRARY_PATH=/usr/local/cuda/targets/x86_64-linux/lib/:/root/miniconda3/lib/python3.10/site-packages/tensorrt:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc +``` + +### pytorch3d のインストール + +* [pytorch3d 公式のインストール提案](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-wheels-for-linux)に従い、事前コンパイル済みバージョンを使用することをお勧めします。 +```python +import sys +import torch +pyt_version_str=torch.__version__.split("+")[0].replace(".", "") +version_str="".join([ + f"py3{sys.version_info.minor}_cu", + torch.version.cuda.replace(".",""), + f"_pyt{pyt_version_str}" +]) +!pip install fvcore iopath +!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html +``` + +### torch_scatter のインストール + +* [torch_scatter 公式インストールガイド](https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation)に従い、事前コンパイル済みのインストールパッケージを使用して迅速インストールします。 +* または、直接コンパイルしてインストールする `pip install git+https://github.com/rusty1s/pytorch_scatter.git` も可能です。 + +### その他のインストール + +* その他のファイルについては、`pip install -r requirements.txt` を実行するだけです。 + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d7deb96a97942f0d293b05a5d7e950aeb9ddf31d --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 AiuniAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 9a792b60f19cc389edd3fa669c19458007b7cbc2..8a5e7ab80cad51d8059e97af07ac40c3fe3b52a7 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,137 @@ --- -title: 3D Genesis -emoji: 🏆 -colorFrom: indigo -colorTo: yellow +title: 3D-Genesis +app_file: gradio_app.py sdk: gradio sdk_version: 5.5.0 -app_file: app.py -pinned: false --- +**[中文版本](README_zh.md)** -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +**[日本語版](README_jp.md)** + +# Unique3D +Official implementation of Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image. + +[Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/) + +## [Paper](https://arxiv.org/abs/2405.20343) | [Project page](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradio Demo](http://unique3d.demo.avar.cn/) | [Online Demo](https://www.aiuni.ai/) + +* Demo inference speed: Gradio Demo > Huggingface Demo > Huggingface Demo2 > Online Demo + +**If the Gradio Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.** + +

+ +

+ +High-fidelity and diverse textured meshes generated by Unique3D from single-view wild images in 30 seconds. + +## More features + +The repo is still being under construction, thanks for your patience. +- [x] Upload weights. +- [x] Local gradio demo. +- [x] Detailed tutorial. +- [x] Huggingface demo. +- [ ] Detailed local demo. +- [x] Comfyui support. +- [x] Windows support. +- [x] Docker support. +- [ ] More stable reconstruction with normal. +- [ ] Training code release. + +## Preparation for inference + +* [Detailed linux installation guide](Installation.md). + +### Linux System Setup. + +Adapted for Ubuntu 22.04.4 LTS and CUDA 12.1. +```angular2html +conda create -n unique3d python=3.11 +conda activate unique3d + +pip install ninja +pip install diffusers==0.27.2 + +pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3.1/index.html + +pip install -r requirements.txt +``` + +[oak-barry](https://github.com/oak-barry) provide another setup script for torch210+cu121 at [here](https://github.com/oak-barry/Unique3D). + +### Windows Setup. + +* Thank you very much `jtydhr88` for the windows installation method! See [issues/15](https://github.com/AiuniAI/Unique3D/issues/15). + +According to [issues/15](https://github.com/AiuniAI/Unique3D/issues/15), implemented a bat script to run the commands, so you can: +1. Might still require Visual Studio Build Tools, you can find it from [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/?q=build+tools). +2. Create conda env and activate it + 1. `conda create -n unique3d-py311 python=3.11` + 2. `conda activate unique3d-py311` +3. download [triton whl](https://huggingface.co/madbuda/triton-windows-builds/resolve/main/triton-2.1.0-cp311-cp311-win_amd64.whl) for py311, and put it into this project. +4. run **install_windows_win_py311_cu121.bat** +5. answer y while asking you uninstall onnxruntime and onnxruntime-gpu +6. create the output folder **tmp\gradio** under the driver root, such as F:\tmp\gradio for me. +7. python app/gradio_local.py --port 7860 + +More details prefer to [issues/15](https://github.com/AiuniAI/Unique3D/issues/15). + +### Interactive inference: run your local gradio demo. + +1. Download the weights from [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt) or [Tsinghua Cloud Drive](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/), and extract it to `ckpt/*`. +``` +Unique3D + ├──ckpt + ├── controlnet-tile/ + ├── image2normal/ + ├── img2mvimg/ + ├── realesrgan-x4.onnx + └── v1-inference.yaml +``` + +2. Run the interactive inference locally. +```bash +python app/gradio_local.py --port 7860 +``` + +## ComfyUI Support + +Thanks for the [ComfyUI-Unique3D](https://github.com/jtydhr88/ComfyUI-Unique3D) implementation from [jtydhr88](https://github.com/jtydhr88)! + +## Tips to get better results + +**Important: Because the mesh is normalized by the longest edge of xyz during training, it is desirable that the input image needs to contain the longest edge of the object during inference, or else you may get erroneously squashed results.** +1. Unique3D is sensitive to the facing direction of input images. Due to the distribution of the training data, orthographic front-facing images with a rest pose always lead to good reconstructions. +2. Images with occlusions will cause worse reconstructions, since four views cannot cover the complete object. Images with fewer occlusions lead to better results. +3. Pass an image with as high a resolution as possible to the input when resolution is a factor. + +## Acknowledgement + +We have intensively borrowed code from the following repositories. Many thanks to the authors for sharing their code. +- [Stable Diffusion](https://github.com/CompVis/stable-diffusion) +- [Wonder3d](https://github.com/xxlong0/Wonder3D) +- [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus) +- [Continues Remeshing](https://github.com/Profactor/continuous-remeshing) +- [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals) + +## Collaborations +Our mission is to create a 4D generative model with 3D concepts. This is just our first step, and the road ahead is still long, but we are confident. We warmly invite you to join the discussion and explore potential collaborations in any capacity. **If you're interested in connecting or partnering with us, please don't hesitate to reach out via email (wkl22@mails.tsinghua.edu.cn)**. + +- Follow us on twitter for the latest updates: https://x.com/aiuni_ai +- Join AIGC 3D/4D generation community on discord: https://discord.gg/aiuni +- Research collaboration, please contact: ai@aiuni.ai + +## Citation + +If you found Unique3D helpful, please cite our report: +```bibtex +@misc{wu2024unique3d, + title={Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image}, + author={Kailu Wu and Fangfu Liu and Zhihan Cai and Runjie Yan and Hanyang Wang and Yating Hu and Yueqi Duan and Kaisheng Ma}, + year={2024}, + eprint={2405.20343}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` diff --git a/README_jp.md b/README_jp.md new file mode 100644 index 0000000000000000000000000000000000000000..44d3fb850dae8bcafe25892a16b70897beebda7d --- /dev/null +++ b/README_jp.md @@ -0,0 +1,126 @@ +**他の言語のバージョン [英語](README.md) [中国語](README_zh.md)** + +# Unique3D +Unique3D: 単一画像からの高品質かつ効率的な3Dメッシュ生成の公式実装。 + +[Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/) + +## [論文](https://arxiv.org/abs/2405.20343) | [プロジェクトページ](https://wukailu.github.io/Unique3D/) | [Huggingfaceデモ](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradioデモ](http://unique3d.demo.avar.cn/) | [オンラインデモ](https://www.aiuni.ai/) + +* デモ推論速度: Gradioデモ > Huggingfaceデモ > Huggingfaceデモ2 > オンラインデモ + +**Gradioデモが残念ながらハングアップしたり、非常に混雑している場合は、[aiuni.ai](https://www.aiuni.ai/)のオンラインデモを使用できます。これは無料で試すことができます(登録招待コードを取得するには、Discordに参加してください: https://discord.gg/aiuni)。ただし、オンラインデモはGradioデモとは少し異なり、推論速度が遅く、生成結果が安定していない可能性がありますが、素材の品質は良いです。** + +

+ +

+ +Unique3Dは、野生の単一画像から高忠実度および多様なテクスチャメッシュを30秒で生成します。 + +## より多くの機能 + +リポジトリはまだ構築中です。ご理解いただきありがとうございます。 +- [x] 重みのアップロード。 +- [x] ローカルGradioデモ。 +- [ ] 詳細なチュートリアル。 +- [x] Huggingfaceデモ。 +- [ ] 詳細なローカルデモ。 +- [x] Comfyuiサポート。 +- [x] Windowsサポート。 +- [ ] Dockerサポート。 +- [ ] ノーマルでより安定した再構築。 +- [ ] トレーニングコードのリリース。 + +## 推論の準備 + +### Linuxシステムセットアップ + +Ubuntu 22.04.4 LTSおよびCUDA 12.1に適応。 +```angular2html +conda create -n unique3d python=3.11 +conda activate unique3d + +pip install ninja +pip install diffusers==0.27.2 + +pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3.1/index.html + +pip install -r requirements.txt +``` + +[oak-barry](https://github.com/oak-barry)は、[こちら](https://github.com/oak-barry/Unique3D)でtorch210+cu121の別のセットアップスクリプトを提供しています。 + +### Windowsセットアップ + +* `jtydhr88`によるWindowsインストール方法に非常に感謝します![issues/15](https://github.com/AiuniAI/Unique3D/issues/15)を参照してください。 + +[issues/15](https://github.com/AiuniAI/Unique3D/issues/15)によると、コマンドを実行するバッチスクリプトを実装したので、以下の手順に従ってください。 +1. [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/?q=build+tools)からVisual Studio Build Toolsが必要になる場合があります。 +2. conda envを作成し、アクティブにします。 + 1. `conda create -n unique3d-py311 python=3.11` + 2. `conda activate unique3d-py311` +3. [triton whl](https://huggingface.co/madbuda/triton-windows-builds/resolve/main/triton-2.1.0-cp311-cp311-win_amd64.whl)をダウンロードし、このプロジェクトに配置します。 +4. **install_windows_win_py311_cu121.bat**を実行します。 +5. onnxruntimeおよびonnxruntime-gpuのアンインストールを求められた場合は、yと回答します。 +6. ドライバールートの下に**tmp\gradio**フォルダを作成します(例:F:\tmp\gradio)。 +7. python app/gradio_local.py --port 7860 + +詳細は[issues/15](https://github.com/AiuniAI/Unique3D/issues/15)を参照してください。 + +### インタラクティブ推論:ローカルGradioデモを実行する + +1. [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt)または[Tsinghua Cloud Drive](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/)から重みをダウンロードし、`ckpt/*`に抽出します。 +``` +Unique3D + ├──ckpt + ├── controlnet-tile/ + ├── image2normal/ + ├── img2mvimg/ + ├── realesrgan-x4.onnx + └── v1-inference.yaml +``` + +2. インタラクティブ推論をローカルで実行します。 +```bash +python app/gradio_local.py --port 7860 +``` + +## ComfyUIサポート + +[jtydhr88](https://github.com/jtydhr88)からの[ComfyUI-Unique3D](https://github.com/jtydhr88/ComfyUI-Unique3D)の実装に感謝します! + +## より良い結果を得るためのヒント + +1. Unique3Dは入力画像の向きに敏感です。トレーニングデータの分布により、正面を向いた直交画像は常に良い再構築につながります。 +2. 遮蔽のある画像は、4つのビューがオブジェクトを完全にカバーできないため、再構築が悪化します。遮蔽の少ない画像は、より良い結果につながります。 +3. 可能な限り高解像度の画像を入力として使用してください。 + +## 謝辞 + +以下のリポジトリからコードを大量に借用しました。コードを共有してくれた著者に感謝します。 +- [Stable Diffusion](https://github.com/CompVis/stable-diffusion) +- [Wonder3d](https://github.com/xxlong0/Wonder3D) +- [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus) +- [Continues Remeshing](https://github.com/Profactor/continuous-remeshing) +- [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals) + +## コラボレーション +私たちの使命は、3Dの概念を持つ4D生成モデルを作成することです。これは私たちの最初のステップであり、前途はまだ長いですが、私たちは自信を持っています。あらゆる形態の潜在的なコラボレーションを探求し、議論に参加することを心から歓迎します。**私たちと連絡を取りたい、またはパートナーシップを結びたい方は、メールでお気軽にお問い合わせください (wkl22@mails.tsinghua.edu.cn)**。 + +- 最新情報を入手するには、Twitterをフォローしてください: https://x.com/aiuni_ai +- DiscordでAIGC 3D/4D生成コミュニティに参加してください: https://discord.gg/aiuni +- 研究協力については、ai@aiuni.aiまでご連絡ください。 + +## 引用 + +Unique3Dが役立つと思われる場合は、私たちのレポートを引用してください: +```bibtex +@misc{wu2024unique3d, + title={Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image}, + author={Kailu Wu and Fangfu Liu and Zhihan Cai and Runjie Yan and Hanyang Wang and Yating Hu and Yueqi Duan and Kaisheng Ma}, + year={2024}, + eprint={2405.20343}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` diff --git a/README_zh.md b/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..cab6233ce15a3cc5630b85eb5d9590dc9389431c --- /dev/null +++ b/README_zh.md @@ -0,0 +1,62 @@ +**其他语言版本 [English](README.md)** + +# Unique3D +High-Quality and Efficient 3D Mesh Generation from a Single Image + +[Kailu Wu](https://scholar.google.com/citations?user=VTU0gysAAAAJ&hl=zh-CN&oi=ao), [Fangfu Liu](https://liuff19.github.io/), Zhihan Cai, Runjie Yan, Hanyang Wang, Yating Hu, [Yueqi Duan](https://duanyueqi.github.io/), [Kaisheng Ma](https://group.iiis.tsinghua.edu.cn/~maks/) + +## [论文](https://arxiv.org/abs/2405.20343) | [项目页面](https://wukailu.github.io/Unique3D/) | [Huggingface Demo](https://huggingface.co/spaces/Wuvin/Unique3D) | [Gradio Demo](http://unique3d.demo.avar.cn/) | [在线演示](https://www.aiuni.ai/) + + + +

+ +

+ +Unique3D从单视图图像生成高保真度和多样化纹理的网格,在4090上大约需要30秒。 + +### 推理准备 + +#### Linux系统设置 +```angular2html +conda create -n unique3d +conda activate unique3d +pip install -r requirements.txt +``` + +#### 交互式推理:运行您的本地gradio演示 + +1. 从 [huggingface spaces](https://huggingface.co/spaces/Wuvin/Unique3D/tree/main/ckpt) 下载或者从[清华云盘](https://cloud.tsinghua.edu.cn/d/319762ec478d46c8bdf7/)下载权重,并将其解压到`ckpt/*`。 +``` +Unique3D + ├──ckpt + ├── controlnet-tile/ + ├── image2normal/ + ├── img2mvimg/ + ├── realesrgan-x4.onnx + └── v1-inference.yaml +``` + +2. 在本地运行交互式推理。 +```bash +python app/gradio_local.py --port 7860 +``` + +## 获取更好结果的提示 + +1. Unique3D对输入图像的朝向非常敏感。由于训练数据的分布,**正交正视图像**通常总是能带来良好的重建。对于人物而言,最好是 A-pose 或者 T-pose,因为目前训练数据很少含有其他类型姿态。 +2. 有遮挡的图像会导致更差的重建,因为4个视图无法覆盖完整的对象。遮挡较少的图像会带来更好的结果。 +3. 尽可能将高分辨率的图像用作输入。 + +## 致谢 + +我们借用了以下代码库的代码。非常感谢作者们分享他们的代码。 +- [Stable Diffusion](https://github.com/CompVis/stable-diffusion) +- [Wonder3d](https://github.com/xxlong0/Wonder3D) +- [Zero123Plus](https://github.com/SUDO-AI-3D/zero123plus) +- [Continues Remeshing](https://github.com/Profactor/continuous-remeshing) +- [Depth from Normals](https://github.com/YertleTurtleGit/depth-from-normals) + +## 合作 + +我们使命是创建一个具有3D概念的4D生成模型。这只是我们的第一步,前方的道路仍然很长,但我们有信心。我们热情邀请您加入讨论,并探索任何形式的潜在合作。**如果您有兴趣联系或与我们合作,欢迎通过电子邮件(wkl22@mails.tsinghua.edu.cn)与我们联系**。 diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/all_models.py b/app/all_models.py new file mode 100644 index 0000000000000000000000000000000000000000..d7df963350c704c864b73d1e403f880f86bdfd5d --- /dev/null +++ b/app/all_models.py @@ -0,0 +1,22 @@ +import torch +from scripts.sd_model_zoo import load_common_sd15_pipe +from diffusers import StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline + + +class MyModelZoo: + _pipe_disney_controlnet_lineart_ipadapter_i2i: StableDiffusionControlNetImg2ImgPipeline = None + + base_model = "runwayml/stable-diffusion-v1-5" + + def __init__(self, base_model=None) -> None: + if base_model is not None: + self.base_model = base_model + + @property + def pipe_disney_controlnet_tile_ipadapter_i2i(self): + return self._pipe_disney_controlnet_lineart_ipadapter_i2i + + def init_models(self): + self._pipe_disney_controlnet_lineart_ipadapter_i2i = load_common_sd15_pipe(base_model=self.base_model, ip_adapter=True, plus_model=False, controlnet="./ckpt/controlnet-tile", pipeline_class=StableDiffusionControlNetImg2ImgPipeline) + +model_zoo = MyModelZoo() diff --git a/app/custom_models/image2mvimage.yaml b/app/custom_models/image2mvimage.yaml new file mode 100644 index 0000000000000000000000000000000000000000..07ad06caeae33c598a929d8f3c4595bd403da32d --- /dev/null +++ b/app/custom_models/image2mvimage.yaml @@ -0,0 +1,63 @@ +pretrained_model_name_or_path: "./ckpt/img2mvimg" +mixed_precision: "bf16" + +init_config: + # enable controls + enable_cross_attn_lora: False + enable_cross_attn_ip: False + enable_self_attn_lora: False + enable_self_attn_ref: False + enable_multiview_attn: True + + # for cross attention + init_cross_attn_lora: False + init_cross_attn_ip: False + cross_attn_lora_rank: 256 # 0 for not enabled + cross_attn_lora_only_kv: False + ipadapter_pretrained_name: "h94/IP-Adapter" + ipadapter_subfolder_name: "models" + ipadapter_weight_name: "ip-adapter_sd15.safetensors" + ipadapter_effect_on: "all" # all, first + + # for self attention + init_self_attn_lora: False + self_attn_lora_rank: 256 + self_attn_lora_only_kv: False + + # for self attention ref + init_self_attn_ref: False + self_attn_ref_position: "attn1" + self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers" + self_attn_ref_pixel_wise_crosspond: False + self_attn_ref_effect_on: "all" + + # for multiview attention + init_multiview_attn: True + multiview_attn_position: "attn1" + use_mv_joint_attn: True + num_modalities: 1 + + # for unet + init_unet_path: "${pretrained_model_name_or_path}" + cat_condition: True # cat condition to input + + # for cls embedding + init_num_cls_label: 8 # for initialize + cls_labels: [0, 1, 2, 3] # for current task + +trainers: + - trainer_type: "image2mvimage_trainer" + trainer: + pretrained_model_name_or_path: "${pretrained_model_name_or_path}" + attn_config: + cls_labels: [0, 1, 2, 3] # for current task + enable_cross_attn_lora: False + enable_cross_attn_ip: False + enable_self_attn_lora: False + enable_self_attn_ref: False + enable_multiview_attn: True + resolution: "256" + condition_image_resolution: "256" + normal_cls_offset: 4 + condition_image_column_name: "conditioning_image" + image_column_name: "image" \ No newline at end of file diff --git a/app/custom_models/image2normal.yaml b/app/custom_models/image2normal.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de5e9e871d77eaa20ffe0d1488f91c82e39a8542 --- /dev/null +++ b/app/custom_models/image2normal.yaml @@ -0,0 +1,61 @@ +pretrained_model_name_or_path: "lambdalabs/sd-image-variations-diffusers" +mixed_precision: "bf16" + +init_config: + # enable controls + enable_cross_attn_lora: False + enable_cross_attn_ip: False + enable_self_attn_lora: False + enable_self_attn_ref: True + enable_multiview_attn: False + + # for cross attention + init_cross_attn_lora: False + init_cross_attn_ip: False + cross_attn_lora_rank: 512 # 0 for not enabled + cross_attn_lora_only_kv: False + ipadapter_pretrained_name: "h94/IP-Adapter" + ipadapter_subfolder_name: "models" + ipadapter_weight_name: "ip-adapter_sd15.safetensors" + ipadapter_effect_on: "all" # all, first + + # for self attention + init_self_attn_lora: False + self_attn_lora_rank: 512 + self_attn_lora_only_kv: False + + # for self attention ref + init_self_attn_ref: True + self_attn_ref_position: "attn1" + self_attn_ref_other_model_name: "lambdalabs/sd-image-variations-diffusers" + self_attn_ref_pixel_wise_crosspond: True + self_attn_ref_effect_on: "all" + + # for multiview attention + init_multiview_attn: False + multiview_attn_position: "attn1" + num_modalities: 1 + + # for unet + init_unet_path: "${pretrained_model_name_or_path}" + init_num_cls_label: 0 # for initialize + cls_labels: [] # for current task + +trainers: + - trainer_type: "image2image_trainer" + trainer: + pretrained_model_name_or_path: "${pretrained_model_name_or_path}" + attn_config: + cls_labels: [] # for current task + enable_cross_attn_lora: False + enable_cross_attn_ip: False + enable_self_attn_lora: False + enable_self_attn_ref: True + enable_multiview_attn: False + resolution: "512" + condition_image_resolution: "512" + condition_image_column_name: "conditioning_image" + image_column_name: "image" + + + diff --git a/app/custom_models/mvimg_prediction.py b/app/custom_models/mvimg_prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..5c81bdf4546eda1154f1904fafdf8962f38c438a --- /dev/null +++ b/app/custom_models/mvimg_prediction.py @@ -0,0 +1,57 @@ +import sys +import torch +import gradio as gr +from PIL import Image +import numpy as np +from rembg import remove +from app.utils import change_rgba_bg, rgba_to_rgb +from app.custom_models.utils import load_pipeline +from scripts.all_typing import * +from scripts.utils import session, simple_preprocess + +training_config = "app/custom_models/image2mvimage.yaml" +checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth" +trainer, pipeline = load_pipeline(training_config, checkpoint_path) +# pipeline.enable_model_cpu_offload() + +def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs): + if isinstance(img_list, Image.Image): + img_list = [img_list] + img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list] + ret = [] + for img in img_list: + images = trainer.pipeline_forward( + pipeline=pipeline, + image=img, + guidance_scale=guidance_scale, + **kwargs + ).images + ret.extend(images) + return ret + + +def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145): + if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.: + # still do remove using rembg, since simple_preprocess requires RGBA image + print("RGB image not RGBA! still remove bg!") + remove_bg = True + + if remove_bg: + input_image = remove(input_image, session=session) + + # make front_pil RGBA with white bg + input_image = change_rgba_bg(input_image, "white") + single_image = simple_preprocess(input_image) + + generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None + + rgb_pils = predict( + single_image, + generator=generator, + guidance_scale=guidance_scale, + width=256, + height=256, + num_inference_steps=30, + ) + + return rgb_pils, single_image diff --git a/app/custom_models/normal_prediction.py b/app/custom_models/normal_prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..32568715b743739e34db3ed6a221d8d95112800a --- /dev/null +++ b/app/custom_models/normal_prediction.py @@ -0,0 +1,26 @@ +import sys +from PIL import Image +from app.utils import rgba_to_rgb, simple_remove +from app.custom_models.utils import load_pipeline +from scripts.utils import rotate_normals_torch +from scripts.all_typing import * + +training_config = "app/custom_models/image2normal.yaml" +checkpoint_path = "ckpt/image2normal/unet_state_dict.pth" +trainer, pipeline = load_pipeline(training_config, checkpoint_path) +# pipeline.enable_model_cpu_offload() + +def predict_normals(image: List[Image.Image], guidance_scale=2., do_rotate=True, num_inference_steps=30, **kwargs): + img_list = image if isinstance(image, list) else [image] + img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list] + images = trainer.pipeline_forward( + pipeline=pipeline, + image=img_list, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + **kwargs + ).images + images = simple_remove(images) + if do_rotate and len(images) > 1: + images = rotate_normals_torch(images, return_types='pil') + return images \ No newline at end of file diff --git a/app/custom_models/utils.py b/app/custom_models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2dff2824ec7fe5926fa4dfc0f8b73fab8abf995e --- /dev/null +++ b/app/custom_models/utils.py @@ -0,0 +1,75 @@ +import torch +from typing import List +from dataclasses import dataclass +from app.utils import rgba_to_rgb +from custum_3d_diffusion.trainings.config_classes import ExprimentConfig, TrainerSubConfig +from custum_3d_diffusion import modules +from custum_3d_diffusion.custum_modules.unifield_processor import AttnConfig, ConfigurableUNet2DConditionModel +from custum_3d_diffusion.trainings.base import BasicTrainer +from custum_3d_diffusion.trainings.utils import load_config + + +@dataclass +class FakeAccelerator: + device: torch.device = torch.device("cuda") + + +def init_trainers(cfg_path: str, weight_dtype: torch.dtype, extras: dict): + accelerator = FakeAccelerator() + cfg: ExprimentConfig = load_config(ExprimentConfig, cfg_path, extras) + init_config: AttnConfig = load_config(AttnConfig, cfg.init_config) + configurable_unet = ConfigurableUNet2DConditionModel(init_config, weight_dtype) + configurable_unet.enable_xformers_memory_efficient_attention() + trainer_cfgs: List[TrainerSubConfig] = [load_config(TrainerSubConfig, trainer) for trainer in cfg.trainers] + trainers: List[BasicTrainer] = [modules.find(trainer.trainer_type)(accelerator, None, configurable_unet, trainer.trainer, weight_dtype, i) for i, trainer in enumerate(trainer_cfgs)] + return trainers, configurable_unet + +from app.utils import make_image_grid, split_image +def process_image(function, img, guidance_scale=2., merged_image=False, remove_bg=True): + from rembg import remove + if remove_bg: + img = remove(img) + img = rgba_to_rgb(img) + if merged_image: + img = split_image(img, rows=2) + images = function( + image=img, + guidance_scale=guidance_scale, + ) + if len(images) > 1: + return make_image_grid(images, rows=2) + else: + return images[0] + + +def process_text(trainer, pipeline, img, guidance_scale=2.): + pipeline.cfg.validation_prompts = [img] + titles, images = trainer.batched_validation_forward(pipeline, guidance_scale=[guidance_scale]) + return images[0] + + +def load_pipeline(config_path, ckpt_path, pipeline_filter=lambda x: True, weight_dtype = torch.bfloat16): + training_config = config_path + load_from_checkpoint = ckpt_path + extras = [] + device = "cuda" + trainers, configurable_unet = init_trainers(training_config, weight_dtype, extras) + shared_modules = dict() + for trainer in trainers: + shared_modules = trainer.init_shared_modules(shared_modules) + + if load_from_checkpoint is not None: + state_dict = torch.load(load_from_checkpoint) + configurable_unet.unet.load_state_dict(state_dict, strict=False) + # Move unet, vae and text_encoder to device and cast to weight_dtype + configurable_unet.unet.to(device, dtype=weight_dtype) + + pipeline = None + trainer_out = None + for trainer in trainers: + if pipeline_filter(trainer.cfg.trainer_name): + pipeline = trainer.construct_pipeline(shared_modules, configurable_unet.unet) + pipeline.set_progress_bar_config(disable=False) + trainer_out = trainer + pipeline = pipeline.to(device) + return trainer_out, pipeline \ No newline at end of file diff --git a/app/examples/Groot.png b/app/examples/Groot.png new file mode 100644 index 0000000000000000000000000000000000000000..1e44e5969b89b608412b07decb5212e27822fd21 Binary files /dev/null and b/app/examples/Groot.png differ diff --git a/app/examples/aaa.png b/app/examples/aaa.png new file mode 100644 index 0000000000000000000000000000000000000000..8a1b3106a4232e293aa19c7dd1e958aec03767dd Binary files /dev/null and b/app/examples/aaa.png differ diff --git a/app/examples/abma.png b/app/examples/abma.png new file mode 100644 index 0000000000000000000000000000000000000000..d2b85fbc8496af631d10446005f5cba3895ef5ef Binary files /dev/null and b/app/examples/abma.png differ diff --git a/app/examples/akun.png b/app/examples/akun.png new file mode 100644 index 0000000000000000000000000000000000000000..ab6790ff297a4bbe828e6536a079d65528bb5acc Binary files /dev/null and b/app/examples/akun.png differ diff --git a/app/examples/anya.png b/app/examples/anya.png new file mode 100644 index 0000000000000000000000000000000000000000..43061b59488f33b37b8c92d4b3ebb4a9a20d90f5 Binary files /dev/null and b/app/examples/anya.png differ diff --git a/app/examples/bag.png b/app/examples/bag.png new file mode 100644 index 0000000000000000000000000000000000000000..e91e10cc662404fdfddc5d8b5df7f68c7028e31c --- /dev/null +++ b/app/examples/bag.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac798ea1f112091c04f5bdfa47c490806fb433a02fe17758aa1f8c55cd64b66e +size 1544762 diff --git a/app/examples/ex1.png b/app/examples/ex1.png new file mode 100644 index 0000000000000000000000000000000000000000..88e80fc4d1d727fbedca1daa248a127fecf7876b --- /dev/null +++ b/app/examples/ex1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d49ccccd40fe0317c2886b0d36a11667003d17a49cc49d9244208d250de9fe31 +size 1169069 diff --git a/app/examples/ex2.png b/app/examples/ex2.png new file mode 100644 index 0000000000000000000000000000000000000000..49531897982d3360bd849c6bb2acadb118eeeeb8 Binary files /dev/null and b/app/examples/ex2.png differ diff --git a/app/examples/ex3.jpg b/app/examples/ex3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dbccc807d7881353d57c60f03f9d559be77e4464 Binary files /dev/null and b/app/examples/ex3.jpg differ diff --git a/app/examples/ex4.png b/app/examples/ex4.png new file mode 100644 index 0000000000000000000000000000000000000000..156bcd9e518f5dca50b58f27519ccc72d397a720 Binary files /dev/null and b/app/examples/ex4.png differ diff --git a/app/examples/generated_1715761545_frame0.png b/app/examples/generated_1715761545_frame0.png new file mode 100644 index 0000000000000000000000000000000000000000..8158529cf66a1f96efaaf348d66f47a392976661 Binary files /dev/null and b/app/examples/generated_1715761545_frame0.png differ diff --git a/app/examples/generated_1715762357_frame0.png b/app/examples/generated_1715762357_frame0.png new file mode 100644 index 0000000000000000000000000000000000000000..6371b4f241e709526c3e7ae224044605e19abbdc Binary files /dev/null and b/app/examples/generated_1715762357_frame0.png differ diff --git a/app/examples/generated_1715763329_frame0.png b/app/examples/generated_1715763329_frame0.png new file mode 100644 index 0000000000000000000000000000000000000000..15355d7c47456f6f5ee901a21d69fc898d7e9ecf Binary files /dev/null and b/app/examples/generated_1715763329_frame0.png differ diff --git a/app/examples/hatsune_miku.png b/app/examples/hatsune_miku.png new file mode 100644 index 0000000000000000000000000000000000000000..2fecf005fdd56a396c4894256fbb98fcc1c4dd8f Binary files /dev/null and b/app/examples/hatsune_miku.png differ diff --git a/app/examples/princess-large.png b/app/examples/princess-large.png new file mode 100644 index 0000000000000000000000000000000000000000..e5a4aeec023a771b7aba9bcaafef70c98757bfbe Binary files /dev/null and b/app/examples/princess-large.png differ diff --git a/app/gradio_3dgen.py b/app/gradio_3dgen.py new file mode 100644 index 0000000000000000000000000000000000000000..c2113a773e36884e185550fa036f5ba4c72611c3 --- /dev/null +++ b/app/gradio_3dgen.py @@ -0,0 +1,71 @@ +import os +import gradio as gr +from PIL import Image +from pytorch3d.structures import Meshes +from app.utils import clean_up +from app.custom_models.mvimg_prediction import run_mvprediction +from app.custom_models.normal_prediction import predict_normals +from scripts.refine_lr_to_sr import run_sr_fast +from scripts.utils import save_glb_and_video +from scripts.multiview_inference import geo_reconstruct + +def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"): + if preview_img is None: + raise gr.Error("preview_img is none") + if isinstance(preview_img, str): + preview_img = Image.open(preview_img) + + if preview_img.size[0] <= 512: + preview_img = run_sr_fast([preview_img])[0] + rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s + new_meshes = geo_reconstruct(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type) + vertices = new_meshes.verts_packed() + vertices = vertices / 2 * 1.35 + vertices[..., [0, 2]] = - vertices[..., [0, 2]] + new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures) + + ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video) + return ret_mesh, video + +####################################### +def create_ui(concurrency_id="wkl"): + with gr.Row(): + with gr.Column(scale=2): + input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview') + + example_folder = os.path.join(os.path.dirname(__file__), "./examples") + example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)]) + gr.Examples( + examples=example_fns, + inputs=[input_image], + cache_examples=False, + label='Examples (click one of the images below to start)', + examples_per_page=12 + ) + + + with gr.Column(scale=3): + # export mesh display + output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320) + output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False) + + input_processing = gr.Checkbox( + value=True, + label='Remove Background', + visible=True, + ) + do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False) + expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False) + init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False) + setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed") + render_video = gr.Checkbox(value=False, visible=False, label="generate video") + fullrunv2_btn = gr.Button('Generate 3D', interactive=True) + + fullrunv2_btn.click( + fn = generate3dv2, + inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type], + outputs=[output_mesh, output_video], + concurrency_id=concurrency_id, + api_name="generate3dv2", + ).success(clean_up, api_name=False) + return input_image diff --git a/app/gradio_3dgen_steps.py b/app/gradio_3dgen_steps.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc91ada685a3b8b789770323b2bdfd7ca4fd3ac --- /dev/null +++ b/app/gradio_3dgen_steps.py @@ -0,0 +1,87 @@ +import gradio as gr +from PIL import Image + +from app.custom_models.mvimg_prediction import run_mvprediction +from app.utils import make_image_grid, split_image +from scripts.utils import save_glb_and_video + +def concept_to_multiview(preview_img, input_processing, seed, guidance=1.): + seed = int(seed) + if preview_img is None: + raise gr.Error("preview_img is none.") + if isinstance(preview_img, str): + preview_img = Image.open(preview_img) + + rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=seed, guidance_scale=guidance) + rgb_pil = make_image_grid(rgb_pils, rows=2) + return rgb_pil, front_pil + +def concept_to_multiview_ui(concurrency_id="wkl"): + with gr.Row(): + with gr.Column(scale=2): + preview_img = gr.Image(type='pil', image_mode='RGBA', label='Frontview') + input_processing = gr.Checkbox( + value=True, + label='Remove Background', + ) + seed = gr.Slider(minimum=-1, maximum=1000000000, value=-1, step=1.0, label="seed") + guidance = gr.Slider(minimum=1.0, maximum=5.0, value=1.0, label="Guidance Scale", step=0.5) + run_btn = gr.Button('Generate Multiview', interactive=True) + with gr.Column(scale=3): + # export mesh display + output_rgb = gr.Image(type='pil', label="RGB", show_label=True) + output_front = gr.Image(type='pil', image_mode='RGBA', label="Frontview", show_label=True) + run_btn.click( + fn = concept_to_multiview, + inputs=[preview_img, input_processing, seed, guidance], + outputs=[output_rgb, output_front], + concurrency_id=concurrency_id, + api_name=False, + ) + return output_rgb, output_front + +from app.custom_models.normal_prediction import predict_normals +from scripts.multiview_inference import geo_reconstruct +def multiview_to_mesh_v2(rgb_pil, normal_pil, front_pil, do_refine=False, expansion_weight=0.1, init_type="std"): + rgb_pils = split_image(rgb_pil, rows=2) + if normal_pil is not None: + normal_pil = split_image(normal_pil, rows=2) + if front_pil is None: + front_pil = rgb_pils[0] + new_meshes = geo_reconstruct(rgb_pils, normal_pil, front_pil, do_refine=do_refine, predict_normal=normal_pil is None, expansion_weight=expansion_weight, init_type=init_type) + ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=False) + return ret_mesh + +def new_multiview_to_mesh_ui(concurrency_id="wkl"): + with gr.Row(): + with gr.Column(scale=2): + rgb_pil = gr.Image(type='pil', image_mode='RGB', label='RGB') + front_pil = gr.Image(type='pil', image_mode='RGBA', label='Frontview(Optinal)') + normal_pil = gr.Image(type='pil', image_mode='RGBA', label='Normal(Optinal)') + do_refine = gr.Checkbox( + value=False, + label='Refine rgb', + visible=False, + ) + expansion_weight = gr.Slider(minimum=-1.0, maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False) + init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh initialization", value="std", visible=False) + run_btn = gr.Button('Generate 3D', interactive=True) + with gr.Column(scale=3): + # export mesh display + output_mesh = gr.Model3D(value=None, label="mesh model", show_label=True) + run_btn.click( + fn = multiview_to_mesh_v2, + inputs=[rgb_pil, normal_pil, front_pil, do_refine, expansion_weight, init_type], + outputs=[output_mesh], + concurrency_id=concurrency_id, + api_name="multiview_to_mesh", + ) + return rgb_pil, front_pil, output_mesh + + +####################################### +def create_step_ui(concurrency_id="wkl"): + with gr.Tab(label="3D:concept_to_multiview"): + concept_to_multiview_ui(concurrency_id) + with gr.Tab(label="3D:new_multiview_to_mesh"): + new_multiview_to_mesh_ui(concurrency_id) diff --git a/app/gradio_local.py b/app/gradio_local.py new file mode 100644 index 0000000000000000000000000000000000000000..8de11e403214cb94d777b034ed8135ffd1355671 --- /dev/null +++ b/app/gradio_local.py @@ -0,0 +1,76 @@ +if __name__ == "__main__": + import os + import sys + sys.path.append(os.curdir) + if 'CUDA_VISIBLE_DEVICES' not in os.environ: + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + os.environ['TRANSFORMERS_OFFLINE']='0' + os.environ['DIFFUSERS_OFFLINE']='0' + os.environ['HF_HUB_OFFLINE']='0' + os.environ['GRADIO_ANALYTICS_ENABLED']='False' + os.environ['HF_ENDPOINT']='https://hf-mirror.com' + import torch + torch.set_float32_matmul_precision('medium') + torch.backends.cuda.matmul.allow_tf32 = True + torch.set_grad_enabled(False) + +import gradio as gr +import argparse + +from app.gradio_3dgen import create_ui as create_3d_ui +# from app.gradio_3dgen_steps import create_step_ui +from app.all_models import model_zoo + + +_TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image''' +_DESCRIPTION = ''' +[Project page](https://wukailu.github.io/Unique3D/) + +* High-fidelity and diverse textured meshes generated by Unique3D from single-view images. + +**If the Gradio Demo is overcrowded or fails to produce stable results, you can use the Online Demo [aiuni.ai](https://www.aiuni.ai/), which is free to try (get the registration invitation code Join Discord: https://discord.gg/aiuni). However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, but the generation is much more stable.** +''' + +def launch( + port, + listen=False, + share=False, + gradio_root="", +): + model_zoo.init_models() + + with gr.Blocks( + title=_TITLE, + theme=gr.themes.Monochrome(), + ) as demo: + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown('# ' + _TITLE) + gr.Markdown(_DESCRIPTION) + create_3d_ui("wkl") + + launch_args = {} + if listen: + launch_args["server_name"] = "0.0.0.0" + + demo.queue(default_concurrency_limit=1).launch( + server_port=None if port == 0 else port, + share=share, + root_path=gradio_root if gradio_root != "" else None, # "/myapp" + **launch_args, + ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + args, extra = parser.parse_known_args() + parser.add_argument("--listen", action="store_true") + parser.add_argument("--port", type=int, default=0) + parser.add_argument("--share", action="store_true") + parser.add_argument("--gradio_root", default="") + args = parser.parse_args() + launch( + args.port, + listen=args.listen, + share=args.share, + gradio_root=args.gradio_root, + ) \ No newline at end of file diff --git a/app/utils.py b/app/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5385bfb256015dd892682ca6b8936c6632a453bf --- /dev/null +++ b/app/utils.py @@ -0,0 +1,112 @@ +import torch +import numpy as np +from PIL import Image +import gc +import numpy as np +import numpy as np +from PIL import Image +from scripts.refine_lr_to_sr import run_sr_fast + +GRADIO_CACHE = "/tmp/gradio/" + +def clean_up(): + torch.cuda.empty_cache() + gc.collect() + +def remove_color(arr): + if arr.shape[-1] == 4: + arr = arr[..., :3] + # calc diffs + base = arr[0, 0] + diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1) + alpha = (diffs <= 80) + + arr[alpha] = 255 + alpha = ~alpha + arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1) + return arr + +def simple_remove(imgs, run_sr=True): + """Only works for normal""" + if not isinstance(imgs, list): + imgs = [imgs] + single_input = True + else: + single_input = False + if run_sr: + imgs = run_sr_fast(imgs) + rets = [] + for img in imgs: + arr = np.array(img) + arr = remove_color(arr) + rets.append(Image.fromarray(arr.astype(np.uint8))) + if single_input: + return rets[0] + return rets + +def rgba_to_rgb(rgba: Image.Image, bkgd="WHITE"): + new_image = Image.new("RGBA", rgba.size, bkgd) + new_image.paste(rgba, (0, 0), rgba) + new_image = new_image.convert('RGB') + return new_image + +def change_rgba_bg(rgba: Image.Image, bkgd="WHITE"): + rgb_white = rgba_to_rgb(rgba, bkgd) + new_rgba = Image.fromarray(np.concatenate([np.array(rgb_white), np.array(rgba)[:, :, 3:4]], axis=-1)) + return new_rgba + +def split_image(image, rows=None, cols=None): + """ + inverse function of make_image_grid + """ + # image is in square + if rows is None and cols is None: + # image.size [W, H] + rows = 1 + cols = image.size[0] // image.size[1] + assert cols * image.size[1] == image.size[0] + subimg_size = image.size[1] + elif rows is None: + subimg_size = image.size[0] // cols + rows = image.size[1] // subimg_size + assert rows * subimg_size == image.size[1] + elif cols is None: + subimg_size = image.size[1] // rows + cols = image.size[0] // subimg_size + assert cols * subimg_size == image.size[0] + else: + subimg_size = image.size[1] // rows + assert cols * subimg_size == image.size[0] + subimgs = [] + for i in range(rows): + for j in range(cols): + subimg = image.crop((j*subimg_size, i*subimg_size, (j+1)*subimg_size, (i+1)*subimg_size)) + subimgs.append(subimg) + return subimgs + +def make_image_grid(images, rows=None, cols=None, resize=None): + if rows is None and cols is None: + rows = 1 + cols = len(images) + if rows is None: + rows = len(images) // cols + if len(images) % cols != 0: + rows += 1 + if cols is None: + cols = len(images) // rows + if len(images) % rows != 0: + cols += 1 + total_imgs = rows * cols + if total_imgs > len(images): + images += [Image.new(images[0].mode, images[0].size) for _ in range(total_imgs - len(images))] + + if resize is not None: + images = [img.resize((resize, resize)) for img in images] + + w, h = images[0].size + grid = Image.new(images[0].mode, size=(cols * w, rows * h)) + + for i, img in enumerate(images): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + diff --git a/assets/teaser.jpg b/assets/teaser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..53a97e29b8fbd5719199a07bbca041f1faebb89e Binary files /dev/null and b/assets/teaser.jpg differ diff --git a/assets/teaser_safe.jpg b/assets/teaser_safe.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ee70f0232d908bcb98220e244d4d839c0c43a9dc --- /dev/null +++ b/assets/teaser_safe.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5eb9060bc45c1d334f988e8053f1de40cf60df907750dfef89d81cdbe86ffc79 +size 2819752 diff --git a/custum_3d_diffusion/custum_modules/attention_processors.py b/custum_3d_diffusion/custum_modules/attention_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..9b01cc45f3e6fbd7fa23f5409dd00fa7222b7762 --- /dev/null +++ b/custum_3d_diffusion/custum_modules/attention_processors.py @@ -0,0 +1,385 @@ +from typing import Any, Dict, Optional +import torch +from diffusers.models.attention_processor import Attention + +def construct_pix2pix_attention(hidden_states_dim, norm_type="none"): + if norm_type == "layernorm": + norm = torch.nn.LayerNorm(hidden_states_dim) + else: + norm = torch.nn.Identity() + attention = Attention( + query_dim=hidden_states_dim, + heads=8, + dim_head=hidden_states_dim // 8, + bias=True, + ) + # NOTE: xformers 0.22 does not support batchsize >= 4096 + attention.xformers_not_supported = True # hacky solution + return norm, attention + +class ExtraAttnProc(torch.nn.Module): + def __init__( + self, + chained_proc, + enabled=False, + name=None, + mode='extract', + with_proj_in=False, + proj_in_dim=768, + target_dim=None, + pixel_wise_crosspond=False, + norm_type="none", # none or layernorm + crosspond_effect_on="all", # all or first + crosspond_chain_pos="parralle", # before or parralle or after + simple_3d=False, + views=4, + ) -> None: + super().__init__() + self.enabled = enabled + self.chained_proc = chained_proc + self.name = name + self.mode = mode + self.with_proj_in=with_proj_in + self.proj_in_dim = proj_in_dim + self.target_dim = target_dim or proj_in_dim + self.hidden_states_dim = self.target_dim + self.pixel_wise_crosspond = pixel_wise_crosspond + self.crosspond_effect_on = crosspond_effect_on + self.crosspond_chain_pos = crosspond_chain_pos + self.views = views + self.simple_3d = simple_3d + if self.with_proj_in and self.enabled: + self.in_linear = torch.nn.Linear(self.proj_in_dim, self.target_dim, bias=False) + if self.target_dim == self.proj_in_dim: + self.in_linear.weight.data = torch.eye(proj_in_dim) + else: + self.in_linear = None + if self.pixel_wise_crosspond and self.enabled: + self.crosspond_norm, self.crosspond_attention = construct_pix2pix_attention(self.hidden_states_dim, norm_type=norm_type) + + def do_crosspond_attention(self, hidden_states: torch.FloatTensor, other_states: torch.FloatTensor): + hidden_states = self.crosspond_norm(hidden_states) + + batch, L, D = hidden_states.shape + assert hidden_states.shape == other_states.shape, f"got {hidden_states.shape} and {other_states.shape}" + # to -> batch * L, 1, D + hidden_states = hidden_states.reshape(batch * L, 1, D) + other_states = other_states.reshape(batch * L, 1, D) + hidden_states_catted = other_states + hidden_states = self.crosspond_attention( + hidden_states, + encoder_hidden_states=hidden_states_catted, + ) + return hidden_states.reshape(batch, L, D) + + def __call__( + self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, + ref_dict: dict = None, mode=None, **kwargs + ) -> Any: + if not self.enabled: + return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + assert ref_dict is not None + if (mode or self.mode) == 'extract': + ref_dict[self.name] = hidden_states + hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) + if self.pixel_wise_crosspond and self.crosspond_chain_pos == "after": + ref_dict[self.name] = hidden_states1 + return hidden_states1 + elif (mode or self.mode) == 'inject': + ref_state = ref_dict.pop(self.name) + if self.with_proj_in: + ref_state = self.in_linear(ref_state) + + B, L, D = ref_state.shape + if hidden_states.shape[0] == B: + modalities = 1 + views = 1 + else: + modalities = hidden_states.shape[0] // B // self.views + views = self.views + if self.pixel_wise_crosspond: + if self.crosspond_effect_on == "all": + ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, *ref_state.shape[-2:]) + + if self.crosspond_chain_pos == "before": + hidden_states = hidden_states + self.do_crosspond_attention(hidden_states, ref_state) + + hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) + + if self.crosspond_chain_pos == "parralle": + hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states, ref_state) + + if self.crosspond_chain_pos == "after": + hidden_states1 = hidden_states1 + self.do_crosspond_attention(hidden_states1, ref_state) + return hidden_states1 + else: + assert self.crosspond_effect_on == "first" + # hidden_states [B * modalities * views, L, D] + # ref_state [B, L, D] + ref_state = ref_state[:, None].expand(-1, modalities, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1]) # [B * modalities, L, D] + + def do_paritial_crosspond(hidden_states, ref_state): + first_view_hidden_states = hidden_states.view(-1, views, hidden_states.shape[1], hidden_states.shape[2])[:, 0] # [B * modalities, L, D] + hidden_states2 = self.do_crosspond_attention(first_view_hidden_states, ref_state) # [B * modalities, L, D] + hidden_states2_padded = torch.zeros_like(hidden_states).reshape(-1, views, hidden_states.shape[1], hidden_states.shape[2]) + hidden_states2_padded[:, 0] = hidden_states2 + hidden_states2_padded = hidden_states2_padded.reshape(-1, hidden_states.shape[1], hidden_states.shape[2]) + return hidden_states2_padded + + if self.crosspond_chain_pos == "before": + hidden_states = hidden_states + do_paritial_crosspond(hidden_states, ref_state) + + hidden_states1 = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) # [B * modalities * views, L, D] + if self.crosspond_chain_pos == "parralle": + hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states, ref_state) + if self.crosspond_chain_pos == "after": + hidden_states1 = hidden_states1 + do_paritial_crosspond(hidden_states1, ref_state) + return hidden_states1 + elif self.simple_3d: + B, L, C = encoder_hidden_states.shape + mv = self.views + encoder_hidden_states = encoder_hidden_states.reshape(B // mv, mv, L, C) + ref_state = ref_state[:, None] + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1) + encoder_hidden_states = encoder_hidden_states.reshape(B // mv, 1, (mv+1) * L, C) + encoder_hidden_states = encoder_hidden_states.repeat(1, mv, 1, 1).reshape(-1, (mv+1) * L, C) + return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) + else: + ref_state = ref_state[:, None].expand(-1, modalities * views, -1, -1).reshape(-1, ref_state.shape[-2], ref_state.shape[-1]) + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_state], dim=1) + return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) + else: + raise NotImplementedError("mode or self.mode is required to be 'extract' or 'inject'") + +def add_extra_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs): + return_dict = torch.nn.ModuleDict() + proj_in_dim = kwargs.get('proj_in_dim', False) + kwargs.pop('proj_in_dim', None) + + def recursive_add_processors(name: str, module: torch.nn.Module): + for sub_name, child in module.named_children(): + if "ref_unet" not in (sub_name + name): + recursive_add_processors(f"{name}.{sub_name}", child) + + if isinstance(module, Attention): + new_processor = ExtraAttnProc( + chained_proc=module.get_processor(), + enabled=enable_filter(f"{name}.processor"), + name=f"{name}.processor", + proj_in_dim=proj_in_dim if proj_in_dim else module.cross_attention_dim, + target_dim=module.cross_attention_dim, + **kwargs + ) + module.set_processor(new_processor) + return_dict[f"{name}.processor".replace(".", "__")] = new_processor + + for name, module in model.named_children(): + recursive_add_processors(name, module) + return return_dict + +def switch_extra_processor(model, enable_filter=lambda x:True): + def recursive_add_processors(name: str, module: torch.nn.Module): + for sub_name, child in module.named_children(): + recursive_add_processors(f"{name}.{sub_name}", child) + + if isinstance(module, ExtraAttnProc): + module.enabled = enable_filter(name) + + for name, module in model.named_children(): + recursive_add_processors(name, module) + +class multiviewAttnProc(torch.nn.Module): + def __init__( + self, + chained_proc, + enabled=False, + name=None, + hidden_states_dim=None, + chain_pos="parralle", # before or parralle or after + num_modalities=1, + views=4, + base_img_size=64, + ) -> None: + super().__init__() + self.enabled = enabled + self.chained_proc = chained_proc + self.name = name + self.hidden_states_dim = hidden_states_dim + self.num_modalities = num_modalities + self.views = views + self.base_img_size = base_img_size + self.chain_pos = chain_pos + self.diff_joint_attn = True + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + **kwargs + ) -> torch.Tensor: + if not self.enabled: + return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) + + B, L, C = hidden_states.shape + mv = self.views + hidden_states = hidden_states.reshape(B // mv, mv, L, C).reshape(-1, mv * L, C) + hidden_states = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs) + return hidden_states.reshape(B // mv, mv, L, C).reshape(-1, L, C) + +def add_multiview_processor(model: torch.nn.Module, enable_filter=lambda x:True, **kwargs): + return_dict = torch.nn.ModuleDict() + def recursive_add_processors(name: str, module: torch.nn.Module): + for sub_name, child in module.named_children(): + if "ref_unet" not in (sub_name + name): + recursive_add_processors(f"{name}.{sub_name}", child) + + if isinstance(module, Attention): + new_processor = multiviewAttnProc( + chained_proc=module.get_processor(), + enabled=enable_filter(f"{name}.processor"), + name=f"{name}.processor", + hidden_states_dim=module.inner_dim, + **kwargs + ) + module.set_processor(new_processor) + return_dict[f"{name}.processor".replace(".", "__")] = new_processor + + for name, module in model.named_children(): + recursive_add_processors(name, module) + + return return_dict + +def switch_multiview_processor(model, enable_filter=lambda x:True): + def recursive_add_processors(name: str, module: torch.nn.Module): + for sub_name, child in module.named_children(): + recursive_add_processors(f"{name}.{sub_name}", child) + + if isinstance(module, Attention): + processor = module.get_processor() + if isinstance(processor, multiviewAttnProc): + processor.enabled = enable_filter(f"{name}.processor") + + for name, module in model.named_children(): + recursive_add_processors(name, module) + +class NNModuleWrapper(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + +class AttnProcessorSwitch(torch.nn.Module): + def __init__( + self, + proc_dict: dict, + enabled_proc="default", + name=None, + switch_name="default_switch", + ): + super().__init__() + self.proc_dict = torch.nn.ModuleDict({k: (v if isinstance(v, torch.nn.Module) else NNModuleWrapper(v)) for k, v in proc_dict.items()}) + self.enabled_proc = enabled_proc + self.name = name + self.switch_name = switch_name + self.choose_module(enabled_proc) + + def choose_module(self, enabled_proc): + self.enabled_proc = enabled_proc + assert enabled_proc in self.proc_dict.keys() + + def __call__( + self, + *args, + **kwargs + ) -> torch.FloatTensor: + used_proc = self.proc_dict[self.enabled_proc] + return used_proc(*args, **kwargs) + +def add_switch(model: torch.nn.Module, module_filter=lambda x:True, switch_dict_fn=lambda x: {"default": x}, switch_name="default_switch", enabled_proc="default"): + return_dict = torch.nn.ModuleDict() + def recursive_add_processors(name: str, module: torch.nn.Module): + for sub_name, child in module.named_children(): + if "ref_unet" not in (sub_name + name): + recursive_add_processors(f"{name}.{sub_name}", child) + + if isinstance(module, Attention): + processor = module.get_processor() + if module_filter(processor): + proc_dict = switch_dict_fn(processor) + new_processor = AttnProcessorSwitch( + proc_dict=proc_dict, + enabled_proc=enabled_proc, + name=f"{name}.processor", + switch_name=switch_name, + ) + module.set_processor(new_processor) + return_dict[f"{name}.processor".replace(".", "__")] = new_processor + + for name, module in model.named_children(): + recursive_add_processors(name, module) + + return return_dict + +def change_switch(model: torch.nn.Module, switch_name="default_switch", enabled_proc="default"): + def recursive_change_processors(name: str, module: torch.nn.Module): + for sub_name, child in module.named_children(): + recursive_change_processors(f"{name}.{sub_name}", child) + + if isinstance(module, Attention): + processor = module.get_processor() + if isinstance(processor, AttnProcessorSwitch) and processor.switch_name == switch_name: + processor.choose_module(enabled_proc) + + for name, module in model.named_children(): + recursive_change_processors(name, module) + +########## Hack: Attention fix ############# +from diffusers.models.attention import Attention + +def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + **cross_attention_kwargs, +) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + +Attention.forward = forward \ No newline at end of file diff --git a/custum_3d_diffusion/custum_modules/unifield_processor.py b/custum_3d_diffusion/custum_modules/unifield_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..aa47c58fdadae5a828dbc7621d436f3f27f2c902 --- /dev/null +++ b/custum_3d_diffusion/custum_modules/unifield_processor.py @@ -0,0 +1,460 @@ +from types import FunctionType +from typing import Any, Dict, List +from diffusers import UNet2DConditionModel +import torch +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, ImageProjection +from diffusers.models.attention_processor import Attention, AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor +from dataclasses import dataclass, field +from diffusers.loaders import IPAdapterMixin +from custum_3d_diffusion.custum_modules.attention_processors import add_extra_processor, switch_extra_processor, add_multiview_processor, switch_multiview_processor, add_switch, change_switch + +@dataclass +class AttnConfig: + """ + * CrossAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), IPAdapter module (achieves conceptual control). + * SelfAttention: Attention module (inherits knowledge), LoRA module (achieves fine-tuning), Reference Attention module (achieves pixel-level control). + * Multiview Attention module: Multiview Attention module (achieves multi-view consistency). + * Cross Modality Attention module: Cross Modality Attention module (achieves multi-modality consistency). + + For setups: + train_xxx_lr is implemented in the U-Net architecture. + enable_xxx_lora is implemented in the U-Net architecture. + enable_xxx_ip is implemented in the processor and U-Net architecture. + enable_xxx_ref_proj_in is implemented in the processor. + """ + latent_size: int = 64 + + train_lr: float = 0 + # for cross attention + # 0 learning rate for not training + train_cross_attn_lr: float = 0 + train_cross_attn_lora_lr: float = 0 + train_cross_attn_ip_lr: float = 0 # 0 for not trained + init_cross_attn_lora: bool = False + enable_cross_attn_lora: bool = False + init_cross_attn_ip: bool = False + enable_cross_attn_ip: bool = False + cross_attn_lora_rank: int = 64 # 0 for not enabled + cross_attn_lora_only_kv: bool = False + ipadapter_pretrained_name: str = "h94/IP-Adapter" + ipadapter_subfolder_name: str = "models" + ipadapter_weight_name: str = "ip-adapter-plus_sd15.safetensors" + ipadapter_effect_on: str = "all" # all, first + + # for self attention + train_self_attn_lr: float = 0 + train_self_attn_lora_lr: float = 0 + init_self_attn_lora: bool = False + enable_self_attn_lora: bool = False + self_attn_lora_rank: int = 64 + self_attn_lora_only_kv: bool = False + + train_self_attn_ref_lr: float = 0 + train_ref_unet_lr: float = 0 + init_self_attn_ref: bool = False + enable_self_attn_ref: bool = False + self_attn_ref_other_model_name: str = "" + self_attn_ref_position: str = "attn1" + self_attn_ref_pixel_wise_crosspond: bool = False # enable pixel_wise_crosspond in refattn + self_attn_ref_chain_pos: str = "parralle" # before or parralle or after + self_attn_ref_effect_on: str = "all" # all or first, for _crosspond attn + self_attn_ref_zero_init: bool = True + use_simple3d_attn: bool = False + + # for multiview attention + init_multiview_attn: bool = False + enable_multiview_attn: bool = False + multiview_attn_position: str = "attn1" + multiview_chain_pose: str = "parralle" # before or parralle or after + num_modalities: int = 1 + use_mv_joint_attn: bool = False + + # for unet + init_unet_path: str = "runwayml/stable-diffusion-v1-5" + init_num_cls_label: int = 0 # for initialize + cls_labels: List[int] = field(default_factory=lambda: []) + cls_label_type: str = "embedding" + cat_condition: bool = False # cat condition to input + +class Configurable: + attn_config: AttnConfig + + def set_config(self, attn_config: AttnConfig): + raise NotImplementedError() + + def update_config(self, attn_config: AttnConfig): + self.attn_config = attn_config + + def do_set_config(self, attn_config: AttnConfig): + self.set_config(attn_config) + for name, module in self.named_modules(): + if isinstance(module, Configurable): + if hasattr(module, "do_set_config"): + module.do_set_config(attn_config) + else: + print(f"Warning: {name} has no attribute do_set_config, but is an instance of Configurable") + module.attn_config = attn_config + + def do_update_config(self, attn_config: AttnConfig): + self.update_config(attn_config) + for name, module in self.named_modules(): + if isinstance(module, Configurable): + if hasattr(module, "do_update_config"): + module.do_update_config(attn_config) + else: + print(f"Warning: {name} has no attribute do_update_config, but is an instance of Configurable") + module.attn_config = attn_config + +from diffusers import ModelMixin # Must import ModelMixin for CompiledUNet +class UnifieldWrappedUNet(UNet2DConditionModel): + forward_hook: FunctionType + + def forward(self, *args, **kwargs): + if hasattr(self, 'forward_hook'): + return self.forward_hook(super().forward, *args, **kwargs) + return super().forward(*args, **kwargs) + + +class ConfigurableUNet2DConditionModel(Configurable, IPAdapterMixin): + unet: UNet2DConditionModel + + cls_embedding_param_dict = {} + cross_attn_lora_param_dict = {} + self_attn_lora_param_dict = {} + cross_attn_param_dict = {} + self_attn_param_dict = {} + ipadapter_param_dict = {} + ref_attn_param_dict = {} + ref_unet_param_dict = {} + multiview_attn_param_dict = {} + other_param_dict = {} + + rev_param_name_mapping = {} + + class_labels = [] + def set_class_labels(self, class_labels: torch.Tensor): + if self.attn_config.init_num_cls_label != 0: + self.class_labels = class_labels.to(self.unet.device).long() + + def __init__(self, init_config: AttnConfig, weight_dtype) -> None: + super().__init__() + self.weight_dtype = weight_dtype + self.set_config(init_config) + + def enable_xformers_memory_efficient_attention(self): + self.unet.enable_xformers_memory_efficient_attention + def recursive_add_processors(name: str, module: torch.nn.Module): + for sub_name, child in module.named_children(): + recursive_add_processors(f"{name}.{sub_name}", child) + + if isinstance(module, Attention): + if hasattr(module, 'xformers_not_supported'): + return + old_processor = module.get_processor() + if isinstance(old_processor, (AttnProcessor, AttnProcessor2_0)): + module.set_use_memory_efficient_attention_xformers(True) + + for name, module in self.unet.named_children(): + recursive_add_processors(name, module) + + def __getattr__(self, name: str) -> Any: + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + # --- for IPAdapterMixin + + def register_modules(self, **kwargs): + for name, module in kwargs.items(): + # set models + setattr(self, name, module) + + def register_to_config(self, **kwargs): + pass + + def unload_ip_adapter(self): + raise NotImplementedError() + + # --- for Configurable + + def get_refunet(self): + if self.attn_config.self_attn_ref_other_model_name == "self": + return self.unet + else: + return self.unet.ref_unet + + def set_config(self, attn_config: AttnConfig): + self.attn_config = attn_config + + unet_type = UnifieldWrappedUNet + # class_embed_type = "projection" for 'camera' + # class_embed_type = None for 'embedding' + unet_kwargs = {} + if attn_config.init_num_cls_label > 0: + if attn_config.cls_label_type == "embedding": + unet_kwargs = { + "num_class_embeds": attn_config.init_num_cls_label, + "device_map": None, + "low_cpu_mem_usage": False, + "class_embed_type": None, + } + else: + raise ValueError(f"cls_label_type {attn_config.cls_label_type} is not supported") + + self.unet: UnifieldWrappedUNet = unet_type.from_pretrained( + attn_config.init_unet_path, subfolder="unet", torch_dtype=self.weight_dtype, + ignore_mismatched_sizes=True, # Added this line + **unet_kwargs + ) + assert isinstance(self.unet, UnifieldWrappedUNet) + self.unet.forward_hook = self.unet_forward_hook + + if self.attn_config.cat_condition: + # double in_channels + if self.unet.config.in_channels != 8: + self.unet.register_to_config(in_channels=self.unet.config.in_channels * 2) + # repeate unet.conv_in weight twice + doubled_conv_in = torch.nn.Conv2d(self.unet.conv_in.in_channels * 2, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) + doubled_conv_in.weight.data = torch.cat([self.unet.conv_in.weight.data, torch.zeros_like(self.unet.conv_in.weight.data)], dim=1) + doubled_conv_in.bias.data = self.unet.conv_in.bias.data + self.unet.conv_in = doubled_conv_in + + used_param_ids = set() + + if attn_config.init_cross_attn_lora: + # setup lora + from peft import LoraConfig + from peft.utils import get_peft_model_state_dict + if attn_config.cross_attn_lora_only_kv: + target_modules=["attn2.to_k", "attn2.to_v"] + else: + target_modules=["attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0"] + lora_config: LoraConfig = LoraConfig( + r=attn_config.cross_attn_lora_rank, + lora_alpha=attn_config.cross_attn_lora_rank, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + adapter_name="cross_attn_lora" + self.unet.add_adapter(lora_config, adapter_name=adapter_name) + # update cross_attn_lora_param_dict + self.cross_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids} + used_param_ids.update(self.cross_attn_lora_param_dict.keys()) + + if attn_config.init_self_attn_lora: + # setup lora + from peft import LoraConfig + if attn_config.self_attn_lora_only_kv: + target_modules=["attn1.to_k", "attn1.to_v"] + else: + target_modules=["attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0"] + lora_config: LoraConfig = LoraConfig( + r=attn_config.self_attn_lora_rank, + lora_alpha=attn_config.self_attn_lora_rank, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + adapter_name="self_attn_lora" + self.unet.add_adapter(lora_config, adapter_name=adapter_name) + # update cross_self_lora_param_dict + self.self_attn_lora_param_dict = {id(param): param for name, param in self.unet.named_parameters() if adapter_name in name and id(param) not in used_param_ids} + used_param_ids.update(self.self_attn_lora_param_dict.keys()) + + if attn_config.init_num_cls_label != 0: + self.cls_embedding_param_dict = {id(param): param for param in self.unet.class_embedding.parameters()} + used_param_ids.update(self.cls_embedding_param_dict.keys()) + self.set_class_labels(torch.tensor(attn_config.cls_labels).long()) + + if attn_config.init_cross_attn_ip: + self.image_encoder = None + # setup ipadapter + self.load_ip_adapter( + attn_config.ipadapter_pretrained_name, + subfolder=attn_config.ipadapter_subfolder_name, + weight_name=attn_config.ipadapter_weight_name + ) + # warp ip_adapter_attn_proc with switch + from diffusers.models.attention_processor import IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0 + add_switch(self.unet, module_filter=lambda x: isinstance(x, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)), switch_dict_fn=lambda x: {"ipadapter": x, "default": XFormersAttnProcessor()}, switch_name="ipadapter_switch", enabled_proc="ipadapter") + # update ipadapter_param_dict + # weights are in attention processors and unet.encoder_hid_proj + self.ipadapter_param_dict = {id(param): param for param in self.unet.encoder_hid_proj.parameters() if id(param) not in used_param_ids} + used_param_ids.update(self.ipadapter_param_dict.keys()) + print("DEBUG: ipadapter_param_dict len in encoder_hid_proj", len(self.ipadapter_param_dict)) + for name, processor in self.unet.attn_processors.items(): + if hasattr(processor, "to_k_ip"): + self.ipadapter_param_dict.update({id(param): param for param in processor.parameters()}) + print(f"DEBUG: ipadapter_param_dict len in all", len(self.ipadapter_param_dict)) + + ref_unet = None + if attn_config.init_self_attn_ref: + # setup reference attention processor + if attn_config.self_attn_ref_other_model_name == "self": + raise NotImplementedError("self reference is not fully implemented") + else: + ref_unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained( + attn_config.self_attn_ref_other_model_name, subfolder="unet", torch_dtype=self.unet.dtype + ) + ref_unet.to(self.unet.device) + if self.attn_config.train_ref_unet_lr == 0: + ref_unet.eval() + ref_unet.requires_grad_(False) + else: + ref_unet.train() + + add_extra_processor( + model=ref_unet, + enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"), + mode='extract', + with_proj_in=False, + pixel_wise_crosspond=False, + ) + # NOTE: Here require cross_attention_dim in two unet's self attention should be the same + processor_dict = add_extra_processor( + model=self.unet, + enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor"), + mode='inject', + with_proj_in=False, + pixel_wise_crosspond=attn_config.self_attn_ref_pixel_wise_crosspond, + crosspond_effect_on=attn_config.self_attn_ref_effect_on, + crosspond_chain_pos=attn_config.self_attn_ref_chain_pos, + simple_3d=attn_config.use_simple3d_attn, + ) + self.ref_unet_param_dict = {id(param): param for name, param in ref_unet.named_parameters() if id(param) not in used_param_ids and (attn_config.self_attn_ref_position in name)} + if attn_config.self_attn_ref_chain_pos != "after": + # pop untrainable paramters + for name, param in ref_unet.named_parameters(): + if id(param) in self.ref_unet_param_dict and ('up_blocks.3.attentions.2.transformer_blocks.0.' in name): + self.ref_unet_param_dict.pop(id(param)) + used_param_ids.update(self.ref_unet_param_dict.keys()) + # update ref_attn_param_dict + self.ref_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids} + used_param_ids.update(self.ref_attn_param_dict.keys()) + + if attn_config.init_multiview_attn: + processor_dict = add_multiview_processor( + model = self.unet, + enable_filter = lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor"), + num_modalities = attn_config.num_modalities, + base_img_size = attn_config.latent_size, + chain_pos = attn_config.multiview_chain_pose, + ) + # update multiview_attn_param_dict + self.multiview_attn_param_dict = {id(param): param for name, param in processor_dict.named_parameters() if id(param) not in used_param_ids} + used_param_ids.update(self.multiview_attn_param_dict.keys()) + + # initialize cross_attn_param_dict parameters + self.cross_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn2" in name and id(param) not in used_param_ids} + used_param_ids.update(self.cross_attn_param_dict.keys()) + + # initialize self_attn_param_dict parameters + self.self_attn_param_dict = {id(param): param for name, param in self.unet.named_parameters() if "attn1" in name and id(param) not in used_param_ids} + used_param_ids.update(self.self_attn_param_dict.keys()) + + # initialize other_param_dict parameters + self.other_param_dict = {id(param): param for name, param in self.unet.named_parameters() if id(param) not in used_param_ids} + + if ref_unet is not None: + self.unet.ref_unet = ref_unet + + self.rev_param_name_mapping = {id(param): name for name, param in self.unet.named_parameters()} + + self.update_config(attn_config, force_update=True) + return self.unet + + _attn_keys_to_update = ["enable_cross_attn_lora", "enable_cross_attn_ip", "enable_self_attn_lora", "enable_self_attn_ref", "enable_multiview_attn", "cls_labels"] + + def update_config(self, attn_config: AttnConfig, force_update=False): + assert isinstance(self.unet, UNet2DConditionModel), "unet must be an instance of UNet2DConditionModel" + + need_to_update = False + # update cls_labels + for key in self._attn_keys_to_update: + if getattr(self.attn_config, key) != getattr(attn_config, key): + need_to_update = True + break + if not force_update and not need_to_update: + return + + self.set_class_labels(torch.tensor(attn_config.cls_labels).long()) + + # setup loras + if self.attn_config.init_cross_attn_lora or self.attn_config.init_self_attn_lora: + if attn_config.enable_cross_attn_lora or attn_config.enable_self_attn_lora: + cross_attn_lora_weight = 1. if attn_config.enable_cross_attn_lora > 0 else 0 + self_attn_lora_weight = 1. if attn_config.enable_self_attn_lora > 0 else 0 + self.unet.set_adapters(["cross_attn_lora", "self_attn_lora"], weights=[cross_attn_lora_weight, self_attn_lora_weight]) + else: + self.unet.disable_adapters() + + # setup ipadapter + if self.attn_config.init_cross_attn_ip: + if attn_config.enable_cross_attn_ip: + change_switch(self.unet, "ipadapter_switch", "ipadapter") + else: + change_switch(self.unet, "ipadapter_switch", "default") + + # setup reference attention processor + if self.attn_config.init_self_attn_ref: + if attn_config.enable_self_attn_ref: + switch_extra_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.self_attn_ref_position}.processor")) + else: + switch_extra_processor(self.unet, enable_filter=lambda name: False) + + # setup multiview attention processor + if self.attn_config.init_multiview_attn: + if attn_config.enable_multiview_attn: + switch_multiview_processor(self.unet, enable_filter=lambda name: name.endswith(f"{attn_config.multiview_attn_position}.processor")) + else: + switch_multiview_processor(self.unet, enable_filter=lambda name: False) + + # update cls_labels + for key in self._attn_keys_to_update: + setattr(self.attn_config, key, getattr(attn_config, key)) + + def unet_forward_hook(self, raw_forward, sample: torch.FloatTensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, cross_attention_kwargs=None, condition_latents=None, class_labels=None, noisy_condition_input=False, cond_pixels_clip=None, **kwargs): + if class_labels is None and len(self.class_labels) > 0: + class_labels = self.class_labels.repeat(sample.shape[0] // self.class_labels.shape[0]).to(sample.device) + elif self.attn_config.init_num_cls_label != 0: + assert class_labels is not None, "class_labels should be passed if self.class_labels is empty and self.attn_config.init_num_cls_label is not 0" + if class_labels is not None: + if self.attn_config.cls_label_type == "embedding": + pass + else: + raise ValueError(f"cls_label_type {self.attn_config.cls_label_type} is not supported") + if self.attn_config.init_self_attn_ref and self.attn_config.enable_self_attn_ref: + # NOTE: extra step, extract condition + ref_dict = {} + ref_unet = self.get_refunet().to(sample.device) + assert condition_latents is not None + if self.attn_config.self_attn_ref_other_model_name == "self": + raise NotImplementedError() + else: + with torch.no_grad(): + cond_encoder_hidden_states = encoder_hidden_states.reshape(condition_latents.shape[0], -1, *encoder_hidden_states.shape[1:])[:, 0] + if timestep.dim() == 0: + cond_timestep = timestep + else: + cond_timestep = timestep.reshape(condition_latents.shape[0], -1)[:, 0] + ref_unet(condition_latents, cond_timestep, cond_encoder_hidden_states, cross_attention_kwargs=dict(ref_dict=ref_dict)) + # NOTE: extra step, inject condition + # Predict the noise residual and compute loss + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + cross_attention_kwargs.update(ref_dict=ref_dict, mode='inject') + elif condition_latents is not None: + if not hasattr(self, 'condition_latents_raised'): + print("Warning! condition_latents is not None, but self_attn_ref is not enabled! This warning will only be raised once.") + self.condition_latents_raised = True + + if self.attn_config.init_cross_attn_ip: + raise NotImplementedError() + + if self.attn_config.cat_condition: + assert condition_latents is not None + B = condition_latents.shape[0] + cat_latents = condition_latents.reshape(B, 1, *condition_latents.shape[1:]).repeat(1, sample.shape[0] // B, 1, 1, 1).reshape(*sample.shape) + sample = torch.cat([sample, cat_latents], dim=1) + + return raw_forward(sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, **kwargs) diff --git a/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py b/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6ef50a3ddd5c2a7813916bbbabb31c5886f01f --- /dev/null +++ b/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2img.py @@ -0,0 +1,298 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# modified by Wuvin + + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from PIL import Image +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + + + +class StableDiffusionImageCustomPipeline( + StableDiffusionImageVariationPipeline +): + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + latents_offset=None, + noisy_cond_latents=False, + ): + super().__init__( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker + ) + latents_offset = tuple(latents_offset) if latents_offset is not None else None + self.latents_offset = latents_offset + if latents_offset is not None: + self.register_to_config(latents_offset=latents_offset) + self.noisy_cond_latents = noisy_cond_latents + self.register_to_config(noisy_cond_latents=noisy_cond_latents) + + def encode_latents(self, image, device, dtype, height, width): + # support batchsize > 1 + if isinstance(image, Image.Image): + image = [image] + image = [img.convert("RGB") for img in image] + images = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype) + latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor + if self.latents_offset is not None: + return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None] + else: + return latents + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # NOTE: the same as original code + negative_prompt_embeds = torch.zeros_like(image_embeddings) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + image: Union[Image.Image, List[Image.Image], torch.FloatTensor], + height: Optional[int] = 1024, + width: Optional[int] = 1024, + height_cond: Optional[int] = 512, + width_cond: Optional[int] = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + upper_left_feature: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + + Examples: + + ```py + from diffusers import StableDiffusionImageVariationPipeline + from PIL import Image + from io import BytesIO + import requests + + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + "lambdalabs/sd-image-variations-diffusers", revision="v2.0" + ) + pipe = pipe.to("cuda") + + url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200" + + response = requests.get(url) + image = Image.open(BytesIO(response.content)).convert("RGB") + + out = pipe(image, num_images_per_prompt=3, guidance_scale=15) + out["images"][0].save("result.jpg") + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + if isinstance(image, Image.Image) and upper_left_feature: + # only use the first one of four images + emb_image = image.crop((0, 0, image.size[0] // 2, image.size[1] // 2)) + else: + emb_image = image + + image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance) + cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.out_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.noisy_cond_latents: + raise ValueError("Noisy condition latents is not recommended.") + else: + noisy_cond_latents = cond_latents + + noisy_cond_latents = torch.cat([torch.zeros_like(noisy_cond_latents), noisy_cond_latents]) if do_classifier_free_guidance else noisy_cond_latents + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=noisy_cond_latents).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + self.maybe_free_model_hooks() + + if self.latents_offset is not None: + latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None] + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + +if __name__ == "__main__": + pass diff --git a/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py b/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py new file mode 100644 index 0000000000000000000000000000000000000000..de342d1b9767b6d1cea138bb24d2d2fff34229fc --- /dev/null +++ b/custum_3d_diffusion/custum_pipeline/unifield_pipeline_img2mvimg.py @@ -0,0 +1,296 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# modified by Wuvin + + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline +from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from PIL import Image +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + + + +class StableDiffusionImage2MVCustomPipeline( + StableDiffusionImageVariationPipeline +): + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + latents_offset=None, + noisy_cond_latents=False, + condition_offset=True, + ): + super().__init__( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker + ) + latents_offset = tuple(latents_offset) if latents_offset is not None else None + self.latents_offset = latents_offset + if latents_offset is not None: + self.register_to_config(latents_offset=latents_offset) + if noisy_cond_latents: + raise NotImplementedError("Noisy condition latents not supported Now.") + self.condition_offset = condition_offset + self.register_to_config(condition_offset=condition_offset) + + def encode_latents(self, image: Image.Image, device, dtype, height, width): + images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype) + # NOTE: .mode() for condition + latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor + if self.latents_offset is not None and self.condition_offset: + return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None] + else: + return latents + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # NOTE: the same as original code + negative_prompt_embeds = torch.zeros_like(image_embeddings) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + image: Union[Image.Image, List[Image.Image], torch.FloatTensor], + height: Optional[int] = 1024, + width: Optional[int] = 1024, + height_cond: Optional[int] = 512, + width_cond: Optional[int] = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + + Examples: + + ```py + from diffusers import StableDiffusionImageVariationPipeline + from PIL import Image + from io import BytesIO + import requests + + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + "lambdalabs/sd-image-variations-diffusers", revision="v2.0" + ) + pipe = pipe.to("cuda") + + url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200" + + response = requests.get(url) + image = Image.open(BytesIO(response.content)).convert("RGB") + + out = pipe(image, num_images_per_prompt=3, guidance_scale=15) + out["images"][0].save("result.jpg") + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, Image.Image): + batch_size = 1 + elif len(image) == 1: + image = image[0] + batch_size = 1 + else: + raise NotImplementedError() + # elif isinstance(image, list): + # batch_size = len(image) + # else: + # batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + emb_image = image + + image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance) + cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond) + cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents + image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values + if do_classifier_free_guidance: + image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.out_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=cond_latents, noisy_condition_input=False, cond_pixels_clip=image_pixels).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + self.maybe_free_model_hooks() + + if self.latents_offset is not None: + latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None] + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + +if __name__ == "__main__": + pass diff --git a/custum_3d_diffusion/modules.py b/custum_3d_diffusion/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..85a0af49c8dff30e64f9f5c1ba94ff697702523a --- /dev/null +++ b/custum_3d_diffusion/modules.py @@ -0,0 +1,14 @@ +__modules__ = {} + +def register(name): + def decorator(cls): + __modules__[name] = cls + return cls + + return decorator + + +def find(name): + return __modules__[name] + +from custum_3d_diffusion.trainings import base, image2mvimage_trainer, image2image_trainer diff --git a/custum_3d_diffusion/trainings/__init__.py b/custum_3d_diffusion/trainings/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custum_3d_diffusion/trainings/base.py b/custum_3d_diffusion/trainings/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a4ce969190b788b34f9c777b30413b5d6d79a349 --- /dev/null +++ b/custum_3d_diffusion/trainings/base.py @@ -0,0 +1,208 @@ +import torch +from accelerate import Accelerator +from accelerate.logging import MultiProcessAdapter +from dataclasses import dataclass, field +from typing import Optional, Union +from datasets import load_dataset +import json +import abc +from diffusers.utils import make_image_grid +import numpy as np +import wandb + +from custum_3d_diffusion.trainings.utils import load_config +from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig + +class BasicTrainer(torch.nn.Module, abc.ABC): + accelerator: Accelerator + logger: MultiProcessAdapter + unet: ConfigurableUNet2DConditionModel + train_dataloader: torch.utils.data.DataLoader + test_dataset: torch.utils.data.Dataset + attn_config: AttnConfig + + @dataclass + class TrainerConfig: + trainer_name: str = "basic" + pretrained_model_name_or_path: str = "" + + attn_config: dict = field(default_factory=dict) + dataset_name: str = "" + dataset_config_name: Optional[str] = None + resolution: str = "1024" + dataloader_num_workers: int = 4 + pair_sampler_group_size: int = 1 + num_views: int = 4 + + max_train_steps: int = -1 # -1 means infinity, otherwise [0, max_train_steps) + training_step_interval: int = 1 # train on step i*interval, stop at max_train_steps + max_train_samples: Optional[int] = None + seed: Optional[int] = None # For dataset related operations and validation stuff + train_batch_size: int = 1 + + validation_interval: int = 5000 + debug: bool = False + + cfg: TrainerConfig # only enable_xxx is used + + def __init__( + self, + accelerator: Accelerator, + logger: MultiProcessAdapter, + unet: ConfigurableUNet2DConditionModel, + config: Union[dict, str], + weight_dtype: torch.dtype, + index: int, + ): + super().__init__() + self.index = index # index in all trainers + self.accelerator = accelerator + self.logger = logger + self.unet = unet + self.weight_dtype = weight_dtype + self.ext_logs = {} + self.cfg = load_config(self.TrainerConfig, config) + self.attn_config = load_config(AttnConfig, self.cfg.attn_config) + self.test_dataset = None + self.validate_trainer_config() + self.configure() + + def get_HW(self): + resolution = json.loads(self.cfg.resolution) + if isinstance(resolution, int): + H = W = resolution + elif isinstance(resolution, list): + H, W = resolution + return H, W + + def unet_update(self): + self.unet.update_config(self.attn_config) + + def validate_trainer_config(self): + pass + + def is_train_finished(self, current_step): + assert isinstance(self.cfg.max_train_steps, int) + return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps + + def next_train_step(self, current_step): + if self.is_train_finished(current_step): + return None + return current_step + self.cfg.training_step_interval + + @classmethod + def make_image_into_grid(cls, all_imgs, rows=2, columns=2): + catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)] + return make_image_grid(catted, rows=1, cols=len(catted)) + + def configure(self) -> None: + pass + + @abc.abstractmethod + def init_shared_modules(self, shared_modules: dict) -> dict: + pass + + def load_dataset(self): + dataset = load_dataset( + self.cfg.dataset_name, + self.cfg.dataset_config_name, + trust_remote_code=True + ) + return dataset + + @abc.abstractmethod + def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader: + """Both init train_dataloader and test_dataset, but returns train_dataloader only""" + pass + + @abc.abstractmethod + def forward_step( + self, + *args, + **kwargs + ) -> torch.Tensor: + """ + input a batch + return a loss + """ + self.unet_update() + pass + + @abc.abstractmethod + def construct_pipeline(self, shared_modules, unet): + pass + + @abc.abstractmethod + def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: + """ + For inference time forward. + """ + pass + + @abc.abstractmethod + def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: + pass + + def do_validation( + self, + shared_modules, + unet, + global_step, + ): + self.unet_update() + self.logger.info("Running validation... ") + pipeline = self.construct_pipeline(shared_modules, unet) + pipeline.set_progress_bar_config(disable=True) + titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.]) + for tracker in self.accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") + elif tracker.name == "wandb": + [image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] # inplace operation + tracker.log({"validation": [ + wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg") + for i, image in enumerate(images)]}) + else: + self.logger.warn(f"image logging not implemented for {tracker.name}") + del pipeline + torch.cuda.empty_cache() + return images + + + @torch.no_grad() + def log_validation( + self, + shared_modules, + unet, + global_step, + force=False + ): + if self.accelerator.is_main_process: + for tracker in self.accelerator.trackers: + if tracker.name == "wandb": + tracker.log(self.ext_logs) + self.ext_logs = {} + if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force: + self.unet_update() + if self.accelerator.is_main_process: + self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step) + + def save_model(self, unwrap_unet, shared_modules, save_dir): + if self.accelerator.is_main_process: + pipeline = self.construct_pipeline(shared_modules, unwrap_unet) + pipeline.save_pretrained(save_dir) + self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}") + + def save_debug_info(self, save_name="debug", **kwargs): + if self.cfg.debug: + to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()} + import pickle + import os + if os.path.exists(f"{save_name}.pkl"): + for i in range(100): + if not os.path.exists(f"{save_name}_v{i}.pkl"): + save_name = f"{save_name}_v{i}" + break + with open(f"{save_name}.pkl", "wb") as f: + pickle.dump(to_saves, f) \ No newline at end of file diff --git a/custum_3d_diffusion/trainings/config_classes.py b/custum_3d_diffusion/trainings/config_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e54921d1c648a1ac88c109873419b27a43e015 --- /dev/null +++ b/custum_3d_diffusion/trainings/config_classes.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass, field +from typing import List, Optional + + +@dataclass +class TrainerSubConfig: + trainer_type: str = "" + trainer: dict = field(default_factory=dict) + + +@dataclass +class ExprimentConfig: + trainers: List[dict] = field(default_factory=lambda: []) + init_config: dict = field(default_factory=dict) + pretrained_model_name_or_path: str = "" + pretrained_unet_state_dict_path: str = "" + # expriments related parameters + linear_beta_schedule: bool = False + zero_snr: bool = False + prediction_type: Optional[str] = None + seed: Optional[int] = None + max_train_steps: int = 1000000 + gradient_accumulation_steps: int = 1 + learning_rate: float = 1e-4 + lr_scheduler: str = "constant" + lr_warmup_steps: int = 500 + use_8bit_adam: bool = False + adam_beta1: float = 0.9 + adam_beta2: float = 0.999 + adam_weight_decay: float = 1e-2 + adam_epsilon: float = 1e-08 + max_grad_norm: float = 1.0 + mixed_precision: Optional[str] = None # ["no", "fp16", "bf16", "fp8"] + skip_training: bool = False + debug: bool = False \ No newline at end of file diff --git a/custum_3d_diffusion/trainings/image2image_trainer.py b/custum_3d_diffusion/trainings/image2image_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5f22c8efefb0f8b8c58cf4a301567919f241222c --- /dev/null +++ b/custum_3d_diffusion/trainings/image2image_trainer.py @@ -0,0 +1,86 @@ +import json +import torch +from diffusers import EulerAncestralDiscreteScheduler, DDPMScheduler +from dataclasses import dataclass + +from custum_3d_diffusion.modules import register +from custum_3d_diffusion.trainings.image2mvimage_trainer import Image2MVImageTrainer +from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2img import StableDiffusionImageCustomPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +def get_HW(resolution): + if isinstance(resolution, str): + resolution = json.loads(resolution) + if isinstance(resolution, int): + H = W = resolution + elif isinstance(resolution, list): + H, W = resolution + return H, W + + +@register("image2image_trainer") +class Image2ImageTrainer(Image2MVImageTrainer): + """ + Trainer for simple image to multiview images. + """ + @dataclass + class TrainerConfig(Image2MVImageTrainer.TrainerConfig): + trainer_name: str = "image2image" + + cfg: TrainerConfig + + def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor: + raise NotImplementedError() + + def construct_pipeline(self, shared_modules, unet, old_version=False): + MyPipeline = StableDiffusionImageCustomPipeline + pipeline = MyPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, + vae=shared_modules['vae'], + image_encoder=shared_modules['image_encoder'], + feature_extractor=shared_modules['feature_extractor'], + unet=unet, + safety_checker=None, + torch_dtype=self.weight_dtype, + latents_offset=self.cfg.latents_offset, + noisy_cond_latents=self.cfg.noisy_condition_input, + ) + pipeline.set_progress_bar_config(disable=True) + scheduler_dict = {} + if self.cfg.zero_snr: + scheduler_dict.update(rescale_betas_zero_snr=True) + if self.cfg.linear_beta_schedule: + scheduler_dict.update(beta_schedule='linear') + + pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict) + return pipeline + + def get_forward_args(self): + if self.cfg.seed is None: + generator = None + else: + generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed) + + H, W = get_HW(self.cfg.resolution) + H_cond, W_cond = get_HW(self.cfg.condition_image_resolution) + + forward_args = dict( + num_images_per_prompt=1, + num_inference_steps=20, + height=H, + width=W, + height_cond=H_cond, + width_cond=W_cond, + generator=generator, + ) + if self.cfg.zero_snr: + forward_args.update(guidance_rescale=0.7) + return forward_args + + def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput: + forward_args = self.get_forward_args() + forward_args.update(pipeline_call_kwargs) + return pipeline(**forward_args) + + def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: + raise NotImplementedError() \ No newline at end of file diff --git a/custum_3d_diffusion/trainings/image2mvimage_trainer.py b/custum_3d_diffusion/trainings/image2mvimage_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f1e6117eed017495e663d9bf6f289238a35c5b88 --- /dev/null +++ b/custum_3d_diffusion/trainings/image2mvimage_trainer.py @@ -0,0 +1,139 @@ +import torch +from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature + +import json +from dataclasses import dataclass +from typing import List, Optional + +from custum_3d_diffusion.modules import register +from custum_3d_diffusion.trainings.base import BasicTrainer +from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +def get_HW(resolution): + if isinstance(resolution, str): + resolution = json.loads(resolution) + if isinstance(resolution, int): + H = W = resolution + elif isinstance(resolution, list): + H, W = resolution + return H, W + +@register("image2mvimage_trainer") +class Image2MVImageTrainer(BasicTrainer): + """ + Trainer for simple image to multiview images. + """ + @dataclass + class TrainerConfig(BasicTrainer.TrainerConfig): + trainer_name: str = "image2mvimage" + condition_image_column_name: str = "conditioning_image" + image_column_name: str = "image" + condition_dropout: float = 0. + condition_image_resolution: str = "512" + validation_images: Optional[List[str]] = None + noise_offset: float = 0.1 + max_loss_drop: float = 0. + snr_gamma: float = 5.0 + log_distribution: bool = False + latents_offset: Optional[List[float]] = None + input_perturbation: float = 0. + noisy_condition_input: bool = False # whether to add noise for ref unet input + normal_cls_offset: int = 0 + condition_offset: bool = True + zero_snr: bool = False + linear_beta_schedule: bool = False + + cfg: TrainerConfig + + def configure(self) -> None: + return super().configure() + + def init_shared_modules(self, shared_modules: dict) -> dict: + if 'vae' not in shared_modules: + vae = AutoencoderKL.from_pretrained( + self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype + ) + vae.requires_grad_(False) + vae.to(self.accelerator.device, dtype=self.weight_dtype) + shared_modules['vae'] = vae + if 'image_encoder' not in shared_modules: + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + self.cfg.pretrained_model_name_or_path, subfolder="image_encoder" + ) + image_encoder.requires_grad_(False) + image_encoder.to(self.accelerator.device, dtype=self.weight_dtype) + shared_modules['image_encoder'] = image_encoder + if 'feature_extractor' not in shared_modules: + feature_extractor = CLIPImageProcessor.from_pretrained( + self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor" + ) + shared_modules['feature_extractor'] = feature_extractor + return shared_modules + + def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader: + raise NotImplementedError() + + def loss_rescale(self, loss, timesteps=None): + raise NotImplementedError() + + def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor: + raise NotImplementedError() + + def construct_pipeline(self, shared_modules, unet, old_version=False): + MyPipeline = StableDiffusionImage2MVCustomPipeline + pipeline = MyPipeline.from_pretrained( + self.cfg.pretrained_model_name_or_path, + vae=shared_modules['vae'], + image_encoder=shared_modules['image_encoder'], + feature_extractor=shared_modules['feature_extractor'], + unet=unet, + safety_checker=None, + torch_dtype=self.weight_dtype, + latents_offset=self.cfg.latents_offset, + noisy_cond_latents=self.cfg.noisy_condition_input, + condition_offset=self.cfg.condition_offset, + ) + pipeline.set_progress_bar_config(disable=True) + scheduler_dict = {} + if self.cfg.zero_snr: + scheduler_dict.update(rescale_betas_zero_snr=True) + if self.cfg.linear_beta_schedule: + scheduler_dict.update(beta_schedule='linear') + + pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict) + return pipeline + + def get_forward_args(self): + if self.cfg.seed is None: + generator = None + else: + generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed) + + H, W = get_HW(self.cfg.resolution) + H_cond, W_cond = get_HW(self.cfg.condition_image_resolution) + + sub_img_H = H // 2 + num_imgs = H // sub_img_H * W // sub_img_H + + forward_args = dict( + num_images_per_prompt=num_imgs, + num_inference_steps=50, + height=sub_img_H, + width=sub_img_H, + height_cond=H_cond, + width_cond=W_cond, + generator=generator, + ) + if self.cfg.zero_snr: + forward_args.update(guidance_rescale=0.7) + return forward_args + + def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput: + forward_args = self.get_forward_args() + forward_args.update(pipeline_call_kwargs) + return pipeline(**forward_args) + + def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: + raise NotImplementedError() \ No newline at end of file diff --git a/custum_3d_diffusion/trainings/utils.py b/custum_3d_diffusion/trainings/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58637e155304f2914d4efca97b792976ac9c86f4 --- /dev/null +++ b/custum_3d_diffusion/trainings/utils.py @@ -0,0 +1,25 @@ +from omegaconf import DictConfig, OmegaConf + + +def parse_structured(fields, cfg) -> DictConfig: + scfg = OmegaConf.structured(fields(**cfg)) + return scfg + + +def load_config(fields, config, extras=None): + if extras is not None: + print("Warning! extra parameter in cli is not verified, may cause erros.") + if isinstance(config, str): + cfg = OmegaConf.load(config) + elif isinstance(config, dict): + cfg = OmegaConf.create(config) + elif isinstance(config, DictConfig): + cfg = config + else: + raise NotImplementedError(f"Unsupported config type {type(config)}") + if extras is not None: + cli_conf = OmegaConf.from_cli(extras) + cfg = OmegaConf.merge(cfg, cli_conf) + OmegaConf.resolve(cfg) + assert isinstance(cfg, DictConfig) + return parse_structured(fields, cfg) \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..ca99dd598120d2ab0e7122e39c0aaf1a4acb6ded --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,54 @@ +# get the development image from nvidia cuda 12.1 +FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 + +LABEL name="unique3d" maintainer="unique3d" + +# create workspace folder and set it as working directory +RUN mkdir -p /workspace +WORKDIR /workspace + +# update package lists and install git, wget, vim, libegl1-mesa-dev, and libglib2.0-0 +RUN apt-get update && apt-get install -y build-essential git wget vim libegl1-mesa-dev libglib2.0-0 unzip git-lfs + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends pkg-config libglvnd0 libgl1 libglx0 libegl1 libgles2 libglvnd-dev libgl1-mesa-dev libegl1-mesa-dev libgles2-mesa-dev cmake curl mesa-utils-extra +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 +ENV LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH +ENV PYOPENGL_PLATFORM=egl + +# install conda +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x Miniconda3-latest-Linux-x86_64.sh && \ + ./Miniconda3-latest-Linux-x86_64.sh -b -p /workspace/miniconda3 && \ + rm Miniconda3-latest-Linux-x86_64.sh + +# update PATH environment variable +ENV PATH="/workspace/miniconda3/bin:${PATH}" + +# initialize conda +RUN conda init bash + +# create and activate conda environment +RUN conda create -n unique3d python=3.10 && echo "source activate unique3d" > ~/.bashrc +ENV PATH /workspace/miniconda3/envs/unique3d/bin:$PATH + +RUN conda install Ninja +RUN conda install cuda -c nvidia/label/cuda-12.1.0 -y + +RUN pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 xformers triton --index-url https://download.pytorch.org/whl/cu121 +RUN pip install diffusers==0.27.2 + +RUN git clone --depth 1 https://huggingface.co/spaces/Wuvin/Unique3D + +# change the working directory to the repository + +WORKDIR /workspace/Unique3D +# other dependencies +RUN pip install -r requirements.txt + +RUN pip install nvidia-pyindex + +RUN pip install --upgrade nvidia-tensorrt + +RUN pip install spaces + diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d534b30d22c2e0988042486b28c31e06645981e2 --- /dev/null +++ b/docker/README.md @@ -0,0 +1,35 @@ +# Docker setup + +This docker setup is tested on Windows 10. + +make sure you are under this directory yourworkspace/Unique3D/docker + +Build docker image: + +``` +docker build -t unique3d -f Dockerfile . +``` + +Run docker image at the first time: + +``` +docker run -it --name unique3d -p 7860:7860 --gpus all unique3d python app.py +``` + +After first time: +``` +docker start unique3d +docker exec unique3d python app.py +``` + +Stop the container: +``` +docker stop unique3d +``` + +You can find the demo link showing in terminal, such as `https://94fc1ba77a08526e17.gradio.live/` or something similar else (it will be changed after each time to restart the container) to use the demo. + +Some notes: +1. this docker build is using https://huggingface.co/spaces/Wuvin/Unique3D rather than this repo to clone the source. +2. the total built time might take more than one hour. +3. the total size of the built image will be more than 70GB. \ No newline at end of file diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3fa6c529e330d59f457f5791633f2bd8a91139 --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,41 @@ +if __name__ == "__main__": + import os + import sys + sys.path.append(os.curdir) + import torch + torch.set_float32_matmul_precision('medium') + torch.backends.cuda.matmul.allow_tf32 = True + torch.set_grad_enabled(False) + +import fire +import gradio as gr +from app.gradio_3dgen import create_ui as create_3d_ui +from app.all_models import model_zoo + + +_TITLE = '''Unique3D: High-Quality and Efficient 3D Mesh Generation from a Single Image''' +_DESCRIPTION = ''' +[Project page](https://wukailu.github.io/Unique3D/) + +* High-fidelity and diverse textured meshes generated by Unique3D from single-view images. + +* The demo is still under construction, and more features are expected to be implemented soon. +''' + +def launch(): + model_zoo.init_models() + + with gr.Blocks( + title=_TITLE, + theme=gr.themes.Monochrome(), + ) as demo: + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown('# ' + _TITLE) + gr.Markdown(_DESCRIPTION) + create_3d_ui("wkl") + + demo.queue().launch(share=True) + +if __name__ == '__main__': + fire.Fire(launch) diff --git a/install_windows_win_py311_cu121.bat b/install_windows_win_py311_cu121.bat new file mode 100644 index 0000000000000000000000000000000000000000..65cd649e05a32973020371b86c51dd8d0664ed57 --- /dev/null +++ b/install_windows_win_py311_cu121.bat @@ -0,0 +1,34 @@ +@echo off + +set "triton_whl=%~dp0\triton-2.1.0-cp311-cp311-win_amd64.whl" + +echo Starting to install Unique3D... + +echo Installing torch, xformers, etc + +pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu121 + +echo Installing triton + +pip install "%triton_whl%" + +pip install Ninja + +pip install diffusers==0.27.2 + +pip install grpcio werkzeug tensorboard-data-server + +pip install -r requirements-win-py311-cu121.txt + +echo Removing default onnxruntime and onnxruntime-gpu + +pip uninstall onnxruntime +pip uninstall onnxruntime-gpu + +echo Installing correct version onnxruntime-gpu for cuda 12.1 + +pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ + +echo Install Finished. Press any key to continue... + +pause \ No newline at end of file diff --git a/mesh_reconstruction/func.py b/mesh_reconstruction/func.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ba322b996bd3e78f0f56cc0b1367b9e780dbff --- /dev/null +++ b/mesh_reconstruction/func.py @@ -0,0 +1,133 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import torch +import numpy as np +import trimesh +from typing import Tuple + +def to_numpy(*args): + def convert(a): + if isinstance(a,torch.Tensor): + return a.detach().cpu().numpy() + assert a is None or isinstance(a,np.ndarray) + return a + + return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args) + +def laplacian( + num_verts:int, + edges: torch.Tensor #E,2 + ) -> torch.Tensor: #sparse V,V + """create sparse Laplacian matrix""" + V = num_verts + E = edges.shape[0] + + #adjacency matrix, + idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T # (2, 2*E) + ones = torch.ones(2*E, dtype=torch.float32, device=edges.device) + A = torch.sparse.FloatTensor(idx, ones, (V, V)) + + #degree matrix + deg = torch.sparse.sum(A, dim=1).to_dense() + idx = torch.arange(V, device=edges.device) + idx = torch.stack([idx, idx], dim=0) + D = torch.sparse.FloatTensor(idx, deg, (V, V)) + + return D - A + +def _translation(x, y, z, device): + return torch.tensor([[1., 0, 0, x], + [0, 1, 0, y], + [0, 0, 1, z], + [0, 0, 0, 1]],device=device) #4,4 + +def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True): + """ + see https://blog.csdn.net/wodownload2/article/details/85069240/ + """ + if l is None: + l = -r + if t is None: + t = r + if b is None: + b = -t + p = torch.zeros([4,4],device=device) + p[0,0] = 2*n/(r-l) + p[0,2] = (r+l)/(r-l) + p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1) + p[1,2] = (t+b)/(t-b) + p[2,2] = -(f+n)/(f-n) + p[2,3] = -(2*f*n)/(f-n) + p[3,2] = -1 + return p #4,4 + +def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True): + if l is None: + l = -r + if t is None: + t = r + if b is None: + b = -t + o = torch.zeros([4,4],device=device) + o[0,0] = 2/(r-l) + o[0,3] = -(r+l)/(r-l) + o[1,1] = 2/(t-b) * (-1 if flip_y else 1) + o[1,3] = -(t+b)/(t-b) + o[2,2] = -2/(f-n) + o[2,3] = -(f+n)/(f-n) + o[3,3] = 1 + return o #4,4 + +def make_star_cameras(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'): + if r is None: + r = 1/distance + A = az_count + P = pol_count + C = A * P + + phi = torch.arange(0,A) * (2*torch.pi/A) + phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone() + phi_rot[:,0,2,2] = phi.cos() + phi_rot[:,0,2,0] = -phi.sin() + phi_rot[:,0,0,2] = phi.sin() + phi_rot[:,0,0,0] = phi.cos() + + theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2 + theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone() + theta_rot[0,:,1,1] = theta.cos() + theta_rot[0,:,1,2] = -theta.sin() + theta_rot[0,:,2,1] = theta.sin() + theta_rot[0,:,2,2] = theta.cos() + + mv = torch.empty((C,4,4), device=device) + mv[:] = torch.eye(4, device=device) + mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3) + mv = _translation(0, 0, -distance, device) @ mv + + return mv, _projection(r,device) + +def make_star_cameras_orthographic(az_count,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'): + mv, _ = make_star_cameras(az_count,pol_count,distance,r,image_size,device) + if r is None: + r = 1 + return mv, _orthographic(r,device) + +def make_sphere(level:int=2,radius=1.,device='cuda') -> Tuple[torch.Tensor,torch.Tensor]: + sphere = trimesh.creation.icosphere(subdivisions=level, radius=1.0, color=None) + vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius + faces = torch.tensor(sphere.faces, device=device, dtype=torch.long) + return vertices,faces + +from pytorch3d.renderer import ( + FoVOrthographicCameras, + look_at_view_transform, +) + +def get_camera(R, T, focal_length=1 / (2**0.5)): + focal_length = 1 / focal_length + camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length) + return camera + +def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1): + R, T = look_at_view_transform(dist, 0, azim_list) + focal_length = 1 / focal + return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device) diff --git a/mesh_reconstruction/opt.py b/mesh_reconstruction/opt.py new file mode 100644 index 0000000000000000000000000000000000000000..6a51ef20b2cae93f14b1781ba4b91b48ed1ae8d0 --- /dev/null +++ b/mesh_reconstruction/opt.py @@ -0,0 +1,190 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import time +import torch +import torch_scatter +from typing import Tuple +from mesh_reconstruction.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges + +@torch.no_grad() +def remesh( + vertices_etc:torch.Tensor, #V,D + faces:torch.Tensor, #F,3 long + min_edgelen:torch.Tensor, #V + max_edgelen:torch.Tensor, #V + flip:bool, + max_vertices=1e6 + ): + + # dummies + vertices_etc,faces = prepend_dummies(vertices_etc,faces) + vertices = vertices_etc[:,:3] #V,3 + nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device) + min_edgelen = torch.concat((nan_tensor,min_edgelen)) + max_edgelen = torch.concat((nan_tensor,max_edgelen)) + + # collapse + edges,face_to_edge = calc_edges(faces) #E,2 F,3 + edge_length = calc_edge_length(vertices,edges) #E + face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3 + vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3 + face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5) + shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0 + priority = face_collapse.float() + shortness + vertices_etc,faces = collapse_edges(vertices_etc,faces,edges,priority) + + # split + if vertices.shape[0] max_edgelen[edges].mean(dim=-1) + vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False) + + vertices_etc,faces = pack(vertices_etc,faces) + vertices = vertices_etc[:,:3] + + if flip: + edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3 + flip_edges(vertices,faces,edges,edge_to_face,with_border=False) + + return remove_dummies(vertices_etc,faces) + +def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int): + """lerp with adam's bias correction""" + c_prev = 1-weight**(step-1) + c = 1-weight**step + a_weight = weight*c_prev/c + b_weight = (1-weight)/c + a.mul_(a_weight).add_(b, alpha=b_weight) + + +class MeshOptimizer: + """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh().""" + + def __init__(self, + vertices:torch.Tensor, #V,3 + faces:torch.Tensor, #F,3 + lr=0.3, #learning rate + betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu + gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing) + nu_ref=0.3, #reference velocity for edge length controller + edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length + edge_len_tol=.5, #edge length tolerance for split and collapse + gain=.2, #gain value for edge length controller + laplacian_weight=.02, #for laplacian smoothing/regularization + ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0]) + grad_lim=10., #gradients are clipped to m1.abs()*grad_lim + remesh_interval=1, #larger intervals are faster but with worse mesh quality + local_edgelen=True, #set to False to use a global scalar reference edge length instead + ): + self._vertices = vertices + self._faces = faces + self._lr = lr + self._betas = betas + self._gammas = gammas + self._nu_ref = nu_ref + self._edge_len_lims = edge_len_lims + self._edge_len_tol = edge_len_tol + self._gain = gain + self._laplacian_weight = laplacian_weight + self._ramp = ramp + self._grad_lim = grad_lim + self._remesh_interval = remesh_interval + self._local_edgelen = local_edgelen + self._step = 0 + + V = self._vertices.shape[0] + # prepare continuous tensor for all vertex-based data + self._vertices_etc = torch.zeros([V,9],device=vertices.device) + self._split_vertices_etc() + self.vertices.copy_(vertices) #initialize vertices + self._vertices.requires_grad_() + self._ref_len.fill_(edge_len_lims[1]) + + @property + def vertices(self): + return self._vertices + + @property + def faces(self): + return self._faces + + def _split_vertices_etc(self): + self._vertices = self._vertices_etc[:,:3] + self._m2 = self._vertices_etc[:,3] + self._nu = self._vertices_etc[:,4] + self._m1 = self._vertices_etc[:,5:8] + self._ref_len = self._vertices_etc[:,8] + + with_gammas = any(g!=0 for g in self._gammas) + self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3] + + def zero_grad(self): + self._vertices.grad = None + + @torch.no_grad() + def step(self): + + eps = 1e-8 + + self._step += 1 + + # spatial smoothing + edges,_ = calc_edges(self._faces) #E,2 + E = edges.shape[0] + edge_smooth = self._smooth[edges] #E,2,S + neighbor_smooth = torch.zeros_like(self._smooth) #V,S + torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth) + + #apply optional smoothing of m1,m2,nu + if self._gammas[0]: + self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0]) + if self._gammas[1]: + self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1]) + if self._gammas[2]: + self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2]) + + #add laplace smoothing to gradients + laplace = self._vertices - neighbor_smooth[:,:3] + grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight) + + #gradient clipping + if self._step>1: + grad_lim = self._m1.abs().mul_(self._grad_lim) + grad.clamp_(min=-grad_lim,max=grad_lim) + + # moment updates + lerp_unbiased(self._m1, grad, self._betas[0], self._step) + lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step) + + velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3 + speed = velocity.norm(dim=-1) #V + + if self._betas[2]: + lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V + else: + self._nu.copy_(speed) #V + + # update vertices + ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp) + self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr) + + # update target edge length + if self._step % self._remesh_interval == 0: + if self._local_edgelen: + len_change = (1 + (self._nu - self._nu_ref) * self._gain) + else: + len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain) + self._ref_len *= len_change + self._ref_len.clamp_(*self._edge_len_lims) + + def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]: + min_edge_len = self._ref_len * (1 - self._edge_len_tol) + max_edge_len = self._ref_len * (1 + self._edge_len_tol) + + self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e6) + + self._split_vertices_etc() + self._vertices.requires_grad_() + + return self._vertices, self._faces diff --git a/mesh_reconstruction/recon.py b/mesh_reconstruction/recon.py new file mode 100644 index 0000000000000000000000000000000000000000..a47ae92068a717334995e96e5d20bb513a9090e5 --- /dev/null +++ b/mesh_reconstruction/recon.py @@ -0,0 +1,59 @@ +from tqdm import tqdm +from PIL import Image +import numpy as np +import torch +from typing import List +from mesh_reconstruction.remesh import calc_vertex_normals +from mesh_reconstruction.opt import MeshOptimizer +from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d +from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer +from scripts.utils import to_py3d_mesh, init_target + +def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1): + vertices, faces = vertices.to("cuda"), faces.to("cuda") + assert len(pils) == 4 + mv,proj = make_star_cameras_orthographic(4, 1) + renderer = NormalsRenderer(mv,proj,list(pils[0].size)) + # cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0) + # renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda") + + target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s + # 1. no rotate + target_images = target_images[[0, 3, 2, 1]] + + # 2. init from coarse mesh + opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len)) + + vertices = opt.vertices + + mask = target_images[..., -1] < 0.5 + + for i in tqdm(range(steps)): + opt.zero_grad() + opt._lr *= decay + normals = calc_vertex_normals(vertices,faces) + images = renderer.render(vertices,normals,faces) + + loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean() + + t_mask = images[..., -1] > 0.5 + loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean() + loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean() + + loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight + + # out of box + loss_oob = (vertices.abs() > 0.99).float().mean() * 10 + loss = loss + loss_oob + + loss.backward() + opt.step() + + vertices,faces = opt.remesh(poisson=False) + + vertices, faces = vertices.detach(), faces.detach() + + if return_mesh: + return to_py3d_mesh(vertices, faces) + else: + return vertices, faces diff --git a/mesh_reconstruction/refine.py b/mesh_reconstruction/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f87df43cee2b5bdedf2ccc846f2deeaab1e1ca --- /dev/null +++ b/mesh_reconstruction/refine.py @@ -0,0 +1,79 @@ +from tqdm import tqdm +from PIL import Image +import torch +from typing import List +from mesh_reconstruction.remesh import calc_vertex_normals +from mesh_reconstruction.opt import MeshOptimizer +from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d +from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer +from scripts.project_mesh import multiview_color_projection, get_cameras_list +from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target + +def run_mesh_refine(vertices, faces, pils: List[Image.Image], steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True): + if process_inputs: + vertices = vertices * 2 / 1.35 + vertices[..., [0, 2]] = - vertices[..., [0, 2]] + + poission_steps = [] + + assert len(pils) == 4 + mv,proj = make_star_cameras_orthographic(4, 1) + renderer = NormalsRenderer(mv,proj,list(pils[0].size)) + # cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0) + # renderer = Pytorch3DNormalsRenderer(cameras, list(pils[0].size), device="cuda") + + target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s + # 1. no rotate + target_images = target_images[[0, 3, 2, 1]] + + # 2. init from coarse mesh + opt = MeshOptimizer(vertices,faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02) + + vertices = opt.vertices + alpha_init = None + + mask = target_images[..., -1] < 0.5 + + for i in tqdm(range(steps)): + opt.zero_grad() + opt._lr *= decay + normals = calc_vertex_normals(vertices,faces) + images = renderer.render(vertices,normals,faces) + if alpha_init is None: + alpha_init = images.detach() + + if i < update_warmup or i % update_normal_interval == 0: + with torch.no_grad(): + py3d_mesh = to_py3d_mesh(vertices, faces, normals) + cameras = get_cameras_list(azim_list = [0, 90, 180, 270], device=vertices.device, focal=1.) + _, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=[2.0, 0.8, 1.0, 0.8], confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear')) + target_normal = target_normal * 2 - 1 + target_normal = torch.nn.functional.normalize(target_normal, dim=-1) + debug_images = renderer.render(vertices,target_normal,faces) + + d_mask = images[..., -1] > 0.5 + loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean() + + loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean() + + loss = loss_debug_l2 + loss_alpha_target_mask_l2 + + # out of box + loss_oob = (vertices.abs() > 0.99).float().mean() * 10 + loss = loss + loss_oob + + loss.backward() + opt.step() + + vertices,faces = opt.remesh(poisson=(i in poission_steps)) + + vertices, faces = vertices.detach(), faces.detach() + + if process_outputs: + vertices = vertices / 2 * 1.35 + vertices[..., [0, 2]] = - vertices[..., [0, 2]] + + if return_mesh: + return to_py3d_mesh(vertices, faces) + else: + return vertices, faces diff --git a/mesh_reconstruction/remesh.py b/mesh_reconstruction/remesh.py new file mode 100644 index 0000000000000000000000000000000000000000..2e2faa83cf5afc98b571254c6cee3893a2396a30 --- /dev/null +++ b/mesh_reconstruction/remesh.py @@ -0,0 +1,361 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import torch +import torch.nn.functional as tfunc +import torch_scatter +from typing import Tuple + +def prepend_dummies( + vertices:torch.Tensor, #V,D + faces:torch.Tensor, #F,3 long + )->Tuple[torch.Tensor,torch.Tensor]: + """prepend dummy elements to vertices and faces to enable "masked" scatter operations""" + V,D = vertices.shape + vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0) + faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0) + return vertices,faces + +def remove_dummies( + vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced + faces:torch.Tensor, #F,3 long - first face all zeros + )->Tuple[torch.Tensor,torch.Tensor]: + """remove dummy elements added with prepend_dummies()""" + return vertices[1:],faces[1:]-1 + + +def calc_edges( + faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros + with_edge_to_face: bool = False + ) -> Tuple[torch.Tensor, ...]: + """ + returns Tuple of + - edges E,2 long, 0 for unused, lower vertex index first + - face_to_edge F,3 long + - (optional) edge_to_face shape=E,[left,right],[face,side] + + o-<-----e1 e0,e1...edge, e0-o + """ + + F = faces.shape[0] + + # make full edges, lower vertex index first + face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2 + full_edges = face_edges.reshape(F*3,2) + sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2 + + # make unique edges + edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3) + E = edges.shape[0] + face_to_edge = full_to_unique.reshape(F,3) #F,3 + + if not with_edge_to_face: + return edges, face_to_edge + + is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3 + edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2 + scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2 + edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2 + edge_to_face[0] = 0 + return edges, face_to_edge, edge_to_face + +def calc_edge_length( + vertices:torch.Tensor, #V,3 first may be dummy + edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused + )->torch.Tensor: #E + + full_vertices = vertices[edges] #E,2,3 + a,b = full_vertices.unbind(dim=1) #E,3 + return torch.norm(a-b,p=2,dim=-1) + +def calc_face_normals( + vertices:torch.Tensor, #V,3 first vertex may be unreferenced + faces:torch.Tensor, #F,3 long, first face may be all zero + normalize:bool=False, + )->torch.Tensor: #F,3 + """ + n + | + c0 corners ordered counterclockwise when + / \ looking onto surface (in neg normal direction) + c1---c2 + """ + full_vertices = vertices[faces] #F,C=3,3 + v0,v1,v2 = full_vertices.unbind(dim=1) #F,3 + face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3 + if normalize: + face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1) + return face_normals #F,3 + +def calc_vertex_normals( + vertices:torch.Tensor, #V,3 first vertex may be unreferenced + faces:torch.Tensor, #F,3 long, first face may be all zero + face_normals:torch.Tensor=None, #F,3, not normalized + )->torch.Tensor: #F,3 + + F = faces.shape[0] + + if face_normals is None: + face_normals = calc_face_normals(vertices,faces) + + vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3 + vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3)) + vertex_normals = vertex_normals.sum(dim=1) #V,3 + return tfunc.normalize(vertex_normals, eps=1e-6, dim=1) + +def calc_face_ref_normals( + faces:torch.Tensor, #F,3 long, 0 for unused + vertex_normals:torch.Tensor, #V,3 first unused + normalize:bool=False, + )->torch.Tensor: #F,3 + """calculate reference normals for face flip detection""" + full_normals = vertex_normals[faces] #F,C=3,3 + ref_normals = full_normals.sum(dim=1) #F,3 + if normalize: + ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1) + return ref_normals + +def pack( + vertices:torch.Tensor, #V,3 first unused and nan + faces:torch.Tensor, #F,3 long, 0 for unused + )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused + """removes unused elements in vertices and faces""" + V = vertices.shape[0] + + # remove unused faces + used_faces = faces[:,0]!=0 + used_faces[0] = True + faces = faces[used_faces] #sync + + # remove unused vertices + used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device) + used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add') + used_vertices = used_vertices.any(dim=1) + used_vertices[0] = True + vertices = vertices[used_vertices] #sync + + # update used faces + ind = torch.zeros(V,dtype=torch.long,device=vertices.device) + V1 = used_vertices.sum() + ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync + faces = ind[faces] + + return vertices,faces + +def split_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + face_to_edge:torch.Tensor, #F,3 long 0 for unused + splits, #E bool + pack_faces:bool=True, + )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces) + + # c2 c2 c...corners = faces + # . . . . s...side_vert, 0 means no split + # . . .N2 . S...shrunk_face + # . . . . Ni...new_faces + # s2 s1 s2|c2...s1|c1 + # . . . . . + # . . . S . . + # . . . . N1 . + # c0...(s0=0)....c1 s0|c0...........c1 + # + # pseudo-code: + # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2] + # split = side_vert!=0 example:[False,True,True] + # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0] + # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0] + # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1] + + V = vertices.shape[0] + F = faces.shape[0] + S = splits.sum().item() #sync + + if S==0: + return vertices,faces + + edge_vert = torch.zeros_like(splits, dtype=torch.long) #E + edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync + side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split + split_edges = edges[splits] #S sync + + #vertices + split_vertices = vertices[split_edges].mean(dim=1) #S,3 + vertices = torch.concat((vertices,split_vertices),dim=0) + + #faces + side_split = side_vert!=0 #F,3 + shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split + new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3 + faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3 + if pack_faces: + mask = faces[:,0]!=0 + mask[0] = True + faces = faces[mask] #F',3 sync + + return vertices,faces + +def collapse_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + priorities:torch.Tensor, #E float + stable:bool=False, #only for unit testing + )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces) + + V = vertices.shape[0] + + # check spacing + _,order = priorities.sort(stable=stable) #E + rank = torch.zeros_like(order) + rank[order] = torch.arange(0,len(rank),device=rank.device) + vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V + edge_rank = rank #E + for i in range(3): + torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank) + edge_rank,_ = vert_rank[edges].max(dim=-1) #E + candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2 + + # check connectivity + vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V + vert_connections[candidates[:,0]] = 1 #start + edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start + vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start + vert_connections[candidates] = 0 #clear start and end + edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start + vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start + collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end + + # mean vertices + vertices[collapses[:,0]] = vertices[collapses].mean(dim=1) + + # update faces + dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V + dest[collapses[:,1]] = dest[collapses[:,0]] + faces = dest[faces] #F,3 + c0,c1,c2 = faces.unbind(dim=-1) + collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2) + faces[collapsed] = 0 + + return vertices,faces + +def calc_face_collapses( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, 0 for unused + edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first + face_to_edge:torch.Tensor, #F,3 long 0 for unused + edge_length:torch.Tensor, #E + face_normals:torch.Tensor, #F,3 + vertex_normals:torch.Tensor, #V,3 first unused + min_edge_length:torch.Tensor=None, #V + area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio + shortest_probability = 0.8 + )->torch.Tensor: #E edges to collapse + + E = edges.shape[0] + F = faces.shape[0] + + # face flips + ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3 + face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F + + # small faces + if min_edge_length is not None: + min_face_length = min_edge_length[faces].mean(dim=-1) #F + min_area = min_face_length**2 * area_ratio #F + face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F + face_collapses[0] = False + + # faces to edges + face_length = edge_length[face_to_edge] #F,3 + + if shortest_probability<1: + #select shortest edge with shortest_probability chance + randlim = round(2/(1-shortest_probability)) + rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face + sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3 + local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None]) + else: + local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face + + edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index + edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device) + edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long()) + + return edge_collapses.bool() + +def flip_edges( + vertices:torch.Tensor, #V,3 first unused + faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused + edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first + edge_to_face:torch.Tensor, #E,[left,right],[face,side] + with_border:bool=True, #handle border edges (D=4 instead of D=6) + with_normal_check:bool=True, #check face normal flips + stable:bool=False, #only for unit testing + ): + V = vertices.shape[0] + E = edges.shape[0] + device=vertices.device + vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long + vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add') + neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner + neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2 + edge_is_inside = neighbors.all(dim=-1) #E + + if with_border: + # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices + # need to use float for masks in order to use scatter(reduce='multiply') + vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float + src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float + vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply') + vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long + vertex_degree -= 2 * vertex_is_inside #V long + + neighbor_degrees = vertex_degree[neighbors] #E,LR=2 + edge_degrees = vertex_degree[edges] #E,2 + # + # loss = Sum_over_affected_vertices((new_degree-6)**2) + # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2) + # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2) + # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree)) + # + loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E + candidates = torch.logical_and(loss_change<0, edge_is_inside) #E + loss_change = loss_change[candidates] #E' + if loss_change.shape[0]==0: + return + + edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4 + _,order = loss_change.sort(descending=True, stable=stable) #E' + rank = torch.zeros_like(order) + rank[order] = torch.arange(0,len(rank),device=rank.device) + vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4 + torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank) + vertex_rank,_ = vertex_rank.max(dim=-1) #V + neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E' + flip = rank==neighborhood_rank #E' + + if with_normal_check: + # cl-<-----e1 e0,e1...edge, e0-cr + v = vertices[edges_neighbors] #E",4,3 + v = v - v[:,0:1] #make relative to e0 + e1 = v[:,1] + cl = v[:,2] + cr = v[:,3] + n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors + flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face + flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face + + flip_edges_neighbors = edges_neighbors[flip] #E",4 + flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2 + flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3 + faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3)) diff --git a/mesh_reconstruction/render.py b/mesh_reconstruction/render.py new file mode 100644 index 0000000000000000000000000000000000000000..23707d8260f5ecc0bf7aa274e4196f8860f48b74 --- /dev/null +++ b/mesh_reconstruction/render.py @@ -0,0 +1,163 @@ +# modified from https://github.com/Profactor/continuous-remeshing +import nvdiffrast.torch as dr +import torch +from typing import Tuple + +def _warmup(glctx, device=None): + device = 'cuda' if device is None else device + #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59 + def tensor(*args, **kwargs): + return torch.tensor(*args, device=device, **kwargs) + pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32) + tri = tensor([[0, 1, 2]], dtype=torch.int32) + dr.rasterize(glctx, pos, tri, resolution=[256, 256]) + +glctx = dr.RasterizeGLContext(output_db=False, device="cuda") + +class NormalsRenderer: + + _glctx:dr.RasterizeGLContext = None + + def __init__( + self, + mv: torch.Tensor, #C,4,4 + proj: torch.Tensor, #C,4,4 + image_size: Tuple[int,int], + mvp = None, + device=None, + ): + if mvp is None: + self._mvp = proj @ mv #C,4,4 + else: + self._mvp = mvp + self._image_size = image_size + self._glctx = glctx + _warmup(self._glctx, device) + + def render(self, + vertices: torch.Tensor, #V,3 float + normals: torch.Tensor, #V,3 float in [-1, 1] + faces: torch.Tensor, #F,3 long + ) ->torch.Tensor: #C,H,W,4 + + V = vertices.shape[0] + faces = faces.type(torch.int32) + vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4 + vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4 + rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4 + vert_col = (normals+1)/2 #V,3 + col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3 + alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1 + col = torch.concat((col,alpha),dim=-1) #C,H,W,4 + col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4 + return col #C,H,W,4 + + + +from pytorch3d.structures import Meshes +from pytorch3d.renderer.mesh.shader import ShaderBase +from pytorch3d.renderer import ( + RasterizationSettings, + MeshRendererWithFragments, + TexturesVertex, + MeshRasterizer, + BlendParams, + FoVOrthographicCameras, + look_at_view_transform, + hard_rgb_blend, +) + +class VertexColorShader(ShaderBase): + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + blend_params = kwargs.get("blend_params", self.blend_params) + texels = meshes.sample_textures(fragments) + return hard_rgb_blend(texels, fragments, blend_params) + +def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"): + if len(mesh) != len(cameras): + if len(cameras) % len(mesh) == 0: + mesh = mesh.extend(len(cameras)) + else: + raise NotImplementedError() + + # render requires everything in float16 or float32 + input_dtype = dtype + blend_params = BlendParams(1e-4, 1e-4, bkgd) + + # Define the settings for rasterization and shading + raster_settings = RasterizationSettings( + image_size=(H, W), + blur_radius=blur_radius, + faces_per_pixel=faces_per_pixel, + clip_barycentric_coords=True, + bin_size=None, + max_faces_per_bin=None, + ) + + # Create a renderer by composing a rasterizer and a shader + # We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used) + renderer = MeshRendererWithFragments( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=raster_settings + ), + shader=VertexColorShader( + device=device, + cameras=cameras, + blend_params=blend_params + ) + ) + + # render RGB and depth, get mask + with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type): + images, _ = renderer(mesh) + return images # BHW4 + +class Pytorch3DNormalsRenderer: # 100 times slower!!! + def __init__(self, cameras, image_size, device): + self.cameras = cameras.to(device) + self._image_size = image_size + self.device = device + + def render(self, + vertices: torch.Tensor, #V,3 float + normals: torch.Tensor, #V,3 float in [-1, 1] + faces: torch.Tensor, #F,3 long + ) ->torch.Tensor: #C,H,W,4 + mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device) + return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device) + +def save_tensor_to_img(tensor, save_dir): + from PIL import Image + import numpy as np + for idx, img in enumerate(tensor): + img = img[..., :3].cpu().numpy() + img = (img * 255).astype(np.uint8) + img = Image.fromarray(img) + img.save(save_dir + f"{idx}.png") + +if __name__ == "__main__": + import sys + import os + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d + cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0) + mv,proj = make_star_cameras_orthographic(4, 1) + resolution = 1024 + renderer1 = NormalsRenderer(mv,proj, [resolution,resolution], device="cuda") + renderer2 = Pytorch3DNormalsRenderer(cameras, [resolution,resolution], device="cuda") + vertices = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[1,0,0]], device="cuda", dtype=torch.float32) + normals = torch.tensor([[-1,-1,-1],[1,-1,-1],[-1,-1,1],[-1,1,-1]], device="cuda", dtype=torch.float32) + faces = torch.tensor([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], device="cuda", dtype=torch.long) + + import time + t0 = time.time() + r1 = renderer1.render(vertices, normals, faces) + print("time r1:", time.time() - t0) + + t0 = time.time() + r2 = renderer2.render(vertices, normals, faces) + print("time r2:", time.time() - t0) + + for i in range(4): + print((r1[i]-r2[i]).abs().mean(), (r1[i]+r2[i]).abs().mean()) \ No newline at end of file diff --git a/requirements-detail.txt b/requirements-detail.txt new file mode 100644 index 0000000000000000000000000000000000000000..cd0bc9a8425cadb2ad81cd178778484f59fcefda --- /dev/null +++ b/requirements-detail.txt @@ -0,0 +1,27 @@ +accelerate==0.29.2 +datasets==2.18.0 +diffusers==0.27.2 +fire==0.6.0 +gradio==4.32.0 +jaxtyping==0.2.29 +numba==0.59.1 +numpy==1.26.4 +nvdiffrast==0.3.1 +omegaconf==2.3.0 +onnxruntime_gpu==1.17.0 +opencv_python==4.9.0.80 +opencv_python_headless==4.9.0.80 +ort_nightly_gpu==1.17.0.dev20240118002 +peft==0.10.0 +Pillow==10.3.0 +pygltflib==1.16.2 +pymeshlab==2023.12.post1 +pytorch3d==0.7.5 +rembg==2.0.56 +torch==2.1.0+cu121 +torch_scatter==2.1.2 +tqdm==4.64.1 +transformers==4.39.3 +trimesh==4.3.0 +typeguard==2.13.3 +wandb==0.16.6 diff --git a/requirements-win-py311-cu121.txt b/requirements-win-py311-cu121.txt new file mode 100644 index 0000000000000000000000000000000000000000..79628f848571caaae2e962c43fb3cb6b11e3978a --- /dev/null +++ b/requirements-win-py311-cu121.txt @@ -0,0 +1,24 @@ +accelerate +datasets +fire +gradio +jaxtyping +numba +numpy +git+https://github.com/NVlabs/nvdiffrast.git +omegaconf>=2.3.0 +opencv_python +opencv_python_headless +ort_nightly_gpu +peft +Pillow +pygltflib +pymeshlab>=2023.12 +git+https://github.com/facebookresearch/pytorch3d.git@stable +rembg[gpu] +torch_scatter +tqdm +transformers +trimesh +typeguard +wandb diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1cf5a2e79cee45c1daeb22e0b5ffaf347e01d641 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +accelerate +datasets +diffusers>=0.26.3 +fire +gradio +jaxtyping +numba +numpy +git+https://github.com/NVlabs/nvdiffrast.git +omegaconf>=2.3.0 +onnxruntime_gpu +opencv_python +opencv_python_headless +ort_nightly_gpu +peft +Pillow +pygltflib +pymeshlab>=2023.12 +git+https://github.com/facebookresearch/pytorch3d.git@stable +rembg[gpu] +#torch>=2.0.1 +torch_scatter +tqdm +transformers +trimesh +typeguard +wandb +xformers diff --git a/scripts/all_typing.py b/scripts/all_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..87b19aaefff18dc04a8a7c185cd9086a27e91c62 --- /dev/null +++ b/scripts/all_typing.py @@ -0,0 +1,42 @@ +# code from https://github.com/threestudio-project + +""" +This module contains type annotations for the project, using +1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects +2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors + +Two types of typing checking can be used: +1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) +2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) +""" + +# Basic types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + NamedTuple, + NewType, + Optional, + Sized, + Tuple, + Type, + TypeVar, + Union, +) + +# Tensor dtype +# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md +from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt + +# Config type +from omegaconf import DictConfig + +# PyTorch Tensor type +from torch import Tensor + +# Runtime type checking decorator +from typeguard import typechecked as typechecker diff --git a/scripts/load_onnx.py b/scripts/load_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..954cd444b988d372314cc1c7983bfc9dca4e998e --- /dev/null +++ b/scripts/load_onnx.py @@ -0,0 +1,48 @@ +import onnxruntime +import torch + +providers = [ + ('TensorrtExecutionProvider', { + 'device_id': 0, + 'trt_max_workspace_size': 8 * 1024 * 1024 * 1024, + 'trt_fp16_enable': True, + 'trt_engine_cache_enable': True, + }), + ('CUDAExecutionProvider', { + 'device_id': 0, + 'arena_extend_strategy': 'kSameAsRequested', + 'gpu_mem_limit': 8 * 1024 * 1024 * 1024, + 'cudnn_conv_algo_search': 'HEURISTIC', + }) +] + +def load_onnx(file_path: str): + assert file_path.endswith(".onnx") + sess_opt = onnxruntime.SessionOptions() + ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers) + return ort_session + + +def load_onnx_caller(file_path: str, single_output=False): + ort_session = load_onnx(file_path) + def caller(*args): + torch_input = isinstance(args[0], torch.Tensor) + if torch_input: + torch_input_dtype = args[0].dtype + torch_input_device = args[0].device + # check all are torch.Tensor and have same dtype and device + assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor" + assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor" + assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor" + args = [arg.cpu().float().numpy() for arg in args] + + ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))} + ort_outs = ort_session.run(None, ort_inputs) + + if torch_input: + ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs] + + if single_output: + return ort_outs[0] + return ort_outs + return caller diff --git a/scripts/mesh_init.py b/scripts/mesh_init.py new file mode 100644 index 0000000000000000000000000000000000000000..5a70af530f71a0c39bfe7f99ecb28c6103bf84ec --- /dev/null +++ b/scripts/mesh_init.py @@ -0,0 +1,132 @@ +from PIL import Image +import torch +import numpy as np +from pytorch3d.structures import Meshes +from pytorch3d.renderer import TexturesVertex +from scripts.utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh +import pymeshlab + +_MAX_THREAD = 8 + +# rgb and depth to mesh +def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"): + pixel_center = 0.5 if use_pixel_centers else 0 + i, j = np.meshgrid( + np.arange(W, dtype=np.float32) + pixel_center, + np.arange(H, dtype=np.float32) + pixel_center, + indexing='xy' + ) + i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device) + + origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3 + directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3 + + return origins, directions + +def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False): + if valid_HWC is None: + valid_HWC = torch.ones_like(pred_HWC).bool() + H, W = rgb_BCHW.shape[-2:] + rgb_BCHW = rgb_BCHW.flip(-2) + pred_HWC = pred_HWC.flip(0) + valid_HWC = valid_HWC.flip(0) + rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device) + verts = rays_o + rays_d * pred_HWC # [H, W, 3] + verts = verts.reshape(-1, 3) # [V, 3] + indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device) + faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1) + # faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1] + faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1] + faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1) + # faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:] + faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:] + faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3), faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)], dim=0) # (F, 3) + colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3) + if is_back: + verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device) + + used_verts = faces.unique() + old_to_new_mapping = torch.zeros_like(verts[..., 0]).long() + old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device) + new_faces = old_to_new_mapping[faces] + mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]])) + return mesh + +def normalmap_to_depthmap(normal_np): + from scripts.normal_to_height_map import estimate_height_map + height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96) + return height + +def transform_back_normal_to_front(normal_pil): + arr = np.array(normal_pil) # in [0, 255] + arr[..., 0] = 255-arr[..., 0] + arr[..., 2] = 255-arr[..., 2] + return Image.fromarray(arr.astype(np.uint8)) + +def calc_w_over_h(normal_pil): + if isinstance(normal_pil, Image.Image): + arr = np.array(normal_pil) + else: + assert isinstance(normal_pil, np.ndarray) + arr = normal_pil + if arr.shape[-1] == 4: + alpha = arr[..., -1] / 255. + alpha[alpha >= 0.5] = 1 + alpha[alpha < 0.5] = 0 + else: + alpha = ~(arr.min(axis=-1) >= 250) + h_min, w_min = np.min(np.where(alpha), axis=1) + h_max, w_max = np.max(np.where(alpha), axis=1) + return (w_max - w_min) / (h_max - h_min) + +def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0): + if is_back: + normal_pil = transform_back_normal_to_front(normal_pil) + normal_img = np.array(normal_pil) + rgb_img = np.array(rgb_pil) + if normal_img.shape[-1] == 4: + valid_HWC = normal_img[..., [3]] / 255 + elif rgb_img.shape[-1] == 4: + valid_HWC = rgb_img[..., [3]] / 255 + else: + raise ValueError("invalid input, either normal or rgb should have alpha channel") + + real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0]) + + heights = normalmap_to_depthmap(normal_img) + rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None] + valid_HWC[valid_HWC < 0.5] = 0 + valid_HWC[valid_HWC >= 0.5] = 1 + valid_HWC = torch.from_numpy(valid_HWC).bool() + if init_type == "std": + # accurate but not stable + pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None] + elif init_type == "thin": + heights = heights - heights.min() + heights = (heights / heights.max() * 0.2) + pred_HWC = torch.from_numpy(heights * scale).float()[..., None] + else: + # stable but not accurate + heights = heights - heights.min() + heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1] + pred_HWC = torch.from_numpy(heights * scale).float()[..., None] + + # set the boarder pixels to 0 height + import cv2 + # edge filter + edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255) + edge = torch.from_numpy(edge).bool()[..., None] + pred_HWC[edge] = 0 + + valid_HWC[pred_HWC < clamp_min] = False + return depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back) + +def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0): + ms = pymeshlab.MeshSet() + ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh") + if simplification > 0: + ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True) + ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True) + if simplification > 0: + ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True) + return meshlab_mesh_to_py3dmesh(ms.current_mesh()) diff --git a/scripts/multiview_inference.py b/scripts/multiview_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..9d620b73fda27607ae8f9ecd4c6299b59f54ca6c --- /dev/null +++ b/scripts/multiview_inference.py @@ -0,0 +1,98 @@ +import os +from PIL import Image +from scripts.mesh_init import build_mesh, calc_w_over_h, fix_border_with_pymeshlab_fast +from scripts.project_mesh import multiview_color_projection +from scripts.refine_lr_to_sr import run_sr_fast +from scripts.utils import simple_clean_mesh +from app.utils import simple_remove, split_image +from app.custom_models.normal_prediction import predict_normals +from mesh_reconstruction.recon import reconstruct_stage1 +from mesh_reconstruction.refine import run_mesh_refine +from scripts.project_mesh import get_cameras_list +from scripts.utils import from_py3d_mesh, to_pyml_mesh +from pytorch3d.structures import Meshes, join_meshes_as_scene +import numpy as np + +def fast_geo(front_normal: Image.Image, back_normal: Image.Image, side_normal: Image.Image, clamp=0., init_type="std"): + import time + if front_normal.mode == "RGB": + front_normal = simple_remove(front_normal, run_sr=False) + front_normal = front_normal.resize((192, 192)) + if back_normal.mode == "RGB": + back_normal = simple_remove(back_normal, run_sr=False) + back_normal = back_normal.resize((192, 192)) + if side_normal.mode == "RGB": + side_normal = simple_remove(side_normal, run_sr=False) + side_normal = side_normal.resize((192, 192)) + + # build mesh with front back projection # ~3s + side_w_over_h = calc_w_over_h(side_normal) + mesh_front = build_mesh(front_normal, front_normal, clamp_min=clamp, scale=side_w_over_h, init_type=init_type) + mesh_back = build_mesh(back_normal, back_normal, is_back=True, clamp_min=clamp, scale=side_w_over_h, init_type=init_type) + meshes = join_meshes_as_scene([mesh_front, mesh_back]) + meshes = fix_border_with_pymeshlab_fast(meshes, poissson_depth=6, simplification=2000) + return meshes + +def refine_rgb(rgb_pils, front_pil): + from scripts.refine_lr_to_sr import refine_lr_with_sd + from scripts.utils import NEG_PROMPT + from app.utils import make_image_grid + from app.all_models import model_zoo + from app.utils import rgba_to_rgb + rgb_pil = make_image_grid(rgb_pils, rows=2) + prompt = "4views, multiview" + neg_prompt = NEG_PROMPT + control_image = rgb_pil.resize((1024, 1024)) + refined_rgb = refine_lr_with_sd([rgb_pil], [rgba_to_rgb(front_pil)], [control_image], prompt_list=[prompt], neg_prompt_list=[neg_prompt], pipe=model_zoo.pipe_disney_controlnet_tile_ipadapter_i2i, strength=0.2, output_size=(1024, 1024))[0] + refined_rgbs = split_image(refined_rgb, rows=2) + return refined_rgbs + +def erode_alpha(img_list): + out_img_list = [] + for idx, img in enumerate(img_list): + arr = np.array(img) + alpha = (arr[:, :, 3] > 127).astype(np.uint8) + # erode 1px + import cv2 + alpha = cv2.erode(alpha, np.ones((3, 3), np.uint8), iterations=1) + alpha = (alpha * 255).astype(np.uint8) + img = Image.fromarray(np.concatenate([arr[:, :, :3], alpha[:, :, None]], axis=-1)) + out_img_list.append(img) + return out_img_list +import time +def geo_reconstruct(rgb_pils, normal_pils, front_pil, do_refine=False, predict_normal=True, expansion_weight=0.1, init_type="std"): + if front_pil.size[0] <= 512: + front_pil = run_sr_fast([front_pil])[0] + if do_refine: + refined_rgbs = refine_rgb(rgb_pils, front_pil) # 6s + else: + refined_rgbs = [rgb.resize((512, 512), resample=Image.LANCZOS) for rgb in rgb_pils] + img_list = [front_pil] + run_sr_fast(refined_rgbs[1:]) + + if predict_normal: + rm_normals = predict_normals([img.resize((512, 512), resample=Image.LANCZOS) for img in img_list], guidance_scale=1.5) + else: + rm_normals = simple_remove([img.resize((512, 512), resample=Image.LANCZOS) for img in normal_pils]) + # transfer the alpha channel of rm_normals to img_list + for idx, img in enumerate(rm_normals): + if idx == 0 and img_list[0].mode == "RGBA": + temp = img_list[0].resize((2048, 2048)) + rm_normals[0] = Image.fromarray(np.concatenate([np.array(rm_normals[0])[:, :, :3], np.array(temp)[:, :, 3:4]], axis=-1)) + continue + img_list[idx] = Image.fromarray(np.concatenate([np.array(img_list[idx]), np.array(img)[:, :, 3:4]], axis=-1)) + assert img_list[0].mode == "RGBA" + assert np.mean(np.array(img_list[0])[..., 3]) < 250 + + img_list = [img_list[0]] + erode_alpha(img_list[1:]) + normal_stg1 = [img.resize((512, 512)) for img in rm_normals] + if init_type in ["std", "thin"]: + meshes = fast_geo(normal_stg1[0], normal_stg1[2], normal_stg1[1], init_type=init_type) + _ = multiview_color_projection(meshes, rgb_pils, resolution=512, device="cuda", complete_unseen=False, confidence_threshold=0.1) # just check for validation, may throw error + vertices, faces, _ = from_py3d_mesh(meshes) + vertices, faces = reconstruct_stage1(normal_stg1, steps=200, vertices=vertices, faces=faces, start_edge_len=0.1, end_edge_len=0.02, gain=0.05, return_mesh=False, loss_expansion_weight=expansion_weight) + elif init_type in ["ball"]: + vertices, faces = reconstruct_stage1(normal_stg1, steps=200, end_edge_len=0.01, return_mesh=False, loss_expansion_weight=expansion_weight) + vertices, faces = run_mesh_refine(vertices, faces, rm_normals, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=20, update_warmup=5, return_mesh=False, process_inputs=False, process_outputs=False) + meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=1, apply_sub_divide=True, sub_divide_threshold=0.25).to("cuda") + new_meshes = multiview_color_projection(meshes, img_list, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([0, 90, 180, 270], "cuda", focal=1)) + return new_meshes diff --git a/scripts/normal_to_height_map.py b/scripts/normal_to_height_map.py new file mode 100644 index 0000000000000000000000000000000000000000..9733e9bc771108fc34ebfc670c06467d83347d8a --- /dev/null +++ b/scripts/normal_to_height_map.py @@ -0,0 +1,203 @@ +# code modified from https://github.com/YertleTurtleGit/depth-from-normals +import numpy as np +import cv2 as cv +from multiprocessing.pool import ThreadPool as Pool +from multiprocessing import cpu_count +from typing import Tuple, List, Union +import numba + + +def calculate_gradients( + normals: np.ndarray, mask: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + horizontal_angle_map = np.arccos(np.clip(normals[:, :, 0], -1, 1)) + left_gradients = np.zeros(normals.shape[:2]) + left_gradients[mask != 0] = (1 - np.sin(horizontal_angle_map[mask != 0])) * np.sign( + horizontal_angle_map[mask != 0] - np.pi / 2 + ) + + vertical_angle_map = np.arccos(np.clip(normals[:, :, 1], -1, 1)) + top_gradients = np.zeros(normals.shape[:2]) + top_gradients[mask != 0] = -(1 - np.sin(vertical_angle_map[mask != 0])) * np.sign( + vertical_angle_map[mask != 0] - np.pi / 2 + ) + + return left_gradients, top_gradients + + +@numba.jit(nopython=True) +def integrate_gradient_field( + gradient_field: np.ndarray, axis: int, mask: np.ndarray +) -> np.ndarray: + heights = np.zeros(gradient_field.shape) + + for d1 in numba.prange(heights.shape[1 - axis]): + sum_value = 0 + for d2 in range(heights.shape[axis]): + coordinates = (d1, d2) if axis == 1 else (d2, d1) + + if mask[coordinates] != 0: + sum_value = sum_value + gradient_field[coordinates] + heights[coordinates] = sum_value + else: + sum_value = 0 + + return heights + + +def calculate_heights( + left_gradients: np.ndarray, top_gradients, mask: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + left_heights = integrate_gradient_field(left_gradients, 1, mask) + right_heights = np.fliplr( + integrate_gradient_field(np.fliplr(-left_gradients), 1, np.fliplr(mask)) + ) + top_heights = integrate_gradient_field(top_gradients, 0, mask) + bottom_heights = np.flipud( + integrate_gradient_field(np.flipud(-top_gradients), 0, np.flipud(mask)) + ) + return left_heights, right_heights, top_heights, bottom_heights + + +def combine_heights(*heights: np.ndarray) -> np.ndarray: + return np.mean(np.stack(heights, axis=0), axis=0) + + +def rotate(matrix: np.ndarray, angle: float) -> np.ndarray: + h, w = matrix.shape[:2] + center = (w / 2, h / 2) + + rotation_matrix = cv.getRotationMatrix2D(center, angle, 1.0) + corners = cv.transform( + np.array([[[0, 0], [w, 0], [w, h], [0, h]]]), rotation_matrix + )[0] + + _, _, w, h = cv.boundingRect(corners) + + rotation_matrix[0, 2] += w / 2 - center[0] + rotation_matrix[1, 2] += h / 2 - center[1] + result = cv.warpAffine(matrix, rotation_matrix, (w, h), flags=cv.INTER_LINEAR) + + return result + + +def rotate_vector_field_normals(normals: np.ndarray, angle: float) -> np.ndarray: + angle = np.radians(angle) + cos_angle = np.cos(angle) + sin_angle = np.sin(angle) + + rotated_normals = np.empty_like(normals) + rotated_normals[:, :, 0] = ( + normals[:, :, 0] * cos_angle - normals[:, :, 1] * sin_angle + ) + rotated_normals[:, :, 1] = ( + normals[:, :, 0] * sin_angle + normals[:, :, 1] * cos_angle + ) + + return rotated_normals + + +def centered_crop(image: np.ndarray, target_resolution: Tuple[int, int]) -> np.ndarray: + return image[ + (image.shape[0] - target_resolution[0]) + // 2 : (image.shape[0] - target_resolution[0]) + // 2 + + target_resolution[0], + (image.shape[1] - target_resolution[1]) + // 2 : (image.shape[1] - target_resolution[1]) + // 2 + + target_resolution[1], + ] + + +def integrate_vector_field( + vector_field: np.ndarray, + mask: np.ndarray, + target_iteration_count: int, + thread_count: int, +) -> np.ndarray: + shape = vector_field.shape[:2] + angles = np.linspace(0, 90, target_iteration_count, endpoint=False) + + def integrate_vector_field_angles(angles: List[float]) -> np.ndarray: + all_combined_heights = np.zeros(shape) + + for angle in angles: + rotated_vector_field = rotate_vector_field_normals( + rotate(vector_field, angle), angle + ) + rotated_mask = rotate(mask, angle) + + left_gradients, top_gradients = calculate_gradients( + rotated_vector_field, rotated_mask + ) + ( + left_heights, + right_heights, + top_heights, + bottom_heights, + ) = calculate_heights(left_gradients, top_gradients, rotated_mask) + + combined_heights = combine_heights( + left_heights, right_heights, top_heights, bottom_heights + ) + combined_heights = centered_crop(rotate(combined_heights, -angle), shape) + all_combined_heights += combined_heights / len(angles) + + return all_combined_heights + + with Pool(processes=thread_count) as pool: + heights = pool.map( + integrate_vector_field_angles, + np.array( + np.array_split(angles, thread_count), + dtype=object, + ), + ) + pool.close() + pool.join() + + isotropic_height = np.zeros(shape) + for height in heights: + isotropic_height += height / thread_count + + return isotropic_height + + +def estimate_height_map( + normal_map: np.ndarray, + mask: Union[np.ndarray, None] = None, + height_divisor: float = 1, + target_iteration_count: int = 250, + thread_count: int = cpu_count(), + raw_values: bool = False, +) -> np.ndarray: + if mask is None: + if normal_map.shape[-1] == 4: + mask = normal_map[:, :, 3] / 255 + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + else: + mask = np.ones(normal_map.shape[:2], dtype=np.uint8) + + normals = ((normal_map[:, :, :3].astype(np.float64) / 255) - 0.5) * 2 + heights = integrate_vector_field( + normals, mask, target_iteration_count, thread_count + ) + + if raw_values: + return heights + + heights /= height_divisor + heights[mask > 0] += 1 / 2 + heights[mask == 0] = 1 / 2 + + heights *= 2**16 - 1 + + if np.min(heights) < 0 or np.max(heights) > 2**16 - 1: + raise OverflowError("Height values are clipping.") + + heights = np.clip(heights, 0, 2**16 - 1) + heights = heights.astype(np.uint16) + + return heights diff --git a/scripts/project_mesh.py b/scripts/project_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..07635c3f104a5888e853233bf20abd86e08a0203 --- /dev/null +++ b/scripts/project_mesh.py @@ -0,0 +1,378 @@ +from typing import List +import torch +import numpy as np +from PIL import Image +from pytorch3d.renderer.cameras import look_at_view_transform, OrthographicCameras, CamerasBase +from pytorch3d.renderer.mesh.rasterizer import Fragments +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + RasterizationSettings, + TexturesVertex, + FoVPerspectiveCameras, + FoVOrthographicCameras, +) +from pytorch3d.renderer import MeshRasterizer + +def get_camera(world_to_cam, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'): + # pytorch3d expects transforms as row-vectors, so flip rotation: https://github.com/facebookresearch/pytorch3d/issues/1183 + R = world_to_cam[:3, :3].t()[None, ...] + T = world_to_cam[:3, 3][None, ...] + if cam_type == 'fov': + camera = FoVPerspectiveCameras(device=world_to_cam.device, R=R, T=T, fov=fov_in_degrees, degrees=True) + else: + focal_length = 1 / focal_length + camera = FoVOrthographicCameras(device=world_to_cam.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length) + return camera + +def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1): + """ + Renders pix2face of visible faces. + + :param mesh: Pytorch3d.structures.Meshes + :param cameras: pytorch3d.renderer.Cameras + :param H: target image height + :param W: target image width + :param blur_radius: Float distance in the range [0, 2] used to expand the face + bounding boxes for rasterization. Setting blur radius + results in blurred edges around the shape instead of a + hard boundary. Set to 0 for no blur. + :param faces_per_pixel: (int) Number of faces to keep track of per pixel. + We return the nearest faces_per_pixel faces along the z-axis. + """ + # Define the settings for rasterization and shading + raster_settings = RasterizationSettings( + image_size=(H, W), + blur_radius=blur_radius, + faces_per_pixel=faces_per_pixel + ) + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=raster_settings + ) + fragments: Fragments = rasterizer(meshes, cameras=cameras) + return { + "pix_to_face": fragments.pix_to_face[..., 0], + } + +import nvdiffrast.torch as dr + +def _warmup(glctx, device=None): + device = 'cuda' if device is None else device + #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59 + def tensor(*args, **kwargs): + return torch.tensor(*args, device=device, **kwargs) + pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32) + tri = tensor([[0, 1, 2]], dtype=torch.int32) + dr.rasterize(glctx, pos, tri, resolution=[256, 256]) + +class Pix2FacesRenderer: + def __init__(self, device="cuda"): + self._glctx = dr.RasterizeGLContext(output_db=False, device=device) + self.device = device + _warmup(self._glctx, device) + + def transform_vertices(self, meshes: Meshes, cameras: CamerasBase): + vertices = cameras.transform_points_ndc(meshes.verts_padded()) + + perspective_correct = cameras.is_perspective() + znear = cameras.get_znear() + if isinstance(znear, torch.Tensor): + znear = znear.min().item() + z_clip = None if not perspective_correct or znear is None else znear / 2 + + if z_clip: + vertices = vertices[vertices[..., 2] >= cameras.get_znear()][None] # clip + vertices = vertices * torch.tensor([-1, -1, 1]).to(vertices) + vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1).to(torch.float32) + return vertices + + def render_pix2faces_nvdiff(self, meshes: Meshes, cameras: CamerasBase, H=512, W=512): + meshes = meshes.to(self.device) + cameras = cameras.to(self.device) + vertices = self.transform_vertices(meshes, cameras) + faces = meshes.faces_packed().to(torch.int32) + rast_out,_ = dr.rasterize(self._glctx, vertices, faces, resolution=(H, W), grad_db=False) #C,H,W,4 + pix_to_face = rast_out[..., -1].to(torch.int32) - 1 + return pix_to_face + +pix2faces_renderer = Pix2FacesRenderer() + +def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024): + # pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face'] + pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution) + + unique_faces = torch.unique(pix_to_face.flatten()) + unique_faces = unique_faces[unique_faces != -1] + return unique_faces + +def project_color(meshes: Meshes, cameras: CamerasBase, pil_image: Image.Image, use_alpha=True, eps=0.05, resolution=1024, device="cuda") -> dict: + """ + Projects color from a given image onto a 3D mesh. + + Args: + meshes (pytorch3d.structures.Meshes): The 3D mesh object. + cameras (pytorch3d.renderer.cameras.CamerasBase): The camera object. + pil_image (PIL.Image.Image): The input image. + use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True. + eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05. + resolution (int, optional): The resolution of the projection. Defaults to 1024. + device (str, optional): The device to use for computation. Defaults to "cuda". + debug (bool, optional): Whether to save debug images. Defaults to False. + + Returns: + dict: A dictionary containing the following keys: + - "new_texture" (TexturesVertex): The updated texture with interpolated colors. + - "valid_verts" (Tensor of [M,3]): The indices of the vertices being projected. + - "valid_colors" (Tensor of [M,3]): The interpolated colors for the valid vertices. + """ + meshes = meshes.to(device) + cameras = cameras.to(device) + image = torch.from_numpy(np.array(pil_image.convert("RGBA")) / 255.).permute((2, 0, 1)).float().to(device) # in CHW format of [0, 1.] + unique_faces = get_visible_faces(meshes, cameras, resolution=resolution) + + # visible faces + faces_normals = meshes.faces_normals_packed()[unique_faces] + faces_normals = faces_normals / faces_normals.norm(dim=1, keepdim=True) + world_points = cameras.unproject_points(torch.tensor([[[0., 0., 0.1], [0., 0., 0.2]]]).to(device))[0] + view_direction = world_points[1] - world_points[0] + view_direction = view_direction / view_direction.norm(dim=0, keepdim=True) + + # find invalid faces + cos_angles = (faces_normals * view_direction).sum(dim=1) + assert cos_angles.mean() < 0, f"The view direction is not correct. cos_angles.mean()={cos_angles.mean()}" + selected_faces = unique_faces[cos_angles < -eps] + + # find verts + faces = meshes.faces_packed()[selected_faces] # [N, 3] + verts = torch.unique(faces.flatten()) # [N, 1] + verts_coordinates = meshes.verts_packed()[verts] # [N, 3] + + # compute color + pt_tensor = cameras.transform_points(verts_coordinates)[..., :2] # NDC space points + valid = ~((pt_tensor.isnan()|(pt_tensor<-1)|(1 dict: + """ + meshes: the mesh with vertex color to be completed. + valid_index: the index of the valid vertices, where valid means colors are fixed. [V, 1] + """ + valid_index = valid_index.to(meshes.device) + colors = meshes.textures.verts_features_packed() # [V, 3] + V = colors.shape[0] + + invalid_index = torch.ones_like(colors[:, 0]).bool() # [V] + invalid_index[valid_index] = False + invalid_index = torch.arange(V).to(meshes.device)[invalid_index] + + L = meshes.laplacian_packed() + E = torch.sparse_coo_tensor(torch.tensor([list(range(V))] * 2), torch.ones((V,)), size=(V, V)).to(meshes.device) + L = L + E + # E = torch.eye(V, layout=torch.sparse_coo, device=meshes.device) + # L = L + E + colored_count = torch.ones_like(colors[:, 0]) # [V] + colored_count[invalid_index] = 0 + L_invalid = torch.index_select(L, 0, invalid_index) # sparse [IV, V] + + total_colored = colored_count.sum() + coloring_round = 0 + stage = "uncolored" + from tqdm import tqdm + pbar = tqdm(miniters=100) + while stage == "uncolored" or coloring_round > 0: + new_color = torch.matmul(L_invalid, colors * colored_count[:, None]) # [IV, 3] + new_count = torch.matmul(L_invalid, colored_count)[:, None] # [IV, 1] + colors[invalid_index] = torch.where(new_count > 0, new_color / new_count, colors[invalid_index]) + colored_count[invalid_index] = (new_count[:, 0] > 0).float() + + new_total_colored = colored_count.sum() + if new_total_colored > total_colored: + total_colored = new_total_colored + coloring_round += 1 + else: + stage = "colored" + coloring_round -= 1 + pbar.update(1) + if coloring_round > 10000: + print("coloring_round > 10000, break") + break + assert not torch.isnan(colors).any() + meshes.textures = TexturesVertex(verts_features=[colors]) + return meshes + +def multiview_color_projection(meshes: Meshes, image_list: List[Image.Image], cameras_list: List[CamerasBase]=None, camera_focal: float = 2 / 1.35, weights=None, eps=0.05, resolution=1024, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy="smooth") -> Meshes: + """ + Projects color from a given image onto a 3D mesh. + + Args: + meshes (pytorch3d.structures.Meshes): The 3D mesh object, only one mesh. + image_list (PIL.Image.Image): List of images. + cameras_list (list): List of cameras. + camera_focal (float, optional): The focal length of the camera, if cameras_list is not passed. Defaults to 2 / 1.35. + weights (list, optional): List of weights for each image, for ['front', 'front_right', 'right', 'back', 'left', 'front_left']. Defaults to None. + eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05. + resolution (int, optional): The resolution of the projection. Defaults to 1024. + device (str, optional): The device to use for computation. Defaults to "cuda". + reweight_with_cosangle (str, optional): Whether to reweight the color with the angle between the view direction and the vertex normal. Defaults to None. + use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True. + confidence_threshold (float, optional): The threshold for the confidence of the projected color, if final projection weight is less than this, we will use the original color. Defaults to 0.1. + complete_unseen (bool, optional): Whether to complete the unseen vertex color using laplacian. Defaults to False. + + Returns: + Meshes: the colored mesh + """ + # 1. preprocess inputs + if image_list is None: + raise ValueError("image_list is None") + if cameras_list is None: + if len(image_list) == 8: + cameras_list = get_8view_cameras(device, focal=camera_focal) + elif len(image_list) == 6: + cameras_list = get_6view_cameras(device, focal=camera_focal) + elif len(image_list) == 4: + cameras_list = get_4view_cameras(device, focal=camera_focal) + elif len(image_list) == 2: + cameras_list = get_2view_cameras(device, focal=camera_focal) + else: + raise ValueError("cameras_list is None, and can not be guessed from image_list") + if weights is None: + if len(image_list) == 8: + weights = [2.0, 0.05, 0.2, 0.02, 1.0, 0.02, 0.2, 0.05] + elif len(image_list) == 6: + weights = [2.0, 0.05, 0.2, 1.0, 0.2, 0.05] + elif len(image_list) == 4: + weights = [2.0, 0.2, 1.0, 0.2] + elif len(image_list) == 2: + weights = [1.0, 1.0] + else: + raise ValueError("weights is None, and can not be guessed from image_list") + + # 2. run projection + meshes = meshes.clone().to(device) + if weights is None: + weights = [1. for _ in range(len(cameras_list))] + assert len(cameras_list) == len(image_list) == len(weights) + original_color = meshes.textures.verts_features_packed() + assert not torch.isnan(original_color).any() + texture_counts = torch.zeros_like(original_color[..., :1]) + texture_values = torch.zeros_like(original_color) + max_texture_counts = torch.zeros_like(original_color[..., :1]) + max_texture_values = torch.zeros_like(original_color) + for camera, image, weight in zip(cameras_list, image_list, weights): + ret = project_color(meshes, camera, image, eps=eps, resolution=resolution, device=device, use_alpha=use_alpha) + if reweight_with_cosangle == "linear": + weight = (ret['cos_angles'].abs() * weight)[:, None] + elif reweight_with_cosangle == "square": + weight = (ret['cos_angles'].abs() ** 2 * weight)[:, None] + if use_alpha: + weight = weight * ret['valid_alpha'] + assert weight.min() > -0.0001 + texture_counts[ret['valid_verts']] += weight + texture_values[ret['valid_verts']] += ret['valid_colors'] * weight + max_texture_values[ret['valid_verts']] = torch.where(weight > max_texture_counts[ret['valid_verts']], ret['valid_colors'], max_texture_values[ret['valid_verts']]) + max_texture_counts[ret['valid_verts']] = torch.max(max_texture_counts[ret['valid_verts']], weight) + + # Method2 + texture_values = torch.where(texture_counts > confidence_threshold, texture_values / texture_counts, texture_values) + if below_confidence_strategy == "smooth": + texture_values = torch.where(texture_counts <= confidence_threshold, (original_color * (confidence_threshold - texture_counts) + texture_values) / confidence_threshold, texture_values) + elif below_confidence_strategy == "original": + texture_values = torch.where(texture_counts <= confidence_threshold, original_color, texture_values) + else: + raise ValueError(f"below_confidence_strategy={below_confidence_strategy} is not supported") + assert not torch.isnan(texture_values).any() + meshes.textures = TexturesVertex(verts_features=[texture_values]) + + if complete_unseen: + meshes = complete_unseen_vertex_color(meshes, torch.arange(texture_values.shape[0]).to(device)[texture_counts[:, 0] >= confidence_threshold]) + ret_mesh = meshes.detach() + del meshes + return ret_mesh + +def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1): + ret = [] + for azim in azim_list: + R, T = look_at_view_transform(dist, 0, azim) + w2c = torch.cat([R[0].T, T[0, :, None]], dim=1) + cameras: OrthographicCameras = get_camera(w2c, focal_length=focal, cam_type='orthogonal').to(device) + ret.append(cameras) + return ret + +def get_8view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 225, 270, 315, 0, 45, 90, 135], device=device, focal=focal) + +def get_6view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 225, 270, 0, 90, 135], device=device, focal=focal) + +def get_4view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 270, 0, 90], device=device, focal=focal) + +def get_2view_cameras(device, focal=2/1.35): + return get_cameras_list(azim_list = [180, 0], device=device, focal=focal) + +def get_multiple_view_cameras(device, focal=2/1.35, offset=180, num_views=8, dist=1.1): + return get_cameras_list(azim_list = (np.linspace(0, 360, num_views+1)[:-1] + offset) % 360, device=device, focal=focal, dist=dist) + +def align_with_alpha_bbox(source_img, target_img, final_size=1024): + # align source_img with target_img using alpha channel + # source_img and target_img are PIL.Image.Image + source_img = source_img.convert("RGBA") + target_img = target_img.convert("RGBA").resize((final_size, final_size)) + source_np = np.array(source_img) + target_np = np.array(target_img) + source_alpha = source_np[:, :, 3] + target_alpha = target_np[:, :, 3] + bbox_source_min, bbox_source_max = np.argwhere(source_alpha > 0).min(axis=0), np.argwhere(source_alpha > 0).max(axis=0) + bbox_target_min, bbox_target_max = np.argwhere(target_alpha > 0).min(axis=0), np.argwhere(target_alpha > 0).max(axis=0) + source_content = source_np[bbox_source_min[0]:bbox_source_max[0]+1, bbox_source_min[1]:bbox_source_max[1]+1, :] + # resize source_content to fit in the position of target_content + source_content = Image.fromarray(source_content).resize((bbox_target_max[1]-bbox_target_min[1]+1, bbox_target_max[0]-bbox_target_min[0]+1), resample=Image.BICUBIC) + target_np[bbox_target_min[0]:bbox_target_max[0]+1, bbox_target_min[1]:bbox_target_max[1]+1, :] = np.array(source_content) + return Image.fromarray(target_np) + +def load_image_list_from_mvdiffusion(mvdiffusion_path, front_from_pil_or_path=None): + import os + image_list = [] + for dir in ['front', 'front_right', 'right', 'back', 'left', 'front_left']: + image_path = os.path.join(mvdiffusion_path, f"rgb_000_{dir}.png") + pil = Image.open(image_path) + if dir == 'front': + if front_from_pil_or_path is not None: + if isinstance(front_from_pil_or_path, str): + replace_pil = Image.open(front_from_pil_or_path) + else: + replace_pil = front_from_pil_or_path + # align replace_pil with pil using bounding box in alpha channel + pil = align_with_alpha_bbox(replace_pil, pil, final_size=1024) + image_list.append(pil) + return image_list + +def load_image_list_from_img_grid(img_grid_path, resolution = 1024): + img_list = [] + grid = Image.open(img_grid_path) + w, h = grid.size + for row in range(0, h, resolution): + for col in range(0, w, resolution): + img_list.append(grid.crop((col, row, col + resolution, row + resolution))) + return img_list \ No newline at end of file diff --git a/scripts/refine_lr_to_sr.py b/scripts/refine_lr_to_sr.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a50cab8cd6d722ce0bf47ceb42ba168070fd9d --- /dev/null +++ b/scripts/refine_lr_to_sr.py @@ -0,0 +1,60 @@ +import torch +import os + +import numpy as np +from hashlib import md5 +def hash_img(img): + return md5(np.array(img).tobytes()).hexdigest() +def hash_any(obj): + return md5(str(obj).encode()).hexdigest() + +def refine_lr_with_sd(pil_image_list, concept_img_list, control_image_list, prompt_list, pipe=None, strength=0.35, neg_prompt_list="", output_size=(512, 512), controlnet_conditioning_scale=1.): + with torch.no_grad(): + images = pipe( + image=pil_image_list, + ip_adapter_image=concept_img_list, + prompt=prompt_list, + neg_prompt=neg_prompt_list, + num_inference_steps=50, + strength=strength, + height=output_size[0], + width=output_size[1], + control_image=control_image_list, + guidance_scale=5.0, + controlnet_conditioning_scale=controlnet_conditioning_scale, + generator=torch.manual_seed(233), + ).images + return images + +SR_cache = None + +def run_sr_fast(source_pils, scale=4): + from PIL import Image + from scripts.upsampler import RealESRGANer + import numpy as np + global SR_cache + if SR_cache is not None: + upsampler = SR_cache + else: + upsampler = RealESRGANer( + scale=4, + onnx_path="ckpt/realesrgan-x4.onnx", + tile=0, + tile_pad=10, + pre_pad=0, + half=True, + gpu_id=0, + ) + ret_pils = [] + for idx, img_pils in enumerate(source_pils): + np_in = isinstance(img_pils, np.ndarray) + assert isinstance(img_pils, (Image.Image, np.ndarray)) + img = np.array(img_pils) + output, _ = upsampler.enhance(img, outscale=scale) + if np_in: + ret_pils.append(output) + else: + ret_pils.append(Image.fromarray(output)) + if SR_cache is None: + SR_cache = upsampler + return ret_pils diff --git a/scripts/sd_model_zoo.py b/scripts/sd_model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5e271ca5b8f242998e151482b86984aac1c10d --- /dev/null +++ b/scripts/sd_model_zoo.py @@ -0,0 +1,131 @@ +from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline +from transformers import CLIPVisionModelWithProjection +import torch +from copy import deepcopy + +ENABLE_CPU_CACHE = False +DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5" + +cached_models = {} # cache for models to avoid repeated loading, key is model name +def cache_model(func): + def wrapper(*args, **kwargs): + if ENABLE_CPU_CACHE: + model_name = func.__name__ + str(args) + str(kwargs) + if model_name not in cached_models: + cached_models[model_name] = func(*args, **kwargs) + return cached_models[model_name] + else: + return func(*args, **kwargs) + return wrapper + +def copied_cache_model(func): + def wrapper(*args, **kwargs): + if ENABLE_CPU_CACHE: + model_name = func.__name__ + str(args) + str(kwargs) + if model_name not in cached_models: + cached_models[model_name] = func(*args, **kwargs) + return deepcopy(cached_models[model_name]) + else: + return func(*args, **kwargs) + return wrapper + +def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs): + if ckpt_or_pretrained.endswith(".safetensors"): + pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs) + else: + pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs) + return pipe + +@copied_cache_model +def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16): + model_kwargs = dict( + torch_dtype=torch_dtype, + requires_safety_checker=False, + safety_checker=None, + ) + pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( + base_model, + StableDiffusionPipeline, + **model_kwargs + ) + pipe.to("cpu") + return pipe.components + +@cache_model +def load_controlnet(controlnet_path, torch_dtype=torch.float16): + controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype) + return controlnet + +@cache_model +def load_image_encoder(): + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "h94/IP-Adapter", + subfolder="models/image_encoder", + torch_dtype=torch.float16, + ) + return image_encoder + +def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="auto", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs): + model_kwargs = dict( + torch_dtype=torch_dtype, + device_map=device, + requires_safety_checker=False, + safety_checker=None, + ) + components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype) + model_kwargs.update(components) + model_kwargs.update(kwargs) + + if controlnet is not None: + if isinstance(controlnet, list): + controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet] + else: + controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype) + model_kwargs.update(controlnet=controlnet) + + if pipeline_class is None: + if controlnet is not None: + pipeline_class = StableDiffusionControlNetPipeline + else: + pipeline_class = StableDiffusionPipeline + + pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( + base_model, + pipeline_class, + **model_kwargs + ) + + if ip_adapter: + image_encoder = load_image_encoder() + pipe.image_encoder = image_encoder + if plus_model: + pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors") + else: + pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors") + pipe.set_ip_adapter_scale(1.0) + else: + pipe.unload_ip_adapter() + + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) + + if model_cpu_offload_seq is None: + if isinstance(pipe, StableDiffusionControlNetPipeline): + pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae" + elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline): + pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae" + else: + pipe.model_cpu_offload_seq = model_cpu_offload_seq + + if enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload() + else: + pipe = pipe.to("cuda") + pass + # pipe.enable_model_cpu_offload() + if vae_slicing: + pipe.enable_vae_slicing() + + import gc + gc.collect() + return pipe + diff --git a/scripts/upsampler.py b/scripts/upsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4f4999ab864c9eb0282832fb1ad02b63e6014926 --- /dev/null +++ b/scripts/upsampler.py @@ -0,0 +1,229 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torch.nn import functional as F +from scripts.load_onnx import load_onnx_caller +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, + scale, + onnx_path, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + if gpu_id: + self.device = torch.device( + f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + self.model = load_onnx_caller(onnx_path, single_output=True) + # warm up + sample_input = torch.randn(1,3,512,512).cuda().float() + self.model(sample_input) + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img = self.post_process() + output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_input * outscale), + int(h_input * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + return output, img_mode + diff --git a/scripts/utils.py b/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a938160d84ee188d9160da70e61187a954d881e0 --- /dev/null +++ b/scripts/utils.py @@ -0,0 +1,319 @@ +import torch +import numpy as np +from PIL import Image +import pymeshlab +import pymeshlab as ml +from pymeshlab import PercentageValue +from pytorch3d.renderer import TexturesVertex +from pytorch3d.structures import Meshes +from rembg import new_session, remove +import torch +import torch.nn.functional as F +from typing import List, Tuple +from PIL import Image +import trimesh + +providers = [ + ('CUDAExecutionProvider', { + 'device_id': 0, + 'arena_extend_strategy': 'kSameAsRequested', + 'gpu_mem_limit': 8 * 1024 * 1024 * 1024, + 'cudnn_conv_algo_search': 'HEURISTIC', + }) +] + +session = new_session(providers=providers) + +NEG_PROMPT="sketch, sculpture, hand drawing, outline, single color, NSFW, lowres, bad anatomy,bad hands, text, error, missing fingers, yellow sleeves, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry,(worst quality:1.4),(low quality:1.4)" + +def load_mesh_with_trimesh(file_name, file_type=None): + import trimesh + mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type) + if isinstance(mesh, trimesh.Scene): + assert len(mesh.geometry) > 0 + # save to obj first and load again to avoid offset issue + from io import BytesIO + with BytesIO() as f: + mesh.export(f, file_type="obj") + f.seek(0) + mesh = trimesh.load(f, file_type="obj") + if isinstance(mesh, trimesh.Scene): + # we lose texture information here + mesh = trimesh.util.concatenate( + tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces) + for g in mesh.geometry.values())) + assert isinstance(mesh, trimesh.Trimesh) + + vertices = torch.from_numpy(mesh.vertices).T + faces = torch.from_numpy(mesh.faces).T + colors = None + if mesh.visual is not None: + if hasattr(mesh.visual, 'vertex_colors'): + colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255. + if colors is None: + # print("Warning: no vertex color found in mesh! Filling it with gray.") + colors = torch.ones_like(vertices) * 0.5 + return vertices, faces, colors + +def meshlab_mesh_to_py3dmesh(mesh: pymeshlab.Mesh) -> Meshes: + verts = torch.from_numpy(mesh.vertex_matrix()).float() + faces = torch.from_numpy(mesh.face_matrix()).long() + colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float() + textures = TexturesVertex(verts_features=[colors]) + return Meshes(verts=[verts], faces=[faces], textures=textures) + + +def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh: + colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64) + m1 = pymeshlab.Mesh( + vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64), + face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32), + v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64), + v_color_matrix=colors_in) + return m1 + + +def to_pyml_mesh(vertices,faces): + m1 = pymeshlab.Mesh( + vertex_matrix=vertices.cpu().float().numpy().astype(np.float64), + face_matrix=faces.cpu().long().numpy().astype(np.int32), + ) + return m1 + + +def to_py3d_mesh(vertices, faces, normals=None): + from pytorch3d.structures import Meshes + from pytorch3d.renderer.mesh.textures import TexturesVertex + mesh = Meshes(verts=[vertices], faces=[faces], textures=None) + if normals is None: + normals = mesh.verts_normals_packed() + # set normals as vertext colors + mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5]) + return mesh + + +def from_py3d_mesh(mesh): + return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed() + +def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float): + """ + rotate along y-axis + normal_map: np.array, shape=(H, W, 3) in [-1, 1] + angle: float, in degree + """ + angle = angle / 180 * np.pi + R = np.array([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]]) + return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape) + +# from view coord to front view world coord +def rotate_normals(normal_pils, return_types='np', rotate_direction=1) -> np.ndarray: # [0, 255] + n_views = len(normal_pils) + ret = [] + for idx, rgba_normal in enumerate(normal_pils): + # rotate normal + normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1] + alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1] + normal_np = normal_np * 2 - 1 + normal_np = rotate_normalmap_by_angle(normal_np, rotate_direction * idx * (360 / n_views)) + normal_np = (normal_np + 1) / 2 + normal_np = normal_np * alpha_np[..., None] # make bg black + rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255] , axis=-1) + if return_types == 'np': + ret.append(rgba_normal_np) + elif return_types == 'pil': + ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8))) + else: + raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}") + return ret + + +def rotate_normalmap_by_angle_torch(normal_map, angle): + """ + rotate along y-axis + normal_map: torch.Tensor, shape=(H, W, 3) in [-1, 1], device='cuda' + angle: float, in degree + """ + angle = torch.tensor(angle / 180 * np.pi).to(normal_map) + R = torch.tensor([[torch.cos(angle), 0, torch.sin(angle)], + [0, 1, 0], + [-torch.sin(angle), 0, torch.cos(angle)]]).to(normal_map) + return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape) + +def do_rotate(rgba_normal, angle): + rgba_normal = torch.from_numpy(rgba_normal).float().cuda() / 255 + rotated_normal_tensor = rotate_normalmap_by_angle_torch(rgba_normal[..., :3] * 2 - 1, angle) + rotated_normal_tensor = (rotated_normal_tensor + 1) / 2 + rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]] # make bg black + rgba_normal_np = torch.cat([rotated_normal_tensor * 255, rgba_normal[:, :, [3]] * 255], dim=-1).cpu().numpy() + return rgba_normal_np + +def rotate_normals_torch(normal_pils, return_types='np', rotate_direction=1): + n_views = len(normal_pils) + ret = [] + for idx, rgba_normal in enumerate(normal_pils): + # rotate normal + angle = rotate_direction * idx * (360 / n_views) + rgba_normal_np = do_rotate(np.array(rgba_normal), angle) + if return_types == 'np': + ret.append(rgba_normal_np) + elif return_types == 'pil': + ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8))) + else: + raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}") + return ret + +def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)): + ret = [] + new_bkgd = np.array(new_bkgd).reshape(1, 1, 3) + for rgba_img in img_pils: + img_np = np.array(rgba_img)[:, :, :3] / 255 + alpha_np = np.array(rgba_img)[:, :, 3] / 255 + ori_bkgd = img_np[:1, :1] + # color = ori_color * alpha + bkgd * (1-alpha) + # ori_color = (color - bkgd * (1-alpha)) / alpha + alpha_np_clamp = np.clip(alpha_np, 1e-6, 1) # avoid divide by zero + ori_img_np = (img_np - ori_bkgd * (1 - alpha_np[..., None])) / alpha_np_clamp[..., None] + img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np * alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd) + rgba_img_np = np.concatenate([img_np * 255, alpha_np[..., None] * 255], axis=-1) + ret.append(Image.fromarray(rgba_img_np.astype(np.uint8))) + return ret + +def change_bkgd_to_normal(normal_pils) -> List[Image.Image]: + n_views = len(normal_pils) + ret = [] + for idx, rgba_normal in enumerate(normal_pils): + # calcuate background normal + target_bkgd = rotate_normalmap_by_angle(np.array([[[0., 0., 1.]]]), idx * (360 / n_views)) + normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1] + alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1] + normal_np = normal_np * 2 - 1 + old_bkgd = normal_np[:1,:1] + normal_np[alpha_np > 0.05] = (normal_np[alpha_np > 0.05] - old_bkgd * (1 - alpha_np[alpha_np > 0.05][..., None])) / alpha_np[alpha_np > 0.05][..., None] + normal_np = normal_np * alpha_np[..., None] + target_bkgd * (1 - alpha_np[..., None]) + normal_np = (normal_np + 1) / 2 + rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[..., None] * 255] , axis=-1) + ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8))) + return ret + + +def fix_vert_color_glb(mesh_path): + from pygltflib import GLTF2, Material, PbrMetallicRoughness + obj1 = GLTF2().load(mesh_path) + obj1.meshes[0].primitives[0].material = 0 + obj1.materials.append(Material( + pbrMetallicRoughness = PbrMetallicRoughness( + baseColorFactor = [1.0, 1.0, 1.0, 1.0], + metallicFactor = 0., + roughnessFactor = 1.0, + ), + emissiveFactor = [0.0, 0.0, 0.0], + doubleSided = True, + )) + obj1.save(mesh_path) + + +def srgb_to_linear(c_srgb): + c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4) + return c_linear.clip(0, 1.) + + +def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True): + # convert from pytorch3d meshes to trimesh mesh + vertices = meshes.verts_packed().cpu().float().numpy() + triangles = meshes.faces_packed().cpu().long().numpy() + np_color = meshes.textures.verts_features_packed().cpu().float().numpy() + if save_glb_path.endswith(".glb"): + # rotate 180 along +Y + vertices[:, [0, 2]] = -vertices[:, [0, 2]] + + if apply_sRGB_to_LinearRGB: + np_color = srgb_to_linear(np_color) + assert vertices.shape[0] == np_color.shape[0] + assert np_color.shape[1] == 3 + assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}" + mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color) + mesh.remove_unreferenced_vertices() + # save mesh + mesh.export(save_glb_path) + if save_glb_path.endswith(".glb"): + fix_vert_color_glb(save_glb_path) + print(f"saving to {save_glb_path}") + + +def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]: + import time + if '.' in save_mesh_prefix: + save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1]) + if with_timestamp: + save_mesh_prefix = save_mesh_prefix + f"_{int(time.time())}" + ret_mesh = save_mesh_prefix + ".glb" + # optimizied version + save_py3dmesh_with_trimesh_fast(meshes, ret_mesh) + return ret_mesh, None + + +def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25): + ms = ml.MeshSet() + ms.add_mesh(pyml_mesh, "cube_mesh") + + if apply_smooth: + ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False) + if apply_sub_divide: # 5s, slow + ms.apply_filter("meshing_repair_non_manifold_vertices") + ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces') + ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold)) + return meshlab_mesh_to_py3dmesh(ms.current_mesh()) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def simple_preprocess(input_image, rembg_session=session, background_color=255): + RES = 2048 + input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS) + if input_image.mode != 'RGBA': + image_rem = input_image.convert('RGBA') + input_image = remove(image_rem, alpha_matting=False, session=rembg_session) + + arr = np.asarray(input_image) + alpha = np.asarray(input_image)[:, :, -1] + x_nonzero = np.nonzero((alpha > 60).sum(axis=1)) + y_nonzero = np.nonzero((alpha > 60).sum(axis=0)) + x_min = int(x_nonzero[0].min()) + y_min = int(y_nonzero[0].min()) + x_max = int(x_nonzero[0].max()) + y_max = int(y_nonzero[0].max()) + arr = arr[x_min: x_max, y_min: y_max] + input_image = Image.fromarray(arr) + input_image = expand2square(input_image, (background_color, background_color, background_color, 0)) + return input_image + +def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"): + # Convert the background color to a PyTorch tensor + new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device) + + # Convert all images to PyTorch tensors and process them + imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255 + img_nps = imgs[..., :3] + alpha_nps = imgs[..., 3] + ori_bkgds = img_nps[:, :1, :1] + + # Avoid divide by zero and calculate the original image + alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1) + ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1) + ori_img_nps = torch.clamp(ori_img_nps, 0, 1) + img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd) + + rgba_img_np = torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1) + return rgba_img_np \ No newline at end of file