Spaces:
Running
Running
gartajackhats1985
commited on
Commit
•
028694a
1
Parent(s):
deb4eee
Upload 45 files
Browse files- ComfyUI-Advanced-ControlNet/.github/workflows/publish.yml +20 -0
- ComfyUI-Advanced-ControlNet/.gitignore +160 -0
- ComfyUI-Advanced-ControlNet/LICENSE +674 -0
- ComfyUI-Advanced-ControlNet/README.md +202 -0
- ComfyUI-Advanced-ControlNet/__init__.py +6 -0
- ComfyUI-Advanced-ControlNet/__pycache__/__init__.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control_lllite.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control_plusplus.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control_reference.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control_sparsectrl.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control_svd.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/documentation.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/logger.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_deprecated.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_keyframes.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_loosecontrol.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_plusplus.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_reference.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_sparsectrl.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_weight.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/sampling.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/__pycache__/utils.cpython-312.pyc +0 -0
- ComfyUI-Advanced-ControlNet/adv_control/control.py +918 -0
- ComfyUI-Advanced-ControlNet/adv_control/control_lllite.py +462 -0
- ComfyUI-Advanced-ControlNet/adv_control/control_plusplus.py +485 -0
- ComfyUI-Advanced-ControlNet/adv_control/control_reference.py +1112 -0
- ComfyUI-Advanced-ControlNet/adv_control/control_sparsectrl.py +1078 -0
- ComfyUI-Advanced-ControlNet/adv_control/control_svd.py +518 -0
- ComfyUI-Advanced-ControlNet/adv_control/documentation.py +47 -0
- ComfyUI-Advanced-ControlNet/adv_control/logger.py +36 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes.py +331 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_deprecated.py +251 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_keyframes.py +468 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_loosecontrol.py +67 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_plusplus.py +85 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_reference.py +90 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_sparsectrl.py +186 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_weight.py +285 -0
- ComfyUI-Advanced-ControlNet/adv_control/sampling.py +216 -0
- ComfyUI-Advanced-ControlNet/adv_control/utils.py +981 -0
- ComfyUI-Advanced-ControlNet/pyproject.toml +15 -0
- ComfyUI-Advanced-ControlNet/web/js/autosize.js +53 -0
- ComfyUI-Advanced-ControlNet/web/js/documentation.js +293 -0
ComfyUI-Advanced-ControlNet/.github/workflows/publish.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Publish to Comfy registry
|
2 |
+
on:
|
3 |
+
workflow_dispatch:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
paths:
|
8 |
+
- "pyproject.toml"
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
publish-node:
|
12 |
+
name: Publish Custom Node to registry
|
13 |
+
runs-on: ubuntu-latest
|
14 |
+
steps:
|
15 |
+
- name: Check out code
|
16 |
+
uses: actions/checkout@v4
|
17 |
+
- name: Publish Custom Node
|
18 |
+
uses: Comfy-Org/publish-node-action@main
|
19 |
+
with:
|
20 |
+
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here.
|
ComfyUI-Advanced-ControlNet/.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
ComfyUI-Advanced-ControlNet/LICENSE
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 29 June 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works.
|
12 |
+
|
13 |
+
The licenses for most software and other practical works are designed
|
14 |
+
to take away your freedom to share and change the works. By contrast,
|
15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
16 |
+
share and change all versions of a program--to make sure it remains free
|
17 |
+
software for all its users. We, the Free Software Foundation, use the
|
18 |
+
GNU General Public License for most of our software; it applies also to
|
19 |
+
any other work released this way by its authors. You can apply it to
|
20 |
+
your programs, too.
|
21 |
+
|
22 |
+
When we speak of free software, we are referring to freedom, not
|
23 |
+
price. Our General Public Licenses are designed to make sure that you
|
24 |
+
have the freedom to distribute copies of free software (and charge for
|
25 |
+
them if you wish), that you receive source code or can get it if you
|
26 |
+
want it, that you can change the software or use pieces of it in new
|
27 |
+
free programs, and that you know you can do these things.
|
28 |
+
|
29 |
+
To protect your rights, we need to prevent others from denying you
|
30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
31 |
+
certain responsibilities if you distribute copies of the software, or if
|
32 |
+
you modify it: responsibilities to respect the freedom of others.
|
33 |
+
|
34 |
+
For example, if you distribute copies of such a program, whether
|
35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
36 |
+
freedoms that you received. You must make sure that they, too, receive
|
37 |
+
or can get the source code. And you must show them these terms so they
|
38 |
+
know their rights.
|
39 |
+
|
40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
43 |
+
|
44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
45 |
+
that there is no warranty for this free software. For both users' and
|
46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
47 |
+
changed, so that their problems will not be attributed erroneously to
|
48 |
+
authors of previous versions.
|
49 |
+
|
50 |
+
Some devices are designed to deny users access to install or run
|
51 |
+
modified versions of the software inside them, although the manufacturer
|
52 |
+
can do so. This is fundamentally incompatible with the aim of
|
53 |
+
protecting users' freedom to change the software. The systematic
|
54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
56 |
+
have designed this version of the GPL to prohibit the practice for those
|
57 |
+
products. If such problems arise substantially in other domains, we
|
58 |
+
stand ready to extend this provision to those domains in future versions
|
59 |
+
of the GPL, as needed to protect the freedom of users.
|
60 |
+
|
61 |
+
Finally, every program is threatened constantly by software patents.
|
62 |
+
States should not allow patents to restrict development and use of
|
63 |
+
software on general-purpose computers, but in those that do, we wish to
|
64 |
+
avoid the special danger that patents applied to a free program could
|
65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
66 |
+
patents cannot be used to render the program non-free.
|
67 |
+
|
68 |
+
The precise terms and conditions for copying, distribution and
|
69 |
+
modification follow.
|
70 |
+
|
71 |
+
TERMS AND CONDITIONS
|
72 |
+
|
73 |
+
0. Definitions.
|
74 |
+
|
75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
76 |
+
|
77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
78 |
+
works, such as semiconductor masks.
|
79 |
+
|
80 |
+
"The Program" refers to any copyrightable work licensed under this
|
81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
82 |
+
"recipients" may be individuals or organizations.
|
83 |
+
|
84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
85 |
+
in a fashion requiring copyright permission, other than the making of an
|
86 |
+
exact copy. The resulting work is called a "modified version" of the
|
87 |
+
earlier work or a work "based on" the earlier work.
|
88 |
+
|
89 |
+
A "covered work" means either the unmodified Program or a work based
|
90 |
+
on the Program.
|
91 |
+
|
92 |
+
To "propagate" a work means to do anything with it that, without
|
93 |
+
permission, would make you directly or secondarily liable for
|
94 |
+
infringement under applicable copyright law, except executing it on a
|
95 |
+
computer or modifying a private copy. Propagation includes copying,
|
96 |
+
distribution (with or without modification), making available to the
|
97 |
+
public, and in some countries other activities as well.
|
98 |
+
|
99 |
+
To "convey" a work means any kind of propagation that enables other
|
100 |
+
parties to make or receive copies. Mere interaction with a user through
|
101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
102 |
+
|
103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
104 |
+
to the extent that it includes a convenient and prominently visible
|
105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
106 |
+
tells the user that there is no warranty for the work (except to the
|
107 |
+
extent that warranties are provided), that licensees may convey the
|
108 |
+
work under this License, and how to view a copy of this License. If
|
109 |
+
the interface presents a list of user commands or options, such as a
|
110 |
+
menu, a prominent item in the list meets this criterion.
|
111 |
+
|
112 |
+
1. Source Code.
|
113 |
+
|
114 |
+
The "source code" for a work means the preferred form of the work
|
115 |
+
for making modifications to it. "Object code" means any non-source
|
116 |
+
form of a work.
|
117 |
+
|
118 |
+
A "Standard Interface" means an interface that either is an official
|
119 |
+
standard defined by a recognized standards body, or, in the case of
|
120 |
+
interfaces specified for a particular programming language, one that
|
121 |
+
is widely used among developers working in that language.
|
122 |
+
|
123 |
+
The "System Libraries" of an executable work include anything, other
|
124 |
+
than the work as a whole, that (a) is included in the normal form of
|
125 |
+
packaging a Major Component, but which is not part of that Major
|
126 |
+
Component, and (b) serves only to enable use of the work with that
|
127 |
+
Major Component, or to implement a Standard Interface for which an
|
128 |
+
implementation is available to the public in source code form. A
|
129 |
+
"Major Component", in this context, means a major essential component
|
130 |
+
(kernel, window system, and so on) of the specific operating system
|
131 |
+
(if any) on which the executable work runs, or a compiler used to
|
132 |
+
produce the work, or an object code interpreter used to run it.
|
133 |
+
|
134 |
+
The "Corresponding Source" for a work in object code form means all
|
135 |
+
the source code needed to generate, install, and (for an executable
|
136 |
+
work) run the object code and to modify the work, including scripts to
|
137 |
+
control those activities. However, it does not include the work's
|
138 |
+
System Libraries, or general-purpose tools or generally available free
|
139 |
+
programs which are used unmodified in performing those activities but
|
140 |
+
which are not part of the work. For example, Corresponding Source
|
141 |
+
includes interface definition files associated with source files for
|
142 |
+
the work, and the source code for shared libraries and dynamically
|
143 |
+
linked subprograms that the work is specifically designed to require,
|
144 |
+
such as by intimate data communication or control flow between those
|
145 |
+
subprograms and other parts of the work.
|
146 |
+
|
147 |
+
The Corresponding Source need not include anything that users
|
148 |
+
can regenerate automatically from other parts of the Corresponding
|
149 |
+
Source.
|
150 |
+
|
151 |
+
The Corresponding Source for a work in source code form is that
|
152 |
+
same work.
|
153 |
+
|
154 |
+
2. Basic Permissions.
|
155 |
+
|
156 |
+
All rights granted under this License are granted for the term of
|
157 |
+
copyright on the Program, and are irrevocable provided the stated
|
158 |
+
conditions are met. This License explicitly affirms your unlimited
|
159 |
+
permission to run the unmodified Program. The output from running a
|
160 |
+
covered work is covered by this License only if the output, given its
|
161 |
+
content, constitutes a covered work. This License acknowledges your
|
162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
163 |
+
|
164 |
+
You may make, run and propagate covered works that you do not
|
165 |
+
convey, without conditions so long as your license otherwise remains
|
166 |
+
in force. You may convey covered works to others for the sole purpose
|
167 |
+
of having them make modifications exclusively for you, or provide you
|
168 |
+
with facilities for running those works, provided that you comply with
|
169 |
+
the terms of this License in conveying all material for which you do
|
170 |
+
not control copyright. Those thus making or running the covered works
|
171 |
+
for you must do so exclusively on your behalf, under your direction
|
172 |
+
and control, on terms that prohibit them from making any copies of
|
173 |
+
your copyrighted material outside their relationship with you.
|
174 |
+
|
175 |
+
Conveying under any other circumstances is permitted solely under
|
176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
177 |
+
makes it unnecessary.
|
178 |
+
|
179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
180 |
+
|
181 |
+
No covered work shall be deemed part of an effective technological
|
182 |
+
measure under any applicable law fulfilling obligations under article
|
183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
184 |
+
similar laws prohibiting or restricting circumvention of such
|
185 |
+
measures.
|
186 |
+
|
187 |
+
When you convey a covered work, you waive any legal power to forbid
|
188 |
+
circumvention of technological measures to the extent such circumvention
|
189 |
+
is effected by exercising rights under this License with respect to
|
190 |
+
the covered work, and you disclaim any intention to limit operation or
|
191 |
+
modification of the work as a means of enforcing, against the work's
|
192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
193 |
+
technological measures.
|
194 |
+
|
195 |
+
4. Conveying Verbatim Copies.
|
196 |
+
|
197 |
+
You may convey verbatim copies of the Program's source code as you
|
198 |
+
receive it, in any medium, provided that you conspicuously and
|
199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
200 |
+
keep intact all notices stating that this License and any
|
201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
202 |
+
keep intact all notices of the absence of any warranty; and give all
|
203 |
+
recipients a copy of this License along with the Program.
|
204 |
+
|
205 |
+
You may charge any price or no price for each copy that you convey,
|
206 |
+
and you may offer support or warranty protection for a fee.
|
207 |
+
|
208 |
+
5. Conveying Modified Source Versions.
|
209 |
+
|
210 |
+
You may convey a work based on the Program, or the modifications to
|
211 |
+
produce it from the Program, in the form of source code under the
|
212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
213 |
+
|
214 |
+
a) The work must carry prominent notices stating that you modified
|
215 |
+
it, and giving a relevant date.
|
216 |
+
|
217 |
+
b) The work must carry prominent notices stating that it is
|
218 |
+
released under this License and any conditions added under section
|
219 |
+
7. This requirement modifies the requirement in section 4 to
|
220 |
+
"keep intact all notices".
|
221 |
+
|
222 |
+
c) You must license the entire work, as a whole, under this
|
223 |
+
License to anyone who comes into possession of a copy. This
|
224 |
+
License will therefore apply, along with any applicable section 7
|
225 |
+
additional terms, to the whole of the work, and all its parts,
|
226 |
+
regardless of how they are packaged. This License gives no
|
227 |
+
permission to license the work in any other way, but it does not
|
228 |
+
invalidate such permission if you have separately received it.
|
229 |
+
|
230 |
+
d) If the work has interactive user interfaces, each must display
|
231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
233 |
+
work need not make them do so.
|
234 |
+
|
235 |
+
A compilation of a covered work with other separate and independent
|
236 |
+
works, which are not by their nature extensions of the covered work,
|
237 |
+
and which are not combined with it such as to form a larger program,
|
238 |
+
in or on a volume of a storage or distribution medium, is called an
|
239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
240 |
+
used to limit the access or legal rights of the compilation's users
|
241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
242 |
+
in an aggregate does not cause this License to apply to the other
|
243 |
+
parts of the aggregate.
|
244 |
+
|
245 |
+
6. Conveying Non-Source Forms.
|
246 |
+
|
247 |
+
You may convey a covered work in object code form under the terms
|
248 |
+
of sections 4 and 5, provided that you also convey the
|
249 |
+
machine-readable Corresponding Source under the terms of this License,
|
250 |
+
in one of these ways:
|
251 |
+
|
252 |
+
a) Convey the object code in, or embodied in, a physical product
|
253 |
+
(including a physical distribution medium), accompanied by the
|
254 |
+
Corresponding Source fixed on a durable physical medium
|
255 |
+
customarily used for software interchange.
|
256 |
+
|
257 |
+
b) Convey the object code in, or embodied in, a physical product
|
258 |
+
(including a physical distribution medium), accompanied by a
|
259 |
+
written offer, valid for at least three years and valid for as
|
260 |
+
long as you offer spare parts or customer support for that product
|
261 |
+
model, to give anyone who possesses the object code either (1) a
|
262 |
+
copy of the Corresponding Source for all the software in the
|
263 |
+
product that is covered by this License, on a durable physical
|
264 |
+
medium customarily used for software interchange, for a price no
|
265 |
+
more than your reasonable cost of physically performing this
|
266 |
+
conveying of source, or (2) access to copy the
|
267 |
+
Corresponding Source from a network server at no charge.
|
268 |
+
|
269 |
+
c) Convey individual copies of the object code with a copy of the
|
270 |
+
written offer to provide the Corresponding Source. This
|
271 |
+
alternative is allowed only occasionally and noncommercially, and
|
272 |
+
only if you received the object code with such an offer, in accord
|
273 |
+
with subsection 6b.
|
274 |
+
|
275 |
+
d) Convey the object code by offering access from a designated
|
276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
277 |
+
Corresponding Source in the same way through the same place at no
|
278 |
+
further charge. You need not require recipients to copy the
|
279 |
+
Corresponding Source along with the object code. If the place to
|
280 |
+
copy the object code is a network server, the Corresponding Source
|
281 |
+
may be on a different server (operated by you or a third party)
|
282 |
+
that supports equivalent copying facilities, provided you maintain
|
283 |
+
clear directions next to the object code saying where to find the
|
284 |
+
Corresponding Source. Regardless of what server hosts the
|
285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
286 |
+
available for as long as needed to satisfy these requirements.
|
287 |
+
|
288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
289 |
+
you inform other peers where the object code and Corresponding
|
290 |
+
Source of the work are being offered to the general public at no
|
291 |
+
charge under subsection 6d.
|
292 |
+
|
293 |
+
A separable portion of the object code, whose source code is excluded
|
294 |
+
from the Corresponding Source as a System Library, need not be
|
295 |
+
included in conveying the object code work.
|
296 |
+
|
297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
298 |
+
tangible personal property which is normally used for personal, family,
|
299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
302 |
+
product received by a particular user, "normally used" refers to a
|
303 |
+
typical or common use of that class of product, regardless of the status
|
304 |
+
of the particular user or of the way in which the particular user
|
305 |
+
actually uses, or expects or is expected to use, the product. A product
|
306 |
+
is a consumer product regardless of whether the product has substantial
|
307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
308 |
+
the only significant mode of use of the product.
|
309 |
+
|
310 |
+
"Installation Information" for a User Product means any methods,
|
311 |
+
procedures, authorization keys, or other information required to install
|
312 |
+
and execute modified versions of a covered work in that User Product from
|
313 |
+
a modified version of its Corresponding Source. The information must
|
314 |
+
suffice to ensure that the continued functioning of the modified object
|
315 |
+
code is in no case prevented or interfered with solely because
|
316 |
+
modification has been made.
|
317 |
+
|
318 |
+
If you convey an object code work under this section in, or with, or
|
319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
320 |
+
part of a transaction in which the right of possession and use of the
|
321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
322 |
+
fixed term (regardless of how the transaction is characterized), the
|
323 |
+
Corresponding Source conveyed under this section must be accompanied
|
324 |
+
by the Installation Information. But this requirement does not apply
|
325 |
+
if neither you nor any third party retains the ability to install
|
326 |
+
modified object code on the User Product (for example, the work has
|
327 |
+
been installed in ROM).
|
328 |
+
|
329 |
+
The requirement to provide Installation Information does not include a
|
330 |
+
requirement to continue to provide support service, warranty, or updates
|
331 |
+
for a work that has been modified or installed by the recipient, or for
|
332 |
+
the User Product in which it has been modified or installed. Access to a
|
333 |
+
network may be denied when the modification itself materially and
|
334 |
+
adversely affects the operation of the network or violates the rules and
|
335 |
+
protocols for communication across the network.
|
336 |
+
|
337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
338 |
+
in accord with this section must be in a format that is publicly
|
339 |
+
documented (and with an implementation available to the public in
|
340 |
+
source code form), and must require no special password or key for
|
341 |
+
unpacking, reading or copying.
|
342 |
+
|
343 |
+
7. Additional Terms.
|
344 |
+
|
345 |
+
"Additional permissions" are terms that supplement the terms of this
|
346 |
+
License by making exceptions from one or more of its conditions.
|
347 |
+
Additional permissions that are applicable to the entire Program shall
|
348 |
+
be treated as though they were included in this License, to the extent
|
349 |
+
that they are valid under applicable law. If additional permissions
|
350 |
+
apply only to part of the Program, that part may be used separately
|
351 |
+
under those permissions, but the entire Program remains governed by
|
352 |
+
this License without regard to the additional permissions.
|
353 |
+
|
354 |
+
When you convey a copy of a covered work, you may at your option
|
355 |
+
remove any additional permissions from that copy, or from any part of
|
356 |
+
it. (Additional permissions may be written to require their own
|
357 |
+
removal in certain cases when you modify the work.) You may place
|
358 |
+
additional permissions on material, added by you to a covered work,
|
359 |
+
for which you have or can give appropriate copyright permission.
|
360 |
+
|
361 |
+
Notwithstanding any other provision of this License, for material you
|
362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
363 |
+
that material) supplement the terms of this License with terms:
|
364 |
+
|
365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
366 |
+
terms of sections 15 and 16 of this License; or
|
367 |
+
|
368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
369 |
+
author attributions in that material or in the Appropriate Legal
|
370 |
+
Notices displayed by works containing it; or
|
371 |
+
|
372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
373 |
+
requiring that modified versions of such material be marked in
|
374 |
+
reasonable ways as different from the original version; or
|
375 |
+
|
376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
377 |
+
authors of the material; or
|
378 |
+
|
379 |
+
e) Declining to grant rights under trademark law for use of some
|
380 |
+
trade names, trademarks, or service marks; or
|
381 |
+
|
382 |
+
f) Requiring indemnification of licensors and authors of that
|
383 |
+
material by anyone who conveys the material (or modified versions of
|
384 |
+
it) with contractual assumptions of liability to the recipient, for
|
385 |
+
any liability that these contractual assumptions directly impose on
|
386 |
+
those licensors and authors.
|
387 |
+
|
388 |
+
All other non-permissive additional terms are considered "further
|
389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
390 |
+
received it, or any part of it, contains a notice stating that it is
|
391 |
+
governed by this License along with a term that is a further
|
392 |
+
restriction, you may remove that term. If a license document contains
|
393 |
+
a further restriction but permits relicensing or conveying under this
|
394 |
+
License, you may add to a covered work material governed by the terms
|
395 |
+
of that license document, provided that the further restriction does
|
396 |
+
not survive such relicensing or conveying.
|
397 |
+
|
398 |
+
If you add terms to a covered work in accord with this section, you
|
399 |
+
must place, in the relevant source files, a statement of the
|
400 |
+
additional terms that apply to those files, or a notice indicating
|
401 |
+
where to find the applicable terms.
|
402 |
+
|
403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
404 |
+
form of a separately written license, or stated as exceptions;
|
405 |
+
the above requirements apply either way.
|
406 |
+
|
407 |
+
8. Termination.
|
408 |
+
|
409 |
+
You may not propagate or modify a covered work except as expressly
|
410 |
+
provided under this License. Any attempt otherwise to propagate or
|
411 |
+
modify it is void, and will automatically terminate your rights under
|
412 |
+
this License (including any patent licenses granted under the third
|
413 |
+
paragraph of section 11).
|
414 |
+
|
415 |
+
However, if you cease all violation of this License, then your
|
416 |
+
license from a particular copyright holder is reinstated (a)
|
417 |
+
provisionally, unless and until the copyright holder explicitly and
|
418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
419 |
+
holder fails to notify you of the violation by some reasonable means
|
420 |
+
prior to 60 days after the cessation.
|
421 |
+
|
422 |
+
Moreover, your license from a particular copyright holder is
|
423 |
+
reinstated permanently if the copyright holder notifies you of the
|
424 |
+
violation by some reasonable means, this is the first time you have
|
425 |
+
received notice of violation of this License (for any work) from that
|
426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
427 |
+
your receipt of the notice.
|
428 |
+
|
429 |
+
Termination of your rights under this section does not terminate the
|
430 |
+
licenses of parties who have received copies or rights from you under
|
431 |
+
this License. If your rights have been terminated and not permanently
|
432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
433 |
+
material under section 10.
|
434 |
+
|
435 |
+
9. Acceptance Not Required for Having Copies.
|
436 |
+
|
437 |
+
You are not required to accept this License in order to receive or
|
438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
440 |
+
to receive a copy likewise does not require acceptance. However,
|
441 |
+
nothing other than this License grants you permission to propagate or
|
442 |
+
modify any covered work. These actions infringe copyright if you do
|
443 |
+
not accept this License. Therefore, by modifying or propagating a
|
444 |
+
covered work, you indicate your acceptance of this License to do so.
|
445 |
+
|
446 |
+
10. Automatic Licensing of Downstream Recipients.
|
447 |
+
|
448 |
+
Each time you convey a covered work, the recipient automatically
|
449 |
+
receives a license from the original licensors, to run, modify and
|
450 |
+
propagate that work, subject to this License. You are not responsible
|
451 |
+
for enforcing compliance by third parties with this License.
|
452 |
+
|
453 |
+
An "entity transaction" is a transaction transferring control of an
|
454 |
+
organization, or substantially all assets of one, or subdividing an
|
455 |
+
organization, or merging organizations. If propagation of a covered
|
456 |
+
work results from an entity transaction, each party to that
|
457 |
+
transaction who receives a copy of the work also receives whatever
|
458 |
+
licenses to the work the party's predecessor in interest had or could
|
459 |
+
give under the previous paragraph, plus a right to possession of the
|
460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
461 |
+
the predecessor has it or can get it with reasonable efforts.
|
462 |
+
|
463 |
+
You may not impose any further restrictions on the exercise of the
|
464 |
+
rights granted or affirmed under this License. For example, you may
|
465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
466 |
+
rights granted under this License, and you may not initiate litigation
|
467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
468 |
+
any patent claim is infringed by making, using, selling, offering for
|
469 |
+
sale, or importing the Program or any portion of it.
|
470 |
+
|
471 |
+
11. Patents.
|
472 |
+
|
473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
474 |
+
License of the Program or a work on which the Program is based. The
|
475 |
+
work thus licensed is called the contributor's "contributor version".
|
476 |
+
|
477 |
+
A contributor's "essential patent claims" are all patent claims
|
478 |
+
owned or controlled by the contributor, whether already acquired or
|
479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
480 |
+
by this License, of making, using, or selling its contributor version,
|
481 |
+
but do not include claims that would be infringed only as a
|
482 |
+
consequence of further modification of the contributor version. For
|
483 |
+
purposes of this definition, "control" includes the right to grant
|
484 |
+
patent sublicenses in a manner consistent with the requirements of
|
485 |
+
this License.
|
486 |
+
|
487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
488 |
+
patent license under the contributor's essential patent claims, to
|
489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
490 |
+
propagate the contents of its contributor version.
|
491 |
+
|
492 |
+
In the following three paragraphs, a "patent license" is any express
|
493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
494 |
+
(such as an express permission to practice a patent or covenant not to
|
495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
496 |
+
party means to make such an agreement or commitment not to enforce a
|
497 |
+
patent against the party.
|
498 |
+
|
499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
500 |
+
and the Corresponding Source of the work is not available for anyone
|
501 |
+
to copy, free of charge and under the terms of this License, through a
|
502 |
+
publicly available network server or other readily accessible means,
|
503 |
+
then you must either (1) cause the Corresponding Source to be so
|
504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
506 |
+
consistent with the requirements of this License, to extend the patent
|
507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
508 |
+
actual knowledge that, but for the patent license, your conveying the
|
509 |
+
covered work in a country, or your recipient's use of the covered work
|
510 |
+
in a country, would infringe one or more identifiable patents in that
|
511 |
+
country that you have reason to believe are valid.
|
512 |
+
|
513 |
+
If, pursuant to or in connection with a single transaction or
|
514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
515 |
+
covered work, and grant a patent license to some of the parties
|
516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
517 |
+
or convey a specific copy of the covered work, then the patent license
|
518 |
+
you grant is automatically extended to all recipients of the covered
|
519 |
+
work and works based on it.
|
520 |
+
|
521 |
+
A patent license is "discriminatory" if it does not include within
|
522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
524 |
+
specifically granted under this License. You may not convey a covered
|
525 |
+
work if you are a party to an arrangement with a third party that is
|
526 |
+
in the business of distributing software, under which you make payment
|
527 |
+
to the third party based on the extent of your activity of conveying
|
528 |
+
the work, and under which the third party grants, to any of the
|
529 |
+
parties who would receive the covered work from you, a discriminatory
|
530 |
+
patent license (a) in connection with copies of the covered work
|
531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
532 |
+
for and in connection with specific products or compilations that
|
533 |
+
contain the covered work, unless you entered into that arrangement,
|
534 |
+
or that patent license was granted, prior to 28 March 2007.
|
535 |
+
|
536 |
+
Nothing in this License shall be construed as excluding or limiting
|
537 |
+
any implied license or other defenses to infringement that may
|
538 |
+
otherwise be available to you under applicable patent law.
|
539 |
+
|
540 |
+
12. No Surrender of Others' Freedom.
|
541 |
+
|
542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
543 |
+
otherwise) that contradict the conditions of this License, they do not
|
544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
546 |
+
License and any other pertinent obligations, then as a consequence you may
|
547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
548 |
+
to collect a royalty for further conveying from those to whom you convey
|
549 |
+
the Program, the only way you could satisfy both those terms and this
|
550 |
+
License would be to refrain entirely from conveying the Program.
|
551 |
+
|
552 |
+
13. Use with the GNU Affero General Public License.
|
553 |
+
|
554 |
+
Notwithstanding any other provision of this License, you have
|
555 |
+
permission to link or combine any covered work with a work licensed
|
556 |
+
under version 3 of the GNU Affero General Public License into a single
|
557 |
+
combined work, and to convey the resulting work. The terms of this
|
558 |
+
License will continue to apply to the part which is the covered work,
|
559 |
+
but the special requirements of the GNU Affero General Public License,
|
560 |
+
section 13, concerning interaction through a network will apply to the
|
561 |
+
combination as such.
|
562 |
+
|
563 |
+
14. Revised Versions of this License.
|
564 |
+
|
565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
566 |
+
the GNU General Public License from time to time. Such new versions will
|
567 |
+
be similar in spirit to the present version, but may differ in detail to
|
568 |
+
address new problems or concerns.
|
569 |
+
|
570 |
+
Each version is given a distinguishing version number. If the
|
571 |
+
Program specifies that a certain numbered version of the GNU General
|
572 |
+
Public License "or any later version" applies to it, you have the
|
573 |
+
option of following the terms and conditions either of that numbered
|
574 |
+
version or of any later version published by the Free Software
|
575 |
+
Foundation. If the Program does not specify a version number of the
|
576 |
+
GNU General Public License, you may choose any version ever published
|
577 |
+
by the Free Software Foundation.
|
578 |
+
|
579 |
+
If the Program specifies that a proxy can decide which future
|
580 |
+
versions of the GNU General Public License can be used, that proxy's
|
581 |
+
public statement of acceptance of a version permanently authorizes you
|
582 |
+
to choose that version for the Program.
|
583 |
+
|
584 |
+
Later license versions may give you additional or different
|
585 |
+
permissions. However, no additional obligations are imposed on any
|
586 |
+
author or copyright holder as a result of your choosing to follow a
|
587 |
+
later version.
|
588 |
+
|
589 |
+
15. Disclaimer of Warranty.
|
590 |
+
|
591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
+
|
600 |
+
16. Limitation of Liability.
|
601 |
+
|
602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
+
SUCH DAMAGES.
|
611 |
+
|
612 |
+
17. Interpretation of Sections 15 and 16.
|
613 |
+
|
614 |
+
If the disclaimer of warranty and limitation of liability provided
|
615 |
+
above cannot be given local legal effect according to their terms,
|
616 |
+
reviewing courts shall apply local law that most closely approximates
|
617 |
+
an absolute waiver of all civil liability in connection with the
|
618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
619 |
+
copy of the Program in return for a fee.
|
620 |
+
|
621 |
+
END OF TERMS AND CONDITIONS
|
622 |
+
|
623 |
+
How to Apply These Terms to Your New Programs
|
624 |
+
|
625 |
+
If you develop a new program, and you want it to be of the greatest
|
626 |
+
possible use to the public, the best way to achieve this is to make it
|
627 |
+
free software which everyone can redistribute and change under these terms.
|
628 |
+
|
629 |
+
To do so, attach the following notices to the program. It is safest
|
630 |
+
to attach them to the start of each source file to most effectively
|
631 |
+
state the exclusion of warranty; and each file should have at least
|
632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
633 |
+
|
634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
635 |
+
Copyright (C) <year> <name of author>
|
636 |
+
|
637 |
+
This program is free software: you can redistribute it and/or modify
|
638 |
+
it under the terms of the GNU General Public License as published by
|
639 |
+
the Free Software Foundation, either version 3 of the License, or
|
640 |
+
(at your option) any later version.
|
641 |
+
|
642 |
+
This program is distributed in the hope that it will be useful,
|
643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
+
GNU General Public License for more details.
|
646 |
+
|
647 |
+
You should have received a copy of the GNU General Public License
|
648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
649 |
+
|
650 |
+
Also add information on how to contact you by electronic and paper mail.
|
651 |
+
|
652 |
+
If the program does terminal interaction, make it output a short
|
653 |
+
notice like this when it starts in an interactive mode:
|
654 |
+
|
655 |
+
<program> Copyright (C) <year> <name of author>
|
656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
657 |
+
This is free software, and you are welcome to redistribute it
|
658 |
+
under certain conditions; type `show c' for details.
|
659 |
+
|
660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
661 |
+
parts of the General Public License. Of course, your program's commands
|
662 |
+
might be different; for a GUI interface, you would use an "about box".
|
663 |
+
|
664 |
+
You should also get your employer (if you work as a programmer) or school,
|
665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
667 |
+
<https://www.gnu.org/licenses/>.
|
668 |
+
|
669 |
+
The GNU General Public License does not permit incorporating your program
|
670 |
+
into proprietary programs. If your program is a subroutine library, you
|
671 |
+
may consider it more useful to permit linking proprietary applications with
|
672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
673 |
+
Public License instead of this License. But first, please read
|
674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
ComfyUI-Advanced-ControlNet/README.md
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ComfyUI-Advanced-ControlNet
|
2 |
+
Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks. The ControlNet nodes here fully support sliding context sampling, like the one used in the [ComfyUI-AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) nodes. Currently supports ControlNets, T2IAdapters, ControlLoRAs, ControlLLLite, SparseCtrls, SVD-ControlNets, and Reference.
|
3 |
+
|
4 |
+
Custom weights allow replication of the "My prompt is more important" feature of Auto1111's sd-webui ControlNet extension via Soft Weights, and the "ControlNet is more important" feature can be granularly controlled by changing the uncond_multiplier on the same Soft Weights.
|
5 |
+
|
6 |
+
ControlNet preprocessors are available through [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux) nodes.
|
7 |
+
|
8 |
+
## Features
|
9 |
+
- Timestep and latent strength scheduling
|
10 |
+
- Attention masks
|
11 |
+
- Replicate ***"My prompt is more important"*** feature from sd-webui-controlnet extension via ***Soft Weights***, and allow softness to be tweaked via ***base_multiplier***
|
12 |
+
- Replicate ***"ControlNet is more important"*** feature from sd-webui-controlnet extension via ***uncond_multiplier*** on ***Soft Weights***
|
13 |
+
- uncond_multiplier=0.0 gives identical results of auto1111's feature, but values between 0.0 and 1.0 can be used without issue to granularly control the setting.
|
14 |
+
- ControlNet, T2IAdapter, and ControlLoRA support for sliding context windows
|
15 |
+
- ControlLLLite support
|
16 |
+
- SparseCtrl support
|
17 |
+
- SVD-ControlNet support
|
18 |
+
- Stable Video Diffusion ControlNets trained by **CiaraRowles**: [Depth](https://huggingface.co/CiaraRowles/temporal-controlnet-depth-svd-v1/tree/main/controlnet), [Lineart](https://huggingface.co/CiaraRowles/temporal-controlnet-lineart-svd-v1/tree/main/controlnet)
|
19 |
+
- Reference support
|
20 |
+
- Supports ```reference_attn```, ```reference_adain```, and ```refrence_adain+attn``` modes. ```style_fidelity``` and ```ref_weight``` are equivalent to style_fidelity and control_weight in Auto1111, respectively, and strength of the Apply ControlNet is the balance between ref-influenced result and no-ref result. There is also a Reference ControlNet (Finetune) node that allows adjust the style_fidelity, weight, and strength of attn and adain separately.
|
21 |
+
|
22 |
+
## Table of Contents:
|
23 |
+
- [Scheduling Explanation](#scheduling-explanation)
|
24 |
+
- [Nodes](#nodes)
|
25 |
+
- [Usage](#usage) (will fill this out soon)
|
26 |
+
|
27 |
+
|
28 |
+
# Scheduling Explanation
|
29 |
+
|
30 |
+
The two core concepts for scheduling are ***Timestep Keyframes*** and ***Latent Keyframes***.
|
31 |
+
|
32 |
+
***Timestep Keyframes*** hold the values that guide the settings for a controlnet, and begin to take effect based on their start_percent, which corresponds to the percentage of the sampling process. They can contain masks for the strengths of each latent, control_net_weights, and latent_keyframes (specific strengths for each latent), all optional.
|
33 |
+
|
34 |
+
***Latent Keyframes*** determine the strength of the controlnet for specific latents - all they contain is the batch_index of the latent, and the strength the controlnet should apply for that latent. As a concept, latent keyframes achieve the same affect as a uniform mask with the chosen strength value.
|
35 |
+
|
36 |
+
![advcn_image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/e6275264-6c3f-4246-a319-111ee48f4cd9)
|
37 |
+
|
38 |
+
# Nodes
|
39 |
+
|
40 |
+
The ControlNet nodes provided here are the ***Apply Advanced ControlNet*** and ***Load Advanced ControlNet Model*** (or diff) nodes. The vanilla ControlNet nodes are also compatible, and can be used almost interchangeably - the only difference is that **at least one of these nodes must be used** for Advanced versions of ControlNets to be used (important for sliding context sampling, like with AnimateDiff-Evolved).
|
41 |
+
|
42 |
+
Key:
|
43 |
+
- 🟩 - required inputs
|
44 |
+
- 🟨 - optional inputs
|
45 |
+
- 🟦 - start as widgets, can be converted to inputs
|
46 |
+
- 🟥 - optional input/output, but not recommended to use unless needed
|
47 |
+
- 🟪 - output
|
48 |
+
|
49 |
+
## Apply Advanced ControlNet
|
50 |
+
![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/dc541d41-70df-4a71-b832-efa65af98f06)
|
51 |
+
|
52 |
+
Same functionality as the vanilla Apply Advanced ControlNet (Advanced) node, except with Advanced ControlNet features added to it. Automatically converts any ControlNet from ControlNet loaders into Advanced versions.
|
53 |
+
|
54 |
+
### Inputs
|
55 |
+
- 🟩***positive***: conditioning (positive).
|
56 |
+
- 🟩***negative***: conditioning (negative).
|
57 |
+
- 🟩***control_net***: loaded controlnet; will be converted to Advanced version automatically by this node, if it's a supported type.
|
58 |
+
- 🟩***image***: images to guide controlnets - if the loaded controlnet requires it, they must preprocessed images. If one image provided, will be used for all latents. If more images provided, will use each image separately for each latent. If not enough images to meet latent count, will repeat the images from the beginning to match vanilla ControlNet functionality.
|
59 |
+
- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as image input, if you provide more than one mask, each can apply to a different latent.
|
60 |
+
- 🟨***timestep_kf***: timestep keyframes to guide controlnet effect throughout sampling steps.
|
61 |
+
- 🟨***latent_kf_override***: override for latent keyframes, useful if no other features from timestep keyframes is needed. *NOTE: this latent keyframe will be applied to ALL timesteps, regardless if there are other latent keyframes attached to connected timestep keyframes.*
|
62 |
+
- 🟨***weights_override***: override for weights, useful if no other features from timestep keyframes is needed. *NOTE: this weight will be applied to ALL timesteps, regardless if there are other weights attached to connected timestep keyframes.*
|
63 |
+
- 🟦***strength***: strength of controlnet; 1.0 is full strength, 0.0 is no effect at all.
|
64 |
+
- 🟦***start_percent***: sampling step percentage at which controlnet should start to be applied - no matter what start_percent is set on timestep keyframes, they won't take effect until this start_percent is reached.
|
65 |
+
- 🟦***stop_percent***: sampling step percentage at which controlnet should stop being applied - no matter what start_percent is set on timestep keyframes, they won't take effect once this end_percent is reached.
|
66 |
+
|
67 |
+
### Outputs
|
68 |
+
- 🟪***positive***: conditioning (positive) with applied controlnets
|
69 |
+
- 🟪***negative***: conditioning (negative) with applied controlnets
|
70 |
+
|
71 |
+
## Load Advanced ControlNet Model
|
72 |
+
![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/4a7f58a9-783d-4da4-bf82-bc9c167e4722)
|
73 |
+
|
74 |
+
Loads a ControlNet model and converts it into an Advanced version that supports all the features in this repo. When used with **Apply Advanced ControlNet** node, there is no reason to use the timestep_keyframe input on this node - use timestep_kf on the Apply node instead.
|
75 |
+
|
76 |
+
### Inputs
|
77 |
+
- 🟥***timestep_keyframe***: optional and likely unnecessary input to have ControlNet use selected timestep_keyframes - should not be used unless you need to. Useful if this node is not attached to **Apply Advanced ControlNet** node, but still want to use Timestep Keyframe, or to use TK_SHORTCUT outputs from ControlWeights in the same scenario. Will be overriden by the timestep_kf input on **Apply Advanced ControlNet** node, if one is provided there.
|
78 |
+
- 🟨***model***: model to plug into the diff version of the node. Some controlnets are designed for receive the model; if you don't know what this does, you probably don't want tot use the diff version of the node.
|
79 |
+
|
80 |
+
### Outputs
|
81 |
+
- 🟪***CONTROL_NET***: loaded Advanced ControlNet
|
82 |
+
|
83 |
+
## Timestep Keyframe
|
84 |
+
![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/404f3cfe-5852-4eed-935b-37e32493d1b5)
|
85 |
+
|
86 |
+
Scheduling node across timesteps (sampling steps) based on the set start_percent. Chaining Timestep Keyframes allows ControlNet scheduling across sampling steps (percentage-wise), through a timestep keyframe schedule.
|
87 |
+
|
88 |
+
### Inputs
|
89 |
+
- 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
|
90 |
+
- 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
|
91 |
+
- 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
|
92 |
+
- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
|
93 |
+
- 🟦***start_percent***: sampling step percentage at which this Timestep Keyframe qualifies to be used. Acts as the 'key' for the Timestep Keyframe in the timestep keyframe schedule.
|
94 |
+
- 🟦***strength***: strength of the controlnet; multiplies the controlnet by this value, basically, applied alongside the strength on the Apply ControlNet node. If set to 0.0 will not have any effect during the duration of this Timestep Keyframe's effect, and will increase sampling speed by not doing any work.
|
95 |
+
- 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
|
96 |
+
- 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
|
97 |
+
- 🟦***guarantee_steps***: when 1 or greater, even if a Timestep Keyframe's start_percent ahead of this one in the schedule is closer to current sampling percentage, this Timestep Keyframe will still be used for the specified amount of steps before moving on to the next selected Timestep Keyframe in the following step. Whether the Timestep Keyframe is used or not, its inputs will still be accounted for inherit_missing purposes.
|
98 |
+
|
99 |
+
### Outputs
|
100 |
+
- 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
|
101 |
+
|
102 |
+
## Timestep Keyframe Interpolation
|
103 |
+
![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/9789617c-202c-4271-92a2-0909bcf9b108)
|
104 |
+
|
105 |
+
Allows to create Timestep Keyframe with interpolated strength values in a given percent range. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0).
|
106 |
+
|
107 |
+
### Inputs
|
108 |
+
- 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
|
109 |
+
- 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
|
110 |
+
- 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
|
111 |
+
- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
|
112 |
+
- 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used.
|
113 |
+
- 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used.
|
114 |
+
- 🟦***strength_start***: strength of the Timestep Keyframe at start of range.
|
115 |
+
- 🟦***strength_end***: strength of the Timestep Keyframe at end of range.
|
116 |
+
- 🟦***interpolation***: the method of interpolation.
|
117 |
+
- 🟦***intervals***: the amount of keyframes to generate in total - the first will have its start_percent equal to start_percent, the last will have its start_percent equal to end_percent.
|
118 |
+
- 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
|
119 |
+
- 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
|
120 |
+
- 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes.
|
121 |
+
|
122 |
+
### Outputs
|
123 |
+
- 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
|
124 |
+
|
125 |
+
## Timestep Keyframe From List
|
126 |
+
![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/9e9c23bf-6f82-4ce7-b4d1-3016fd14707d)
|
127 |
+
|
128 |
+
Allows to create Timestep Keyframe via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0).
|
129 |
+
|
130 |
+
### Inputs
|
131 |
+
- 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
|
132 |
+
- 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
|
133 |
+
- 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
|
134 |
+
- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
|
135 |
+
- 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Timestep Keyframe; first will be assigned to start_percent, last will be assigned to end_percent, and the rest spread linearly between.
|
136 |
+
- 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used.
|
137 |
+
- 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used.
|
138 |
+
- 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
|
139 |
+
- 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
|
140 |
+
- 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes.
|
141 |
+
|
142 |
+
### Outputs
|
143 |
+
- 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
|
144 |
+
|
145 |
+
## Latent Keyframe
|
146 |
+
![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/7eb2cc4c-255c-4f32-b09b-699f713fada3)
|
147 |
+
|
148 |
+
A singular Latent Keyframe, selects the strength for a specific batch_index. If batch_index is not present during sampling, will simply have no effect. Can be chained with any other Latent Keyframe-type node to create a latent keyframe schedule.
|
149 |
+
|
150 |
+
### Inputs
|
151 |
+
- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If a Latent Keyframe contained in prev_latent_keyframes have the same batch_index as this Latent Keyframe, they will take priority over this node's value.*
|
152 |
+
- 🟦***batch_index***: index of latent in batch to apply controlnet strength to. Acts as the 'key' for the Latent Keyframe in the latent keyframe schedule.
|
153 |
+
- 🟦***strength***: strength of controlnet to apply to the corresponding latent.
|
154 |
+
|
155 |
+
### Outputs
|
156 |
+
- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
|
157 |
+
|
158 |
+
## Latent Keyframe Group
|
159 |
+
![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/5ce3b795-f5fc-4dc3-ae30-a4c7f87e278c)
|
160 |
+
|
161 |
+
Allows to create Latent Keyframes via individual indeces or python-style ranges.
|
162 |
+
|
163 |
+
### Inputs
|
164 |
+
- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
|
165 |
+
- 🟨***latent_optional***: the latents expected to be passed in for sampling; only required if you wish to use negative indeces (will be automatically converted to real values).
|
166 |
+
- 🟦***index_strengths***: string list of indeces or python-style ranges of indeces to assign strengths to. If latent_optional is passed in, can contain negative indeces or ranges that contain negative numbers, python-style. The different indeces must be comma separated. Individual latents can be specified by ```batch_index=strength```, like ```0=0.9```. Ranges can be specified by ```start_index_inclusive:end_index_exclusive=strength```, like ```0:8=strength```. Negative indeces are possible when latents_optional has an input, with a string such as ```0,-4=0.25```.
|
167 |
+
- 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
|
168 |
+
|
169 |
+
### Outputs
|
170 |
+
- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
|
171 |
+
|
172 |
+
## Latent Keyframe Interpolation
|
173 |
+
![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/7986c737-83b9-46bc-aab0-ae4c368df446)
|
174 |
+
|
175 |
+
Allows to create Latent Keyframes with interpolated values in a range.
|
176 |
+
|
177 |
+
### Inputs
|
178 |
+
- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
|
179 |
+
- 🟦***batch_index_from***: starting batch_index of range, included.
|
180 |
+
- 🟦***batch_index_to***: end batch_index of range, excluded (python-style range).
|
181 |
+
- 🟦***strength_from***: starting strength of interpolation.
|
182 |
+
- 🟦***strength_to***: end strength of interpolation.
|
183 |
+
- 🟦***interpolation***: the method of interpolation.
|
184 |
+
- 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
|
185 |
+
|
186 |
+
### Outputs
|
187 |
+
- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
|
188 |
+
|
189 |
+
## Latent Keyframe From List
|
190 |
+
![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/6cec701f-6183-4aeb-af5c-cac76f5591b7)
|
191 |
+
|
192 |
+
Allows to create Latent Keyframes via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes.
|
193 |
+
|
194 |
+
### Inputs
|
195 |
+
- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
|
196 |
+
- 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Latent Keyframe; the batch_index is the index of each float value in the list.
|
197 |
+
- 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
|
198 |
+
|
199 |
+
### Outputs
|
200 |
+
- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
|
201 |
+
|
202 |
+
# There are more nodes to document and show usage - will add this soon! TODO
|
ComfyUI-Advanced-ControlNet/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .adv_control.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
2 |
+
from .adv_control import documentation
|
3 |
+
|
4 |
+
WEB_DIRECTORY = "./web"
|
5 |
+
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', "WEB_DIRECTORY"]
|
6 |
+
documentation.format_descriptions(NODE_CLASS_MAPPINGS)
|
ComfyUI-Advanced-ControlNet/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (501 Bytes). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control.cpython-312.pyc
ADDED
Binary file (52.6 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control_lllite.cpython-312.pyc
ADDED
Binary file (25 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control_plusplus.cpython-312.pyc
ADDED
Binary file (31.4 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control_reference.cpython-312.pyc
ADDED
Binary file (59.4 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control_sparsectrl.cpython-312.pyc
ADDED
Binary file (52.5 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/control_svd.cpython-312.pyc
ADDED
Binary file (22 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/documentation.cpython-312.pyc
ADDED
Binary file (2.61 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/logger.cpython-312.pyc
ADDED
Binary file (1.79 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes.cpython-312.pyc
ADDED
Binary file (13.8 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_deprecated.cpython-312.pyc
ADDED
Binary file (10.8 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_keyframes.cpython-312.pyc
ADDED
Binary file (19 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_loosecontrol.cpython-312.pyc
ADDED
Binary file (3.26 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_plusplus.cpython-312.pyc
ADDED
Binary file (3.86 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_reference.cpython-312.pyc
ADDED
Binary file (4.63 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_sparsectrl.cpython-312.pyc
ADDED
Binary file (9.45 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/nodes_weight.cpython-312.pyc
ADDED
Binary file (12.3 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/sampling.cpython-312.pyc
ADDED
Binary file (12.5 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (54.1 kB). View file
|
|
ComfyUI-Advanced-ControlNet/adv_control/control.py
ADDED
@@ -0,0 +1,918 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Union
|
2 |
+
from torch import Tensor
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
|
6 |
+
import comfy.ops
|
7 |
+
import comfy.utils
|
8 |
+
import comfy.model_management
|
9 |
+
import comfy.model_detection
|
10 |
+
import comfy.controlnet as comfy_cn
|
11 |
+
from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter, StrengthType
|
12 |
+
from comfy.model_patcher import ModelPatcher
|
13 |
+
|
14 |
+
from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseSettings, SparseConst
|
15 |
+
from .control_lllite import LLLiteModule, LLLitePatch, load_controllllite
|
16 |
+
from .control_svd import svd_unet_config_from_diffusers_unet, SVDControlNet, svd_unet_to_diffusers
|
17 |
+
from .utils import (AdvancedControlBase, TimestepKeyframeGroup, LatentKeyframeGroup, AbstractPreprocWrapper, ControlWeightType, ControlWeights, WeightTypeException,
|
18 |
+
manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory,
|
19 |
+
broadcast_image_to_extend, extend_to_batch_size, ORIG_PREVIOUS_CONTROLNET, CONTROL_INIT_BY_ACN)
|
20 |
+
from .logger import logger
|
21 |
+
|
22 |
+
|
23 |
+
class ControlNetAdvanced(ControlNet, AdvancedControlBase):
|
24 |
+
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
25 |
+
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
26 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
|
27 |
+
self.is_flux = False
|
28 |
+
self.x_noisy_shape = None
|
29 |
+
|
30 |
+
def get_universal_weights(self) -> ControlWeights:
|
31 |
+
def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
|
32 |
+
if key == "middle":
|
33 |
+
return 1.0
|
34 |
+
c_len = len(control[key])
|
35 |
+
raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
|
36 |
+
raw_weights = raw_weights[:-1]
|
37 |
+
if key == "input":
|
38 |
+
raw_weights.reverse()
|
39 |
+
return raw_weights[idx]
|
40 |
+
return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)
|
41 |
+
|
42 |
+
def get_control_advanced(self, x_noisy, t, cond, batched_number):
|
43 |
+
# perform special version of get_control that supports sliding context and masks
|
44 |
+
return self.sliding_get_control(x_noisy, t, cond, batched_number)
|
45 |
+
|
46 |
+
def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
|
47 |
+
control_prev = None
|
48 |
+
if self.previous_controlnet is not None:
|
49 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
50 |
+
|
51 |
+
if self.timestep_range is not None:
|
52 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
53 |
+
if control_prev is not None:
|
54 |
+
return control_prev
|
55 |
+
else:
|
56 |
+
return None
|
57 |
+
|
58 |
+
dtype = self.control_model.dtype
|
59 |
+
if self.manual_cast_dtype is not None:
|
60 |
+
dtype = self.manual_cast_dtype
|
61 |
+
|
62 |
+
# make cond_hint appropriate dimensions
|
63 |
+
# TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
|
64 |
+
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
65 |
+
if self.cond_hint is not None:
|
66 |
+
del self.cond_hint
|
67 |
+
self.cond_hint = None
|
68 |
+
compression_ratio = self.compression_ratio
|
69 |
+
if self.vae is not None:
|
70 |
+
compression_ratio *= self.vae.downscale_ratio
|
71 |
+
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
|
72 |
+
if self.sub_idxs is not None:
|
73 |
+
actual_cond_hint_orig = self.cond_hint_original
|
74 |
+
if self.cond_hint_original.size(0) < self.full_latent_length:
|
75 |
+
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
|
76 |
+
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
77 |
+
else:
|
78 |
+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
79 |
+
if self.vae is not None:
|
80 |
+
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
81 |
+
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
82 |
+
comfy.model_management.load_models_gpu(loaded_models)
|
83 |
+
if self.latent_format is not None:
|
84 |
+
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
85 |
+
self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype)
|
86 |
+
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
87 |
+
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
|
88 |
+
|
89 |
+
# prepare mask_cond_hint
|
90 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
|
91 |
+
|
92 |
+
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
93 |
+
extra = self.extra_args.copy()
|
94 |
+
for c in self.extra_conds:
|
95 |
+
temp = cond.get(c, None)
|
96 |
+
if temp is not None:
|
97 |
+
extra[c] = temp.to(dtype)
|
98 |
+
|
99 |
+
timestep = self.model_sampling_current.timestep(t)
|
100 |
+
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
101 |
+
self.x_noisy_shape = x_noisy.shape
|
102 |
+
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
103 |
+
return self.control_merge(control, control_prev, output_dtype=None)
|
104 |
+
|
105 |
+
def pre_run_advanced(self, *args, **kwargs):
|
106 |
+
self.is_flux = "Flux" in str(type(self.control_model).__name__)
|
107 |
+
return super().pre_run_advanced(*args, **kwargs)
|
108 |
+
|
109 |
+
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, flux_shape=None):
|
110 |
+
if self.is_flux:
|
111 |
+
flux_shape = self.x_noisy_shape
|
112 |
+
return super().apply_advanced_strengths_and_masks(x, batched_number, flux_shape)
|
113 |
+
|
114 |
+
def copy(self):
|
115 |
+
c = ControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
116 |
+
c.control_model = self.control_model
|
117 |
+
c.control_model_wrapped = self.control_model_wrapped
|
118 |
+
self.copy_to(c)
|
119 |
+
self.copy_to_advanced(c)
|
120 |
+
return c
|
121 |
+
|
122 |
+
def cleanup_advanced(self):
|
123 |
+
self.x_noisy_shape = None
|
124 |
+
return super().cleanup_advanced()
|
125 |
+
|
126 |
+
@staticmethod
|
127 |
+
def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced':
|
128 |
+
to_return = ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
|
129 |
+
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, load_device=v.load_device,
|
130 |
+
manual_cast_dtype=v.manual_cast_dtype)
|
131 |
+
v.copy_to(to_return)
|
132 |
+
return to_return
|
133 |
+
|
134 |
+
|
135 |
+
class T2IAdapterAdvanced(T2IAdapter, AdvancedControlBase):
|
136 |
+
def __init__(self, t2i_model, timestep_keyframes: TimestepKeyframeGroup, channels_in, compression_ratio=8, upscale_algorithm="nearest_exact", device=None):
|
137 |
+
super().__init__(t2i_model=t2i_model, channels_in=channels_in, compression_ratio=compression_ratio, upscale_algorithm=upscale_algorithm, device=device)
|
138 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.t2iadapter())
|
139 |
+
|
140 |
+
def control_merge_inject(self, control: dict[str, list[Tensor]], control_prev, output_dtype):
|
141 |
+
# match batch_size
|
142 |
+
# TODO: make this more efficient by modifying the cached self.control_input val instead of doing this every step
|
143 |
+
for key in control:
|
144 |
+
control_current = control[key]
|
145 |
+
for i in range(len(control_current)):
|
146 |
+
x = control_current[i]
|
147 |
+
if x is not None and x.size(0) == 1 and x.size(0) != self.batch_size:
|
148 |
+
control_current[i] = x.repeat(self.batch_size, 1, 1, 1)[:self.batch_size]
|
149 |
+
return AdvancedControlBase.control_merge_inject(self, control, control_prev, output_dtype)
|
150 |
+
|
151 |
+
def get_universal_weights(self) -> ControlWeights:
|
152 |
+
def t2i_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
|
153 |
+
if key == "middle":
|
154 |
+
return 1.0
|
155 |
+
c_len = 8 #len(control[key])
|
156 |
+
raw_weights = [(self.weights.base_multiplier ** float((c_len-1) - i)) for i in range(c_len)]
|
157 |
+
raw_weights = [raw_weights[-c_len], raw_weights[-3], raw_weights[-2], raw_weights[-1]]
|
158 |
+
raw_weights = get_properly_arranged_t2i_weights(raw_weights)
|
159 |
+
if key == "input":
|
160 |
+
raw_weights.reverse()
|
161 |
+
return raw_weights[idx]
|
162 |
+
return self.weights.copy_with_new_weights(new_weight_func=t2i_weights_func)
|
163 |
+
|
164 |
+
def get_calc_pow(self, idx: int, control: dict[str, list[Tensor]], key: str) -> int:
|
165 |
+
if key == "middle":
|
166 |
+
return 0
|
167 |
+
# match how T2IAdapterAdvanced deals with universal weights
|
168 |
+
c_len = 8 #len(control[key])
|
169 |
+
indeces = [(c_len-1) - i for i in range(c_len)]
|
170 |
+
indeces = [indeces[-c_len], indeces[-3], indeces[-2], indeces[-1]]
|
171 |
+
indeces = get_properly_arranged_t2i_weights(indeces)
|
172 |
+
if key == "input":
|
173 |
+
indeces.reverse() # need to reverse to match recent ComfyUI changes
|
174 |
+
return indeces[idx]
|
175 |
+
|
176 |
+
def get_control_advanced(self, x_noisy, t, cond, batched_number):
|
177 |
+
try:
|
178 |
+
# if sub indexes present, replace original hint with subsection
|
179 |
+
if self.sub_idxs is not None:
|
180 |
+
# cond hints
|
181 |
+
full_cond_hint_original = self.cond_hint_original
|
182 |
+
actual_cond_hint_orig = full_cond_hint_original
|
183 |
+
del self.cond_hint
|
184 |
+
self.cond_hint = None
|
185 |
+
if full_cond_hint_original.size(0) < self.full_latent_length:
|
186 |
+
actual_cond_hint_orig = extend_to_batch_size(tensor=full_cond_hint_original, batch_size=full_cond_hint_original.size(0))
|
187 |
+
self.cond_hint_original = actual_cond_hint_orig[self.sub_idxs]
|
188 |
+
# mask hints
|
189 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
|
190 |
+
return super().get_control(x_noisy, t, cond, batched_number)
|
191 |
+
finally:
|
192 |
+
if self.sub_idxs is not None:
|
193 |
+
# replace original cond hint
|
194 |
+
self.cond_hint_original = full_cond_hint_original
|
195 |
+
del full_cond_hint_original
|
196 |
+
|
197 |
+
def copy(self):
|
198 |
+
c = T2IAdapterAdvanced(self.t2i_model, self.timestep_keyframes, self.channels_in, self.compression_ratio, self.upscale_algorithm)
|
199 |
+
self.copy_to(c)
|
200 |
+
self.copy_to_advanced(c)
|
201 |
+
return c
|
202 |
+
|
203 |
+
def cleanup(self):
|
204 |
+
super().cleanup()
|
205 |
+
self.cleanup_advanced()
|
206 |
+
|
207 |
+
@staticmethod
|
208 |
+
def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -> 'T2IAdapterAdvanced':
|
209 |
+
to_return = T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in,
|
210 |
+
compression_ratio=v.compression_ratio, upscale_algorithm=v.upscale_algorithm, device=v.device)
|
211 |
+
v.copy_to(to_return)
|
212 |
+
return to_return
|
213 |
+
|
214 |
+
|
215 |
+
class ControlLoraAdvanced(ControlLora, AdvancedControlBase):
|
216 |
+
def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False):
|
217 |
+
super().__init__(control_weights=control_weights, global_average_pooling=global_average_pooling)
|
218 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllora())
|
219 |
+
# use some functions from ControlNetAdvanced
|
220 |
+
self.get_control_advanced = ControlNetAdvanced.get_control_advanced.__get__(self, type(self))
|
221 |
+
self.sliding_get_control = ControlNetAdvanced.sliding_get_control.__get__(self, type(self))
|
222 |
+
|
223 |
+
def get_universal_weights(self) -> ControlWeights:
|
224 |
+
raw_weights = [(self.weights.base_multiplier ** float(9 - i)) for i in range(10)]
|
225 |
+
return self.weights.copy_with_new_weights(raw_weights)
|
226 |
+
|
227 |
+
def copy(self):
|
228 |
+
c = ControlLoraAdvanced(self.control_weights, self.timestep_keyframes, global_average_pooling=self.global_average_pooling)
|
229 |
+
self.copy_to(c)
|
230 |
+
self.copy_to_advanced(c)
|
231 |
+
return c
|
232 |
+
|
233 |
+
def cleanup(self):
|
234 |
+
super().cleanup()
|
235 |
+
self.cleanup_advanced()
|
236 |
+
|
237 |
+
@staticmethod
|
238 |
+
def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlLoraAdvanced':
|
239 |
+
to_return = ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe,
|
240 |
+
global_average_pooling=v.global_average_pooling)
|
241 |
+
v.copy_to(to_return)
|
242 |
+
return to_return
|
243 |
+
|
244 |
+
|
245 |
+
class SVDControlNetAdvanced(ControlNetAdvanced):
|
246 |
+
def __init__(self, control_model: SVDControlNet, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, load_device=None, manual_cast_dtype=None):
|
247 |
+
super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
248 |
+
|
249 |
+
def set_cond_hint_inject(self, *args, **kwargs):
|
250 |
+
to_return = super().set_cond_hint_inject(*args, **kwargs)
|
251 |
+
# cond hint for SVD-ControlNet needs to be scaled between (-1, 1) instead of (0, 1)
|
252 |
+
self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
|
253 |
+
return to_return
|
254 |
+
|
255 |
+
def get_control_advanced(self, x_noisy, t, cond, batched_number):
|
256 |
+
control_prev = None
|
257 |
+
if self.previous_controlnet is not None:
|
258 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
259 |
+
|
260 |
+
if self.timestep_range is not None:
|
261 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
262 |
+
if control_prev is not None:
|
263 |
+
return control_prev
|
264 |
+
else:
|
265 |
+
return None
|
266 |
+
|
267 |
+
dtype = self.control_model.dtype
|
268 |
+
if self.manual_cast_dtype is not None:
|
269 |
+
dtype = self.manual_cast_dtype
|
270 |
+
|
271 |
+
output_dtype = x_noisy.dtype
|
272 |
+
# make cond_hint appropriate dimensions
|
273 |
+
# TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
|
274 |
+
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
275 |
+
if self.cond_hint is not None:
|
276 |
+
del self.cond_hint
|
277 |
+
self.cond_hint = None
|
278 |
+
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
|
279 |
+
if self.sub_idxs is not None:
|
280 |
+
actual_cond_hint_orig = self.cond_hint_original
|
281 |
+
if self.cond_hint_original.size(0) < self.full_latent_length:
|
282 |
+
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
|
283 |
+
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
|
284 |
+
else:
|
285 |
+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
|
286 |
+
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
287 |
+
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
|
288 |
+
|
289 |
+
# prepare mask_cond_hint
|
290 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
|
291 |
+
|
292 |
+
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
293 |
+
# uses 'y' in new ComfyUI update
|
294 |
+
y = cond.get('y', None)
|
295 |
+
if y is not None:
|
296 |
+
y = y.to(dtype)
|
297 |
+
timestep = self.model_sampling_current.timestep(t)
|
298 |
+
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
299 |
+
# concat c_concat if exists (should exist for SVD), doubling channels to 8
|
300 |
+
if cond.get('c_concat', None) is not None:
|
301 |
+
x_noisy = torch.cat([x_noisy] + [cond['c_concat']], dim=1)
|
302 |
+
|
303 |
+
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, cond=cond)
|
304 |
+
return self.control_merge(control, control_prev, output_dtype)
|
305 |
+
|
306 |
+
def copy(self):
|
307 |
+
c = SVDControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
308 |
+
self.copy_to(c)
|
309 |
+
self.copy_to_advanced(c)
|
310 |
+
return c
|
311 |
+
|
312 |
+
|
313 |
+
class SparseCtrlAdvanced(ControlNetAdvanced):
|
314 |
+
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, sparse_settings: SparseSettings=None, global_average_pooling=False, load_device=None, manual_cast_dtype=None):
|
315 |
+
super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
316 |
+
self.control_model_wrapped = SparseModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
317 |
+
self.add_compatible_weight(ControlWeightType.SPARSECTRL)
|
318 |
+
self.control_model: SparseControlNet = self.control_model # does nothing except help with IDE hints
|
319 |
+
if self.control_model.use_simplified_conditioning_embedding:
|
320 |
+
# TODO: allow vae_optional to be used instead of preprocessor
|
321 |
+
#self.require_vae = True
|
322 |
+
self.allow_condhint_latents = True
|
323 |
+
self.sparse_settings = sparse_settings if sparse_settings is not None else SparseSettings.default()
|
324 |
+
self.model_latent_format = None # latent format for active SD model, NOT controlnet
|
325 |
+
self.preprocessed = False
|
326 |
+
|
327 |
+
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
|
328 |
+
# normal ControlNet stuff
|
329 |
+
control_prev = None
|
330 |
+
if self.previous_controlnet is not None:
|
331 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
332 |
+
|
333 |
+
if self.timestep_range is not None:
|
334 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
335 |
+
if control_prev is not None:
|
336 |
+
return control_prev
|
337 |
+
else:
|
338 |
+
return None
|
339 |
+
|
340 |
+
dtype = self.control_model.dtype
|
341 |
+
if self.manual_cast_dtype is not None:
|
342 |
+
dtype = self.manual_cast_dtype
|
343 |
+
output_dtype = x_noisy.dtype
|
344 |
+
# set actual input length on motion model
|
345 |
+
actual_length = x_noisy.size(0)//batched_number
|
346 |
+
full_length = actual_length if self.sub_idxs is None else self.full_latent_length
|
347 |
+
self.control_model.set_actual_length(actual_length=actual_length, full_length=full_length)
|
348 |
+
# prepare cond_hint, if needed
|
349 |
+
dim_mult = 1 if self.control_model.use_simplified_conditioning_embedding else 8
|
350 |
+
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2]*dim_mult != self.cond_hint.shape[2] or x_noisy.shape[3]*dim_mult != self.cond_hint.shape[3]:
|
351 |
+
# clear out cond_hint and conditioning_mask
|
352 |
+
if self.cond_hint is not None:
|
353 |
+
del self.cond_hint
|
354 |
+
self.cond_hint = None
|
355 |
+
# first, figure out which cond idxs are relevant, and where they fit in
|
356 |
+
cond_idxs, hint_order = self.sparse_settings.sparse_method.get_indexes(hint_length=self.cond_hint_original.size(0), full_length=full_length,
|
357 |
+
sub_idxs=self.sub_idxs if self.sparse_settings.is_context_aware() else None)
|
358 |
+
range_idxs = list(range(full_length)) if self.sub_idxs is None else self.sub_idxs
|
359 |
+
hint_idxs = [] # idxs in cond_idxs
|
360 |
+
local_idxs = [] # idx to put in final cond_hint
|
361 |
+
for i,cond_idx in enumerate(cond_idxs):
|
362 |
+
if cond_idx in range_idxs:
|
363 |
+
hint_idxs.append(i)
|
364 |
+
local_idxs.append(range_idxs.index(cond_idx))
|
365 |
+
# log_string = f"cond_idxs: {cond_idxs}, local_idxs: {local_idxs}, hint_idxs: {hint_idxs}, hint_order: {hint_order}"
|
366 |
+
# if self.sub_idxs is not None:
|
367 |
+
# log_string += f" sub_idxs: {self.sub_idxs[0]}-{self.sub_idxs[-1]}"
|
368 |
+
# logger.warn(log_string)
|
369 |
+
# determine cond/uncond indexes that will get masked
|
370 |
+
self.local_sparse_idxs = []
|
371 |
+
self.local_sparse_idxs_inverse = list(range(x_noisy.size(0)))
|
372 |
+
for batch_idx in range(batched_number):
|
373 |
+
for i in local_idxs:
|
374 |
+
actual_i = i+(batch_idx*actual_length)
|
375 |
+
self.local_sparse_idxs.append(actual_i)
|
376 |
+
if actual_i in self.local_sparse_idxs_inverse:
|
377 |
+
self.local_sparse_idxs_inverse.remove(actual_i)
|
378 |
+
# sub_cond_hint now contains the hints relevant to current x_noisy
|
379 |
+
if hint_order is None:
|
380 |
+
sub_cond_hint = self.cond_hint_original[hint_idxs].to(dtype).to(x_noisy.device)
|
381 |
+
else:
|
382 |
+
sub_cond_hint = self.cond_hint_original[hint_order][hint_idxs].to(dtype).to(x_noisy.device)
|
383 |
+
# scale cond_hints to match noisy input
|
384 |
+
if self.control_model.use_simplified_conditioning_embedding:
|
385 |
+
# RGB SparseCtrl; the inputs are latents - use bilinear to avoid blocky artifacts
|
386 |
+
sub_cond_hint = self.model_latent_format.process_in(sub_cond_hint) # multiplies by model scale factor
|
387 |
+
sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3], x_noisy.shape[2], "nearest-exact", "center").to(dtype).to(x_noisy.device)
|
388 |
+
else:
|
389 |
+
# other SparseCtrl; inputs are typical images
|
390 |
+
sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
|
391 |
+
# prepare cond_hint (b, c, h ,w)
|
392 |
+
cond_shape = list(sub_cond_hint.shape)
|
393 |
+
cond_shape[0] = len(range_idxs)
|
394 |
+
self.cond_hint = torch.zeros(cond_shape).to(dtype).to(x_noisy.device)
|
395 |
+
self.cond_hint[local_idxs] = sub_cond_hint[:]
|
396 |
+
# prepare cond_mask (b, 1, h, w)
|
397 |
+
cond_shape[1] = 1
|
398 |
+
cond_mask = torch.zeros(cond_shape).to(dtype).to(x_noisy.device)
|
399 |
+
cond_mask[local_idxs] = self.sparse_settings.sparse_mask_mult * self.weights.extras.get(SparseConst.MASK_MULT, 1.0)
|
400 |
+
# combine cond_hint and cond_mask into (b, c+1, h, w)
|
401 |
+
if not self.sparse_settings.merged:
|
402 |
+
self.cond_hint = torch.cat([self.cond_hint, cond_mask], dim=1)
|
403 |
+
del sub_cond_hint
|
404 |
+
del cond_mask
|
405 |
+
# make cond_hint match x_noisy batch
|
406 |
+
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
407 |
+
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
|
408 |
+
|
409 |
+
# prepare mask_cond_hint
|
410 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
|
411 |
+
|
412 |
+
context = cond['c_crossattn']
|
413 |
+
y = cond.get('y', None)
|
414 |
+
if y is not None:
|
415 |
+
y = y.to(dtype)
|
416 |
+
timestep = self.model_sampling_current.timestep(t)
|
417 |
+
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
418 |
+
|
419 |
+
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
420 |
+
return self.control_merge(control, control_prev, output_dtype)
|
421 |
+
|
422 |
+
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, *args, **kwargs):
|
423 |
+
# apply mults to indexes with and without a direct condhint
|
424 |
+
x[self.local_sparse_idxs] *= self.sparse_settings.sparse_hint_mult * self.weights.extras.get(SparseConst.HINT_MULT, 1.0)
|
425 |
+
x[self.local_sparse_idxs_inverse] *= self.sparse_settings.sparse_nonhint_mult * self.weights.extras.get(SparseConst.NONHINT_MULT, 1.0)
|
426 |
+
return super().apply_advanced_strengths_and_masks(x, batched_number, *args, **kwargs)
|
427 |
+
|
428 |
+
def pre_run_advanced(self, model, percent_to_timestep_function):
|
429 |
+
super().pre_run_advanced(model, percent_to_timestep_function)
|
430 |
+
if isinstance(self.cond_hint_original, AbstractPreprocWrapper):
|
431 |
+
if not self.control_model.use_simplified_conditioning_embedding:
|
432 |
+
raise ValueError("Any model besides RGB SparseCtrl should NOT have its images go through the RGB SparseCtrl preprocessor.")
|
433 |
+
self.cond_hint_original = self.cond_hint_original.condhint
|
434 |
+
self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond hint
|
435 |
+
if self.control_model.motion_wrapper is not None:
|
436 |
+
self.control_model.motion_wrapper.reset()
|
437 |
+
self.control_model.motion_wrapper.set_strength(self.sparse_settings.motion_strength)
|
438 |
+
self.control_model.motion_wrapper.set_scale_multiplier(self.sparse_settings.motion_scale)
|
439 |
+
|
440 |
+
def cleanup_advanced(self):
|
441 |
+
super().cleanup_advanced()
|
442 |
+
if self.model_latent_format is not None:
|
443 |
+
del self.model_latent_format
|
444 |
+
self.model_latent_format = None
|
445 |
+
self.local_sparse_idxs = None
|
446 |
+
self.local_sparse_idxs_inverse = None
|
447 |
+
|
448 |
+
def copy(self):
|
449 |
+
c = SparseCtrlAdvanced(self.control_model, self.timestep_keyframes, self.sparse_settings, self.global_average_pooling, self.load_device, self.manual_cast_dtype)
|
450 |
+
self.copy_to(c)
|
451 |
+
self.copy_to_advanced(c)
|
452 |
+
return c
|
453 |
+
|
454 |
+
|
455 |
+
def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
|
456 |
+
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
457 |
+
# from pathlib import Path
|
458 |
+
# log_name = ckpt_path.split('\\')[-1]
|
459 |
+
# with open(Path(__file__).parent.parent.parent / rf"keys_{log_name}.txt", "w") as afile:
|
460 |
+
# for key, value in controlnet_data.items():
|
461 |
+
# afile.write(f"{key}:\t{value.shape}\n")
|
462 |
+
control = None
|
463 |
+
# check if a non-vanilla ControlNet
|
464 |
+
controlnet_type = ControlWeightType.DEFAULT
|
465 |
+
has_controlnet_key = False
|
466 |
+
has_motion_modules_key = False
|
467 |
+
has_temporal_res_block_key = False
|
468 |
+
for key in controlnet_data:
|
469 |
+
# LLLite check
|
470 |
+
if "lllite" in key:
|
471 |
+
controlnet_type = ControlWeightType.CONTROLLLLITE
|
472 |
+
break
|
473 |
+
# SparseCtrl check
|
474 |
+
elif "motion_modules" in key:
|
475 |
+
has_motion_modules_key = True
|
476 |
+
elif "controlnet" in key:
|
477 |
+
has_controlnet_key = True
|
478 |
+
# SVD-ControlNet check
|
479 |
+
elif "temporal_res_block" in key:
|
480 |
+
has_temporal_res_block_key = True
|
481 |
+
# ControlNet++ check
|
482 |
+
elif "task_embedding" in key:
|
483 |
+
pass
|
484 |
+
|
485 |
+
if has_controlnet_key and has_motion_modules_key:
|
486 |
+
controlnet_type = ControlWeightType.SPARSECTRL
|
487 |
+
elif has_controlnet_key and has_temporal_res_block_key:
|
488 |
+
controlnet_type = ControlWeightType.SVD_CONTROLNET
|
489 |
+
|
490 |
+
if controlnet_type != ControlWeightType.DEFAULT:
|
491 |
+
if controlnet_type == ControlWeightType.CONTROLLLLITE:
|
492 |
+
control = load_controllllite(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe)
|
493 |
+
elif controlnet_type == ControlWeightType.SPARSECTRL:
|
494 |
+
control = load_sparsectrl(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe, model=model)
|
495 |
+
elif controlnet_type == ControlWeightType.SVD_CONTROLNET:
|
496 |
+
control = load_svdcontrolnet(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe)
|
497 |
+
# otherwise, load vanilla ControlNet
|
498 |
+
else:
|
499 |
+
try:
|
500 |
+
# hacky way of getting load_torch_file in load_controlnet to use already-present controlnet_data and not redo loading
|
501 |
+
orig_load_torch_file = comfy.utils.load_torch_file
|
502 |
+
comfy.utils.load_torch_file = load_torch_file_with_dict_factory(controlnet_data, orig_load_torch_file)
|
503 |
+
control = comfy_cn.load_controlnet(ckpt_path, model=model)
|
504 |
+
finally:
|
505 |
+
comfy.utils.load_torch_file = orig_load_torch_file
|
506 |
+
return convert_to_advanced(control, timestep_keyframe=timestep_keyframe)
|
507 |
+
|
508 |
+
|
509 |
+
def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None):
|
510 |
+
# if already advanced, leave it be
|
511 |
+
if is_advanced_controlnet(control):
|
512 |
+
return control
|
513 |
+
# if exactly ControlNet returned, transform it into ControlNetAdvanced
|
514 |
+
if type(control) == ControlNet:
|
515 |
+
control = ControlNetAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
|
516 |
+
if is_sd3_advanced_controlnet(control):
|
517 |
+
control.require_vae = True
|
518 |
+
return control
|
519 |
+
# if exactly ControlLora returned, transform it into ControlLoraAdvanced
|
520 |
+
elif type(control) == ControlLora:
|
521 |
+
return ControlLoraAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
|
522 |
+
# if T2IAdapter returned, transform it into T2IAdapterAdvanced
|
523 |
+
elif isinstance(control, T2IAdapter):
|
524 |
+
return T2IAdapterAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
|
525 |
+
# otherwise, leave it be - might be something I am not supporting yet
|
526 |
+
return control
|
527 |
+
|
528 |
+
|
529 |
+
def convert_all_to_advanced(conds: list[list[dict[str]]]) -> tuple[bool, list]:
|
530 |
+
cache = {}
|
531 |
+
modified = False
|
532 |
+
new_conds = []
|
533 |
+
for cond in conds:
|
534 |
+
converted_cond = None
|
535 |
+
if cond is not None:
|
536 |
+
need_to_convert = False
|
537 |
+
# first, check if there is even a need to convert
|
538 |
+
for sub_cond in cond:
|
539 |
+
actual_cond = sub_cond[1]
|
540 |
+
if "control" in actual_cond:
|
541 |
+
if not are_all_advanced_controlnet(actual_cond["control"]):
|
542 |
+
need_to_convert = True
|
543 |
+
break
|
544 |
+
if not need_to_convert:
|
545 |
+
converted_cond = cond
|
546 |
+
else:
|
547 |
+
converted_cond = []
|
548 |
+
for sub_cond in cond:
|
549 |
+
new_sub_cond: list = []
|
550 |
+
for actual_cond in sub_cond:
|
551 |
+
if not type(actual_cond) == dict:
|
552 |
+
new_sub_cond.append(actual_cond)
|
553 |
+
continue
|
554 |
+
if "control" not in actual_cond:
|
555 |
+
new_sub_cond.append(actual_cond)
|
556 |
+
elif are_all_advanced_controlnet(actual_cond["control"]):
|
557 |
+
new_sub_cond.append(actual_cond)
|
558 |
+
else:
|
559 |
+
actual_cond = actual_cond.copy()
|
560 |
+
actual_cond["control"] = _convert_all_control_to_advanced(actual_cond["control"], cache)
|
561 |
+
new_sub_cond.append(actual_cond)
|
562 |
+
modified = True
|
563 |
+
converted_cond.append(new_sub_cond)
|
564 |
+
new_conds.append(converted_cond)
|
565 |
+
return modified, new_conds
|
566 |
+
|
567 |
+
|
568 |
+
def _convert_all_control_to_advanced(input_object: ControlBase, cache: dict):
|
569 |
+
output_object = input_object
|
570 |
+
# iteratively convert to advanced, if needed
|
571 |
+
next_cn = None
|
572 |
+
curr_cn = input_object
|
573 |
+
iter = 0
|
574 |
+
while curr_cn is not None:
|
575 |
+
if not is_advanced_controlnet(curr_cn):
|
576 |
+
# if already in cache, then conversion was done before, so just link it and exit
|
577 |
+
if curr_cn in cache:
|
578 |
+
new_cn = cache[curr_cn]
|
579 |
+
if next_cn is not None:
|
580 |
+
setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet)
|
581 |
+
next_cn.previous_controlnet = new_cn
|
582 |
+
if iter == 0: # if was top-level controlnet, that's the new output
|
583 |
+
output_object = new_cn
|
584 |
+
break
|
585 |
+
try:
|
586 |
+
# convert to advanced, and assign previous_controlnet (convert doesn't transfer it)
|
587 |
+
new_cn = convert_to_advanced(curr_cn)
|
588 |
+
except Exception as e:
|
589 |
+
raise Exception("Failed to automatically convert a ControlNet to Advanced to support sliding window context.", e)
|
590 |
+
new_cn.previous_controlnet = curr_cn.previous_controlnet
|
591 |
+
if iter == 0: # if was top-level controlnet, that's the new output
|
592 |
+
output_object = new_cn
|
593 |
+
# if next_cn is present, then it needs to be pointed to new_cn
|
594 |
+
if next_cn is not None:
|
595 |
+
setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet)
|
596 |
+
next_cn.previous_controlnet = new_cn
|
597 |
+
# add to cache
|
598 |
+
cache[curr_cn] = new_cn
|
599 |
+
curr_cn = new_cn
|
600 |
+
next_cn = curr_cn
|
601 |
+
curr_cn = curr_cn.previous_controlnet
|
602 |
+
iter += 1
|
603 |
+
return output_object
|
604 |
+
|
605 |
+
|
606 |
+
def restore_all_controlnet_conns(conds: list[list[dict[str]]]):
|
607 |
+
# if a cn has an _orig_previous_controlnet property, restore it and delete
|
608 |
+
for main_cond in conds:
|
609 |
+
if main_cond is not None:
|
610 |
+
for cond in main_cond:
|
611 |
+
if "control" in cond[1]:
|
612 |
+
# if ACN is the one to have initialized it, delete it
|
613 |
+
# TODO: maybe check if someone else did a similar hack, and carefully pluck out our stuff?
|
614 |
+
if CONTROL_INIT_BY_ACN in cond[1]:
|
615 |
+
cond[1].pop("control")
|
616 |
+
cond[1].pop(CONTROL_INIT_BY_ACN)
|
617 |
+
else:
|
618 |
+
_restore_all_controlnet_conns(cond[1]["control"])
|
619 |
+
|
620 |
+
|
621 |
+
def _restore_all_controlnet_conns(input_object: ControlBase):
|
622 |
+
# restore original previous_controlnet if needed
|
623 |
+
curr_cn = input_object
|
624 |
+
while curr_cn is not None:
|
625 |
+
if hasattr(curr_cn, ORIG_PREVIOUS_CONTROLNET):
|
626 |
+
curr_cn.previous_controlnet = getattr(curr_cn, ORIG_PREVIOUS_CONTROLNET)
|
627 |
+
delattr(curr_cn, ORIG_PREVIOUS_CONTROLNET)
|
628 |
+
curr_cn = curr_cn.previous_controlnet
|
629 |
+
|
630 |
+
|
631 |
+
def are_all_advanced_controlnet(input_object: ControlBase):
|
632 |
+
# iteratively check if linked controlnets objects are all advanced
|
633 |
+
curr_cn = input_object
|
634 |
+
while curr_cn is not None:
|
635 |
+
if not is_advanced_controlnet(curr_cn):
|
636 |
+
return False
|
637 |
+
curr_cn = curr_cn.previous_controlnet
|
638 |
+
return True
|
639 |
+
|
640 |
+
|
641 |
+
def is_advanced_controlnet(input_object):
|
642 |
+
return hasattr(input_object, "sub_idxs")
|
643 |
+
|
644 |
+
|
645 |
+
def is_sd3_advanced_controlnet(input_object: ControlNetAdvanced):
|
646 |
+
return type(input_object) == ControlNetAdvanced and input_object.latent_format is not None
|
647 |
+
|
648 |
+
|
649 |
+
def load_sparsectrl(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, sparse_settings=SparseSettings.default(), model=None) -> SparseCtrlAdvanced:
|
650 |
+
if controlnet_data is None:
|
651 |
+
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
652 |
+
# first, separate out motion part from normal controlnet part and attempt to load that portion
|
653 |
+
motion_data = {}
|
654 |
+
for key in list(controlnet_data.keys()):
|
655 |
+
if "temporal" in key:
|
656 |
+
motion_data[key] = controlnet_data.pop(key)
|
657 |
+
if len(motion_data) == 0:
|
658 |
+
raise ValueError(f"No motion-related keys in '{ckpt_path}'; not a valid SparseCtrl model!")
|
659 |
+
|
660 |
+
# now, load as if it was a normal controlnet - mostly copied from comfy load_controlnet function
|
661 |
+
controlnet_config: dict[str] = None
|
662 |
+
is_diffusers = False
|
663 |
+
use_simplified_conditioning_embedding = False
|
664 |
+
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data:
|
665 |
+
is_diffusers = True
|
666 |
+
if "controlnet_cond_embedding.weight" in controlnet_data:
|
667 |
+
is_diffusers = True
|
668 |
+
use_simplified_conditioning_embedding = True
|
669 |
+
if is_diffusers: #diffusers format
|
670 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
671 |
+
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
672 |
+
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
673 |
+
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
674 |
+
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
675 |
+
|
676 |
+
count = 0
|
677 |
+
loop = True
|
678 |
+
while loop:
|
679 |
+
suffix = [".weight", ".bias"]
|
680 |
+
for s in suffix:
|
681 |
+
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
682 |
+
k_out = "zero_convs.{}.0{}".format(count, s)
|
683 |
+
if k_in not in controlnet_data:
|
684 |
+
loop = False
|
685 |
+
break
|
686 |
+
diffusers_keys[k_in] = k_out
|
687 |
+
count += 1
|
688 |
+
# normal conditioning embedding
|
689 |
+
if not use_simplified_conditioning_embedding:
|
690 |
+
count = 0
|
691 |
+
loop = True
|
692 |
+
while loop:
|
693 |
+
suffix = [".weight", ".bias"]
|
694 |
+
for s in suffix:
|
695 |
+
if count == 0:
|
696 |
+
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
697 |
+
else:
|
698 |
+
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
699 |
+
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
700 |
+
if k_in not in controlnet_data:
|
701 |
+
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
702 |
+
loop = False
|
703 |
+
diffusers_keys[k_in] = k_out
|
704 |
+
count += 1
|
705 |
+
# simplified conditioning embedding
|
706 |
+
else:
|
707 |
+
count = 0
|
708 |
+
suffix = [".weight", ".bias"]
|
709 |
+
for s in suffix:
|
710 |
+
k_in = "controlnet_cond_embedding{}".format(s)
|
711 |
+
k_out = "input_hint_block.{}{}".format(count, s)
|
712 |
+
diffusers_keys[k_in] = k_out
|
713 |
+
|
714 |
+
new_sd = {}
|
715 |
+
for k in diffusers_keys:
|
716 |
+
if k in controlnet_data:
|
717 |
+
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
718 |
+
|
719 |
+
leftover_keys = controlnet_data.keys()
|
720 |
+
if len(leftover_keys) > 0:
|
721 |
+
logger.info("leftover keys:", leftover_keys)
|
722 |
+
controlnet_data = new_sd
|
723 |
+
|
724 |
+
pth_key = 'control_model.zero_convs.0.0.weight'
|
725 |
+
pth = False
|
726 |
+
key = 'zero_convs.0.0.weight'
|
727 |
+
if pth_key in controlnet_data:
|
728 |
+
pth = True
|
729 |
+
key = pth_key
|
730 |
+
prefix = "control_model."
|
731 |
+
elif key in controlnet_data:
|
732 |
+
prefix = ""
|
733 |
+
else:
|
734 |
+
raise ValueError("The provided model is not a valid SparseCtrl model! [ErrorCode: HORSERADISH]")
|
735 |
+
|
736 |
+
if controlnet_config is None:
|
737 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
738 |
+
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
739 |
+
load_device = comfy.model_management.get_torch_device()
|
740 |
+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
741 |
+
if manual_cast_dtype is not None:
|
742 |
+
controlnet_config["operations"] = manual_cast_clean_groupnorm
|
743 |
+
else:
|
744 |
+
controlnet_config["operations"] = disable_weight_init_clean_groupnorm
|
745 |
+
controlnet_config.pop("out_channels")
|
746 |
+
# get proper hint channels
|
747 |
+
if use_simplified_conditioning_embedding:
|
748 |
+
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
749 |
+
controlnet_config["use_simplified_conditioning_embedding"] = use_simplified_conditioning_embedding
|
750 |
+
else:
|
751 |
+
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
752 |
+
controlnet_config["use_simplified_conditioning_embedding"] = use_simplified_conditioning_embedding
|
753 |
+
control_model = SparseControlNet(**controlnet_config)
|
754 |
+
|
755 |
+
if pth:
|
756 |
+
if 'difference' in controlnet_data:
|
757 |
+
if model is not None:
|
758 |
+
comfy.model_management.load_models_gpu([model])
|
759 |
+
model_sd = model.model_state_dict()
|
760 |
+
for x in controlnet_data:
|
761 |
+
c_m = "control_model."
|
762 |
+
if x.startswith(c_m):
|
763 |
+
sd_key = "diffusion_model.{}".format(x[len(c_m):])
|
764 |
+
if sd_key in model_sd:
|
765 |
+
cd = controlnet_data[x]
|
766 |
+
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
767 |
+
else:
|
768 |
+
logger.warning("WARNING: Loaded a diff SparseCtrl without a model. It will very likely not work.")
|
769 |
+
|
770 |
+
class WeightsLoader(torch.nn.Module):
|
771 |
+
pass
|
772 |
+
w = WeightsLoader()
|
773 |
+
w.control_model = control_model
|
774 |
+
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
775 |
+
else:
|
776 |
+
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
777 |
+
if len(missing) > 0 or len(unexpected) > 0:
|
778 |
+
logger.info(f"SparseCtrl ControlNet: {missing}, {unexpected}")
|
779 |
+
|
780 |
+
global_average_pooling = False
|
781 |
+
filename = os.path.splitext(ckpt_path)[0]
|
782 |
+
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
783 |
+
global_average_pooling = True
|
784 |
+
|
785 |
+
# actually load motion portion of model now
|
786 |
+
motion_wrapper: SparseCtrlMotionWrapper = SparseCtrlMotionWrapper(motion_data, ops=controlnet_config.get("operations", None)).to(comfy.model_management.unet_dtype())
|
787 |
+
missing, unexpected = motion_wrapper.load_state_dict(motion_data)
|
788 |
+
if len(missing) > 0 or len(unexpected) > 0:
|
789 |
+
logger.info(f"SparseCtrlMotionWrapper: {missing}, {unexpected}")
|
790 |
+
|
791 |
+
# both motion portion and controlnet portions are loaded; bring them together if using motion model
|
792 |
+
if sparse_settings.use_motion:
|
793 |
+
motion_wrapper.inject(control_model)
|
794 |
+
|
795 |
+
control = SparseCtrlAdvanced(control_model, timestep_keyframes=timestep_keyframe, sparse_settings=sparse_settings, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
796 |
+
return control
|
797 |
+
|
798 |
+
|
799 |
+
def load_svdcontrolnet(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
|
800 |
+
if controlnet_data is None:
|
801 |
+
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
802 |
+
|
803 |
+
controlnet_config = None
|
804 |
+
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
805 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
806 |
+
controlnet_config = svd_unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
807 |
+
diffusers_keys = svd_unet_to_diffusers(controlnet_config)
|
808 |
+
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
809 |
+
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
810 |
+
|
811 |
+
count = 0
|
812 |
+
loop = True
|
813 |
+
while loop:
|
814 |
+
suffix = [".weight", ".bias"]
|
815 |
+
for s in suffix:
|
816 |
+
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
817 |
+
k_out = "zero_convs.{}.0{}".format(count, s)
|
818 |
+
if k_in not in controlnet_data:
|
819 |
+
loop = False
|
820 |
+
break
|
821 |
+
diffusers_keys[k_in] = k_out
|
822 |
+
count += 1
|
823 |
+
|
824 |
+
count = 0
|
825 |
+
loop = True
|
826 |
+
while loop:
|
827 |
+
suffix = [".weight", ".bias"]
|
828 |
+
for s in suffix:
|
829 |
+
if count == 0:
|
830 |
+
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
831 |
+
else:
|
832 |
+
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
833 |
+
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
834 |
+
if k_in not in controlnet_data:
|
835 |
+
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
836 |
+
loop = False
|
837 |
+
diffusers_keys[k_in] = k_out
|
838 |
+
count += 1
|
839 |
+
|
840 |
+
new_sd = {}
|
841 |
+
for k in diffusers_keys:
|
842 |
+
if k in controlnet_data:
|
843 |
+
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
844 |
+
|
845 |
+
leftover_keys = controlnet_data.keys()
|
846 |
+
if len(leftover_keys) > 0:
|
847 |
+
spatial_leftover_keys = []
|
848 |
+
temporal_leftover_keys = []
|
849 |
+
other_leftover_keys = []
|
850 |
+
for key in leftover_keys:
|
851 |
+
if "spatial" in key:
|
852 |
+
spatial_leftover_keys.append(key)
|
853 |
+
elif "temporal" in key:
|
854 |
+
temporal_leftover_keys.append(key)
|
855 |
+
else:
|
856 |
+
other_leftover_keys.append(key)
|
857 |
+
logger.warn(f"spatial_leftover_keys ({len(spatial_leftover_keys)}): {spatial_leftover_keys}")
|
858 |
+
logger.warn(f"temporal_leftover_keys ({len(temporal_leftover_keys)}): {temporal_leftover_keys}")
|
859 |
+
logger.warn(f"other_leftover_keys ({len(other_leftover_keys)}): {other_leftover_keys}")
|
860 |
+
#print("leftover keys:", leftover_keys)
|
861 |
+
controlnet_data = new_sd
|
862 |
+
|
863 |
+
pth_key = 'control_model.zero_convs.0.0.weight'
|
864 |
+
pth = False
|
865 |
+
key = 'zero_convs.0.0.weight'
|
866 |
+
if pth_key in controlnet_data:
|
867 |
+
pth = True
|
868 |
+
key = pth_key
|
869 |
+
prefix = "control_model."
|
870 |
+
elif key in controlnet_data:
|
871 |
+
prefix = ""
|
872 |
+
else:
|
873 |
+
raise ValueError("The provided model is not a valid SVD-ControlNet model! [ErrorCode: MUSTARD]")
|
874 |
+
|
875 |
+
if controlnet_config is None:
|
876 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
877 |
+
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
878 |
+
load_device = comfy.model_management.get_torch_device()
|
879 |
+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
880 |
+
if manual_cast_dtype is not None:
|
881 |
+
controlnet_config["operations"] = comfy.ops.manual_cast
|
882 |
+
controlnet_config.pop("out_channels")
|
883 |
+
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
884 |
+
control_model = SVDControlNet(**controlnet_config)
|
885 |
+
|
886 |
+
if pth:
|
887 |
+
if 'difference' in controlnet_data:
|
888 |
+
if model is not None:
|
889 |
+
comfy.model_management.load_models_gpu([model])
|
890 |
+
model_sd = model.model_state_dict()
|
891 |
+
for x in controlnet_data:
|
892 |
+
c_m = "control_model."
|
893 |
+
if x.startswith(c_m):
|
894 |
+
sd_key = "diffusion_model.{}".format(x[len(c_m):])
|
895 |
+
if sd_key in model_sd:
|
896 |
+
cd = controlnet_data[x]
|
897 |
+
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
898 |
+
else:
|
899 |
+
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
900 |
+
|
901 |
+
class WeightsLoader(torch.nn.Module):
|
902 |
+
pass
|
903 |
+
w = WeightsLoader()
|
904 |
+
w.control_model = control_model
|
905 |
+
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
906 |
+
else:
|
907 |
+
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
908 |
+
if len(missing) > 0 or len(unexpected) > 0:
|
909 |
+
logger.info(f"SVD-ControlNet: {missing}, {unexpected}")
|
910 |
+
|
911 |
+
global_average_pooling = False
|
912 |
+
filename = os.path.splitext(ckpt_path)[0]
|
913 |
+
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
914 |
+
global_average_pooling = True
|
915 |
+
|
916 |
+
control = SVDControlNetAdvanced(control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
917 |
+
return control
|
918 |
+
|
ComfyUI-Advanced-ControlNet/adv_control/control_lllite.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
|
2 |
+
# basically, all the LLLite core code is from there, which I then combined with
|
3 |
+
# Advanced-ControlNet features and QoL
|
4 |
+
import math
|
5 |
+
from typing import Union
|
6 |
+
from torch import Tensor
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
|
10 |
+
import comfy.utils
|
11 |
+
import comfy.ops
|
12 |
+
import comfy.model_management
|
13 |
+
from comfy.model_patcher import ModelPatcher
|
14 |
+
from comfy.controlnet import ControlBase
|
15 |
+
|
16 |
+
from .logger import logger
|
17 |
+
from .utils import (AdvancedControlBase, TimestepKeyframeGroup, ControlWeights, broadcast_image_to_extend, extend_to_batch_size,
|
18 |
+
deepcopy_with_sharing, prepare_mask_batch)
|
19 |
+
|
20 |
+
|
21 |
+
# based on set_model_patch code in comfy/model_patcher.py
|
22 |
+
def set_model_patch(model_options, patch, name):
|
23 |
+
to = model_options["transformer_options"]
|
24 |
+
# check if patch was already added
|
25 |
+
if "patches" in to:
|
26 |
+
current_patches = to["patches"].get(name, [])
|
27 |
+
if patch in current_patches:
|
28 |
+
return
|
29 |
+
if "patches" not in to:
|
30 |
+
to["patches"] = {}
|
31 |
+
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
32 |
+
|
33 |
+
def set_model_attn1_patch(model_options, patch):
|
34 |
+
set_model_patch(model_options, patch, "attn1_patch")
|
35 |
+
|
36 |
+
def set_model_attn2_patch(model_options, patch):
|
37 |
+
set_model_patch(model_options, patch, "attn2_patch")
|
38 |
+
|
39 |
+
|
40 |
+
def extra_options_to_module_prefix(extra_options):
|
41 |
+
# extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64}
|
42 |
+
|
43 |
+
# block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0),
|
44 |
+
# ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)]
|
45 |
+
# transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block
|
46 |
+
# block_index is: 0-1 or 0-9, depends on the block
|
47 |
+
# input 7 and 8, middle has 10 blocks
|
48 |
+
|
49 |
+
# make module name from extra_options
|
50 |
+
block = extra_options["block"]
|
51 |
+
block_index = extra_options["block_index"]
|
52 |
+
if block[0] == "input":
|
53 |
+
module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}"
|
54 |
+
elif block[0] == "middle":
|
55 |
+
module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}"
|
56 |
+
elif block[0] == "output":
|
57 |
+
module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}"
|
58 |
+
else:
|
59 |
+
raise Exception(f"ControlLLLite: invalid block name '{block[0]}'. Expected 'input', 'middle', or 'output'.")
|
60 |
+
return module_pfx
|
61 |
+
|
62 |
+
|
63 |
+
class LLLitePatch:
|
64 |
+
ATTN1 = "attn1"
|
65 |
+
ATTN2 = "attn2"
|
66 |
+
def __init__(self, modules: dict[str, 'LLLiteModule'], patch_type: str, control: Union[AdvancedControlBase, ControlBase]=None):
|
67 |
+
self.modules = modules
|
68 |
+
self.control = control
|
69 |
+
self.patch_type = patch_type
|
70 |
+
#logger.error(f"create LLLitePatch: {id(self)},{control}")
|
71 |
+
|
72 |
+
def __call__(self, q, k, v, extra_options):
|
73 |
+
#logger.error(f"in __call__: {id(self)}")
|
74 |
+
# determine if have anything to run
|
75 |
+
if self.control.timestep_range is not None:
|
76 |
+
# it turns out comparing single-value tensors to floats is extremely slow
|
77 |
+
# a: Tensor = extra_options["sigmas"][0]
|
78 |
+
if self.control.t > self.control.timestep_range[0] or self.control.t < self.control.timestep_range[1]:
|
79 |
+
return q, k, v
|
80 |
+
|
81 |
+
module_pfx = extra_options_to_module_prefix(extra_options)
|
82 |
+
|
83 |
+
is_attn1 = q.shape[-1] == k.shape[-1] # self attention
|
84 |
+
if is_attn1:
|
85 |
+
module_pfx = module_pfx + "_attn1"
|
86 |
+
else:
|
87 |
+
module_pfx = module_pfx + "_attn2"
|
88 |
+
|
89 |
+
module_pfx_to_q = module_pfx + "_to_q"
|
90 |
+
module_pfx_to_k = module_pfx + "_to_k"
|
91 |
+
module_pfx_to_v = module_pfx + "_to_v"
|
92 |
+
|
93 |
+
if module_pfx_to_q in self.modules:
|
94 |
+
q = q + self.modules[module_pfx_to_q](q, self.control)
|
95 |
+
if module_pfx_to_k in self.modules:
|
96 |
+
k = k + self.modules[module_pfx_to_k](k, self.control)
|
97 |
+
if module_pfx_to_v in self.modules:
|
98 |
+
v = v + self.modules[module_pfx_to_v](v, self.control)
|
99 |
+
|
100 |
+
return q, k, v
|
101 |
+
|
102 |
+
def to(self, device):
|
103 |
+
#logger.info(f"to... has control? {self.control}")
|
104 |
+
for d in self.modules.keys():
|
105 |
+
self.modules[d] = self.modules[d].to(device)
|
106 |
+
return self
|
107 |
+
|
108 |
+
def set_control(self, control: Union[AdvancedControlBase, ControlBase]) -> 'LLLitePatch':
|
109 |
+
self.control = control
|
110 |
+
return self
|
111 |
+
#logger.error(f"set control for LLLitePatch: {id(self)}, cn: {id(control)}")
|
112 |
+
|
113 |
+
def clone_with_control(self, control: AdvancedControlBase):
|
114 |
+
#logger.error(f"clone-set control for LLLitePatch: {id(self)},{id(control)}")
|
115 |
+
return LLLitePatch(self.modules, self.patch_type, control)
|
116 |
+
|
117 |
+
def cleanup(self):
|
118 |
+
#total_cleaned = 0
|
119 |
+
for module in self.modules.values():
|
120 |
+
module.cleanup()
|
121 |
+
# total_cleaned += 1
|
122 |
+
#logger.info(f"cleaned modules: {total_cleaned}, {id(self)}")
|
123 |
+
#logger.error(f"cleanup LLLitePatch: {id(self)}")
|
124 |
+
|
125 |
+
# make sure deepcopy does not copy control, and deepcopied LLLitePatch should be assigned to control
|
126 |
+
# def __deepcopy__(self, memo):
|
127 |
+
# self.cleanup()
|
128 |
+
# to_return: LLLitePatch = deepcopy_with_sharing(self, shared_attribute_names = ['control'], memo=memo)
|
129 |
+
# #logger.warn(f"patch {id(self)} turned into {id(to_return)}")
|
130 |
+
# try:
|
131 |
+
# if self.patch_type == self.ATTN1:
|
132 |
+
# to_return.control.patch_attn1 = to_return
|
133 |
+
# elif self.patch_type == self.ATTN2:
|
134 |
+
# to_return.control.patch_attn2 = to_return
|
135 |
+
# except Exception:
|
136 |
+
# pass
|
137 |
+
# return to_return
|
138 |
+
|
139 |
+
|
140 |
+
# TODO: use comfy.ops to support fp8 properly
|
141 |
+
class LLLiteModule(torch.nn.Module):
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
name: str,
|
145 |
+
is_conv2d: bool,
|
146 |
+
in_dim: int,
|
147 |
+
depth: int,
|
148 |
+
cond_emb_dim: int,
|
149 |
+
mlp_dim: int,
|
150 |
+
):
|
151 |
+
super().__init__()
|
152 |
+
self.name = name
|
153 |
+
self.is_conv2d = is_conv2d
|
154 |
+
self.is_first = False
|
155 |
+
|
156 |
+
modules = []
|
157 |
+
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
|
158 |
+
if depth == 1:
|
159 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
160 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
161 |
+
elif depth == 2:
|
162 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
163 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
164 |
+
elif depth == 3:
|
165 |
+
# kernel size 8 is too large, so set it to 4
|
166 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
167 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
168 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
169 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
170 |
+
|
171 |
+
self.conditioning1 = torch.nn.Sequential(*modules)
|
172 |
+
|
173 |
+
if self.is_conv2d:
|
174 |
+
self.down = torch.nn.Sequential(
|
175 |
+
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
176 |
+
torch.nn.ReLU(inplace=True),
|
177 |
+
)
|
178 |
+
self.mid = torch.nn.Sequential(
|
179 |
+
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
180 |
+
torch.nn.ReLU(inplace=True),
|
181 |
+
)
|
182 |
+
self.up = torch.nn.Sequential(
|
183 |
+
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
self.down = torch.nn.Sequential(
|
187 |
+
torch.nn.Linear(in_dim, mlp_dim),
|
188 |
+
torch.nn.ReLU(inplace=True),
|
189 |
+
)
|
190 |
+
self.mid = torch.nn.Sequential(
|
191 |
+
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
|
192 |
+
torch.nn.ReLU(inplace=True),
|
193 |
+
)
|
194 |
+
self.up = torch.nn.Sequential(
|
195 |
+
torch.nn.Linear(mlp_dim, in_dim),
|
196 |
+
)
|
197 |
+
|
198 |
+
self.depth = depth
|
199 |
+
self.cond_emb = None
|
200 |
+
self.cx_shape = None
|
201 |
+
self.prev_batch = 0
|
202 |
+
self.prev_sub_idxs = None
|
203 |
+
|
204 |
+
def cleanup(self):
|
205 |
+
del self.cond_emb
|
206 |
+
self.cond_emb = None
|
207 |
+
self.cx_shape = None
|
208 |
+
self.prev_batch = 0
|
209 |
+
self.prev_sub_idxs = None
|
210 |
+
|
211 |
+
def forward(self, x: Tensor, control: Union[AdvancedControlBase, ControlBase]):
|
212 |
+
mask = None
|
213 |
+
mask_tk = None
|
214 |
+
#logger.info(x.shape)
|
215 |
+
if self.cond_emb is None or control.sub_idxs != self.prev_sub_idxs or x.shape[0] != self.prev_batch:
|
216 |
+
# print(f"cond_emb is None, {self.name}")
|
217 |
+
cond_hint = control.cond_hint.to(x.device, dtype=x.dtype)
|
218 |
+
if control.latent_dims_div2 is not None and x.shape[-1] != 1280:
|
219 |
+
cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div2[0] * 8, control.latent_dims_div2[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
|
220 |
+
elif control.latent_dims_div4 is not None and x.shape[-1] == 1280:
|
221 |
+
cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div4[0] * 8, control.latent_dims_div4[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
|
222 |
+
cx = self.conditioning1(cond_hint)
|
223 |
+
self.cx_shape = cx.shape
|
224 |
+
if not self.is_conv2d:
|
225 |
+
# reshape / b,c,h,w -> b,h*w,c
|
226 |
+
n, c, h, w = cx.shape
|
227 |
+
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
228 |
+
self.cond_emb = cx
|
229 |
+
# save prev values
|
230 |
+
self.prev_batch = x.shape[0]
|
231 |
+
self.prev_sub_idxs = control.sub_idxs
|
232 |
+
|
233 |
+
cx: torch.Tensor = self.cond_emb
|
234 |
+
# print(f"forward {self.name}, {cx.shape}, {x.shape}")
|
235 |
+
|
236 |
+
# TODO: make masks work for conv2d (could not find any ControlLLLites at this time that use them)
|
237 |
+
# create masks
|
238 |
+
if not self.is_conv2d:
|
239 |
+
n, c, h, w = self.cx_shape
|
240 |
+
if control.mask_cond_hint is not None:
|
241 |
+
mask = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
|
242 |
+
mask = mask.view(mask.shape[0], 1, h * w).permute(0, 2, 1)
|
243 |
+
if control.tk_mask_cond_hint is not None:
|
244 |
+
mask_tk = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
|
245 |
+
mask_tk = mask_tk.view(mask_tk.shape[0], 1, h * w).permute(0, 2, 1)
|
246 |
+
|
247 |
+
# x in uncond/cond doubles batch size
|
248 |
+
if x.shape[0] != cx.shape[0]:
|
249 |
+
if self.is_conv2d:
|
250 |
+
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
|
251 |
+
else:
|
252 |
+
# print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
|
253 |
+
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
|
254 |
+
if mask is not None:
|
255 |
+
mask = mask.repeat(x.shape[0] // mask.shape[0], 1, 1)
|
256 |
+
if mask_tk is not None:
|
257 |
+
mask_tk = mask_tk.repeat(x.shape[0] // mask_tk.shape[0], 1, 1)
|
258 |
+
|
259 |
+
if mask is None:
|
260 |
+
mask = 1.0
|
261 |
+
elif mask_tk is not None:
|
262 |
+
mask = mask * mask_tk
|
263 |
+
|
264 |
+
#logger.info(f"cs: {cx.shape}, x: {x.shape}, is_conv2d: {self.is_conv2d}")
|
265 |
+
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
|
266 |
+
cx = self.mid(cx)
|
267 |
+
cx = self.up(cx)
|
268 |
+
if control.latent_keyframes is not None:
|
269 |
+
cx = cx * control.calc_latent_keyframe_mults(x=cx, batched_number=control.batched_number)
|
270 |
+
if control.weights is not None and control.weights.has_uncond_multiplier:
|
271 |
+
cond_or_uncond = control.batched_number.cond_or_uncond
|
272 |
+
actual_length = cx.size(0) // control.batched_number
|
273 |
+
for idx, cond_type in enumerate(cond_or_uncond):
|
274 |
+
# if uncond, set to weight's uncond_multiplier
|
275 |
+
if cond_type == 1:
|
276 |
+
cx[actual_length*idx:actual_length*(idx+1)] *= control.weights.uncond_multiplier
|
277 |
+
return cx * mask * control.strength * control._current_timestep_keyframe.strength
|
278 |
+
|
279 |
+
|
280 |
+
class ControlLLLiteModules(torch.nn.Module):
|
281 |
+
def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch):
|
282 |
+
super().__init__()
|
283 |
+
self.patch_attn1_modules = torch.nn.Sequential(*list(patch_attn1.modules.values()))
|
284 |
+
self.patch_attn2_modules = torch.nn.Sequential(*list(patch_attn2.modules.values()))
|
285 |
+
|
286 |
+
|
287 |
+
class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase):
|
288 |
+
# This ControlNet is more of an attention patch than a traditional controlnet
|
289 |
+
def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device, ops: comfy.ops.disable_weight_init):
|
290 |
+
super().__init__()
|
291 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite())
|
292 |
+
self.device = device
|
293 |
+
self.ops = ops
|
294 |
+
self.patch_attn1 = patch_attn1.clone_with_control(self)
|
295 |
+
self.patch_attn2 = patch_attn2.clone_with_control(self)
|
296 |
+
self.control_model = ControlLLLiteModules(self.patch_attn1, self.patch_attn2)
|
297 |
+
self.control_model_wrapped = ModelPatcher(self.control_model, load_device=device, offload_device=comfy.model_management.unet_offload_device())
|
298 |
+
self.latent_dims_div2 = None
|
299 |
+
self.latent_dims_div4 = None
|
300 |
+
|
301 |
+
def live_model_patches(self, model_options):
|
302 |
+
set_model_attn1_patch(model_options, self.patch_attn1.set_control(self))
|
303 |
+
set_model_attn2_patch(model_options, self.patch_attn2.set_control(self))
|
304 |
+
|
305 |
+
# def patch_model(self, model: ModelPatcher):
|
306 |
+
# model.set_model_attn1_patch(self.patch_attn1)
|
307 |
+
# model.set_model_attn2_patch(self.patch_attn2)
|
308 |
+
|
309 |
+
def set_cond_hint_inject(self, *args, **kwargs):
|
310 |
+
to_return = super().set_cond_hint_inject(*args, **kwargs)
|
311 |
+
# cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1)
|
312 |
+
self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
|
313 |
+
return to_return
|
314 |
+
|
315 |
+
def pre_run_advanced(self, *args, **kwargs):
|
316 |
+
AdvancedControlBase.pre_run_advanced(self, *args, **kwargs)
|
317 |
+
#logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}")
|
318 |
+
self.patch_attn1.set_control(self)
|
319 |
+
self.patch_attn2.set_control(self)
|
320 |
+
#logger.warn(f"in pre_run_advanced: {id(self)}")
|
321 |
+
|
322 |
+
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
|
323 |
+
# normal ControlNet stuff
|
324 |
+
control_prev = None
|
325 |
+
if self.previous_controlnet is not None:
|
326 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
327 |
+
|
328 |
+
if self.timestep_range is not None:
|
329 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
330 |
+
return control_prev
|
331 |
+
|
332 |
+
dtype = x_noisy.dtype
|
333 |
+
# prepare cond_hint
|
334 |
+
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
335 |
+
if self.cond_hint is not None:
|
336 |
+
del self.cond_hint
|
337 |
+
self.cond_hint = None
|
338 |
+
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
|
339 |
+
if self.sub_idxs is not None:
|
340 |
+
actual_cond_hint_orig = self.cond_hint_original
|
341 |
+
if self.cond_hint_original.size(0) < self.full_latent_length:
|
342 |
+
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
|
343 |
+
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
|
344 |
+
else:
|
345 |
+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
|
346 |
+
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
347 |
+
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
|
348 |
+
# some special logic here compared to other controlnets:
|
349 |
+
# * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
|
350 |
+
# * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
|
351 |
+
divisible_by_2_h = x_noisy.shape[2]%2==0
|
352 |
+
divisible_by_2_w = x_noisy.shape[3]%2==0
|
353 |
+
if not (divisible_by_2_h and divisible_by_2_w):
|
354 |
+
#logger.warn(f"{x_noisy.shape} not divisible by 2!")
|
355 |
+
new_h = (x_noisy.shape[2]//2)*2
|
356 |
+
new_w = (x_noisy.shape[3]//2)*2
|
357 |
+
if not divisible_by_2_h:
|
358 |
+
new_h += 2
|
359 |
+
if not divisible_by_2_w:
|
360 |
+
new_w += 2
|
361 |
+
self.latent_dims_div2 = (new_h, new_w)
|
362 |
+
divisible_by_4_h = x_noisy.shape[2]%4==0
|
363 |
+
divisible_by_4_w = x_noisy.shape[3]%4==0
|
364 |
+
if not (divisible_by_4_h and divisible_by_4_w):
|
365 |
+
#logger.warn(f"{x_noisy.shape} not divisible by 4!")
|
366 |
+
new_h = (x_noisy.shape[2]//4)*4
|
367 |
+
new_w = (x_noisy.shape[3]//4)*4
|
368 |
+
if not divisible_by_4_h:
|
369 |
+
new_h += 4
|
370 |
+
if not divisible_by_4_w:
|
371 |
+
new_w += 4
|
372 |
+
self.latent_dims_div4 = (new_h, new_w)
|
373 |
+
# prepare mask
|
374 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
|
375 |
+
# done preparing; model patches will take care of everything now.
|
376 |
+
# return normal controlnet stuff
|
377 |
+
return control_prev
|
378 |
+
|
379 |
+
def get_models(self):
|
380 |
+
to_return: list = super().get_models()
|
381 |
+
to_return.append(self.control_model_wrapped)
|
382 |
+
return to_return
|
383 |
+
|
384 |
+
def cleanup_advanced(self):
|
385 |
+
super().cleanup_advanced()
|
386 |
+
self.patch_attn1.cleanup()
|
387 |
+
self.patch_attn2.cleanup()
|
388 |
+
self.latent_dims_div2 = None
|
389 |
+
self.latent_dims_div4 = None
|
390 |
+
|
391 |
+
def copy(self):
|
392 |
+
c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes, self.device, self.ops)
|
393 |
+
self.copy_to(c)
|
394 |
+
self.copy_to_advanced(c)
|
395 |
+
return c
|
396 |
+
|
397 |
+
# deepcopy needs to properly keep track of objects to work between model.clone calls!
|
398 |
+
# def __deepcopy__(self, *args, **kwargs):
|
399 |
+
# self.cleanup_advanced()
|
400 |
+
# return self
|
401 |
+
|
402 |
+
# def get_models(self):
|
403 |
+
# # get_models is called once at the start of every KSampler run - use to reset already_patched status
|
404 |
+
# out = super().get_models()
|
405 |
+
# logger.error(f"in get_models! {id(self)}")
|
406 |
+
# return out
|
407 |
+
|
408 |
+
|
409 |
+
def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None):
|
410 |
+
if controlnet_data is None:
|
411 |
+
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
412 |
+
# adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
|
413 |
+
# first, split weights for each module
|
414 |
+
module_weights = {}
|
415 |
+
for key, value in controlnet_data.items():
|
416 |
+
fragments = key.split(".")
|
417 |
+
module_name = fragments[0]
|
418 |
+
weight_name = ".".join(fragments[1:])
|
419 |
+
|
420 |
+
if module_name not in module_weights:
|
421 |
+
module_weights[module_name] = {}
|
422 |
+
module_weights[module_name][weight_name] = value
|
423 |
+
|
424 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
425 |
+
load_device = comfy.model_management.get_torch_device()
|
426 |
+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
427 |
+
ops = comfy.ops.disable_weight_init
|
428 |
+
if manual_cast_dtype is not None:
|
429 |
+
ops = comfy.ops.manual_cast
|
430 |
+
|
431 |
+
# next, load each module
|
432 |
+
modules = {}
|
433 |
+
for module_name, weights in module_weights.items():
|
434 |
+
# kohya planned to do something about how these should be chosen, so I'm not touching this
|
435 |
+
# since I am not familiar with the logic for this
|
436 |
+
if "conditioning1.4.weight" in weights:
|
437 |
+
depth = 3
|
438 |
+
elif weights["conditioning1.2.weight"].shape[-1] == 4:
|
439 |
+
depth = 2
|
440 |
+
else:
|
441 |
+
depth = 1
|
442 |
+
|
443 |
+
module = LLLiteModule(
|
444 |
+
name=module_name,
|
445 |
+
is_conv2d=weights["down.0.weight"].ndim == 4,
|
446 |
+
in_dim=weights["down.0.weight"].shape[1],
|
447 |
+
depth=depth,
|
448 |
+
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
|
449 |
+
mlp_dim=weights["down.0.weight"].shape[0],
|
450 |
+
)
|
451 |
+
# load weights into module
|
452 |
+
module.load_state_dict(weights)
|
453 |
+
modules[module_name] = module.to(dtype=unet_dtype)
|
454 |
+
if len(modules) == 1:
|
455 |
+
module.is_first = True
|
456 |
+
|
457 |
+
#logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules")
|
458 |
+
|
459 |
+
patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1)
|
460 |
+
patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2)
|
461 |
+
control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe, device=load_device, ops=ops)
|
462 |
+
return control
|
ComfyUI-Advanced-ControlNet/adv_control/control_plusplus.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code ported and modified from the diffusers ControlNetPlus repo by Qi Xin:
|
2 |
+
# https://github.com/xinsir6/ControlNetPlus/blob/main/models/controlnet_union.py
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import torch as th
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch import Tensor
|
10 |
+
from collections import OrderedDict
|
11 |
+
|
12 |
+
|
13 |
+
from comfy.ldm.modules.diffusionmodules.util import (zero_module, timestep_embedding)
|
14 |
+
|
15 |
+
from comfy.cldm.cldm import ControlNet as ControlNetCLDM
|
16 |
+
import comfy.cldm.cldm
|
17 |
+
from comfy.controlnet import ControlNet
|
18 |
+
#from comfy.t2i_adapter.adapter import ResidualAttentionBlock
|
19 |
+
from comfy.ldm.modules.attention import optimized_attention
|
20 |
+
import comfy.ops
|
21 |
+
import comfy.model_management
|
22 |
+
import comfy.model_detection
|
23 |
+
import comfy.utils
|
24 |
+
|
25 |
+
from .utils import (AdvancedControlBase, ControlWeights, ControlWeightType, TimestepKeyframeGroup, AbstractPreprocWrapper,
|
26 |
+
extend_to_batch_size, broadcast_image_to_extend)
|
27 |
+
from .logger import logger
|
28 |
+
|
29 |
+
|
30 |
+
class PlusPlusType:
|
31 |
+
OPENPOSE = "openpose"
|
32 |
+
DEPTH = "depth"
|
33 |
+
THICKLINE = "hed/pidi/scribble/ted"
|
34 |
+
THINLINE = "canny/lineart/mlsd"
|
35 |
+
NORMAL = "normal"
|
36 |
+
SEGMENT = "segment"
|
37 |
+
TILE = "tile"
|
38 |
+
REPAINT = "inpaint/outpaint"
|
39 |
+
NONE = "none"
|
40 |
+
_LIST_WITH_NONE = [OPENPOSE, DEPTH, THICKLINE, THINLINE, NORMAL, SEGMENT, TILE, REPAINT, NONE]
|
41 |
+
_LIST = [OPENPOSE, DEPTH, THICKLINE, THINLINE, NORMAL, SEGMENT, TILE, REPAINT]
|
42 |
+
_DICT = {OPENPOSE: 0, DEPTH: 1, THICKLINE: 2, THINLINE: 3, NORMAL: 4, SEGMENT: 5, TILE: 6, REPAINT: 7, NONE: -1}
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def to_idx(cls, control_type: str):
|
46 |
+
try:
|
47 |
+
return cls._DICT[control_type]
|
48 |
+
except KeyError:
|
49 |
+
raise Exception(f"Unknown control type '{control_type}'.")
|
50 |
+
|
51 |
+
|
52 |
+
class PlusPlusInput:
|
53 |
+
def __init__(self, image: Tensor, control_type: str, strength: float):
|
54 |
+
self.image = image
|
55 |
+
self.control_type = control_type
|
56 |
+
self.strength = strength
|
57 |
+
|
58 |
+
def clone(self):
|
59 |
+
return PlusPlusInput(self.image, self.control_type, self.strength)
|
60 |
+
|
61 |
+
|
62 |
+
class PlusPlusInputGroup:
|
63 |
+
def __init__(self):
|
64 |
+
self.controls: dict[str, PlusPlusInput] = {}
|
65 |
+
|
66 |
+
def add(self, pp_input: PlusPlusInput):
|
67 |
+
if pp_input.control_type in self.controls:
|
68 |
+
raise Exception(f"Control type '{pp_input.control_type}' is already present; ControlNet++ does not allow more than 1 of each type.")
|
69 |
+
self.controls[pp_input.control_type] = pp_input
|
70 |
+
|
71 |
+
def clone(self) -> 'PlusPlusInputGroup':
|
72 |
+
cloned = PlusPlusInputGroup()
|
73 |
+
for key, value in self.controls.items():
|
74 |
+
cloned.controls[key] = value.clone()
|
75 |
+
return cloned
|
76 |
+
|
77 |
+
|
78 |
+
class PlusPlusImageWrapper(AbstractPreprocWrapper):
|
79 |
+
error_msg = error_msg = "Invalid use of ControlNet++ Image Wrapper. The output of ControlNet++ Image Wrapper is NOT a usual image, but an object holding the images and extra info - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
|
80 |
+
def __init__(self, condhint: PlusPlusInputGroup):
|
81 |
+
super().__init__(condhint)
|
82 |
+
# just an IDE type hint
|
83 |
+
self.condhint: PlusPlusInputGroup
|
84 |
+
|
85 |
+
def movedim(self, source: int, destination: int):
|
86 |
+
condhint = self.condhint.clone()
|
87 |
+
for pp_input in condhint.controls.values():
|
88 |
+
pp_input.image = pp_input.image.movedim(source, destination)
|
89 |
+
return PlusPlusImageWrapper(condhint)
|
90 |
+
|
91 |
+
# parts taken from comfy/cldm/cldm.py
|
92 |
+
class OptimizedAttention(nn.Module):
|
93 |
+
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
94 |
+
super().__init__()
|
95 |
+
self.heads = nhead
|
96 |
+
self.c = c
|
97 |
+
|
98 |
+
self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
|
99 |
+
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
x = self.in_proj(x)
|
103 |
+
q, k, v = x.split(self.c, dim=2)
|
104 |
+
out = optimized_attention(q, k, v, self.heads)
|
105 |
+
return self.out_proj(out)
|
106 |
+
|
107 |
+
class QuickGELU(nn.Module):
|
108 |
+
def forward(self, x: torch.Tensor):
|
109 |
+
return x * torch.sigmoid(1.702 * x)
|
110 |
+
|
111 |
+
class ResBlockUnionControlnet(nn.Module):
|
112 |
+
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
|
113 |
+
super().__init__()
|
114 |
+
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
|
115 |
+
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
|
116 |
+
self.mlp = nn.Sequential(
|
117 |
+
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
|
118 |
+
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
|
119 |
+
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
|
120 |
+
|
121 |
+
def attention(self, x: torch.Tensor):
|
122 |
+
return self.attn(x)
|
123 |
+
|
124 |
+
def forward(self, x: torch.Tensor):
|
125 |
+
x = x + self.attention(self.ln_1(x))
|
126 |
+
x = x + self.mlp(self.ln_2(x))
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
class ControlAddEmbeddingAdv(nn.Module):
|
131 |
+
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations: comfy.ops.disable_weight_init=None):
|
132 |
+
super().__init__()
|
133 |
+
self.num_control_type = num_control_type
|
134 |
+
self.in_dim = in_dim
|
135 |
+
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
|
136 |
+
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
|
137 |
+
|
138 |
+
def forward(self, control_type, dtype, device):
|
139 |
+
if control_type is None:
|
140 |
+
control_type = torch.zeros((self.num_control_type,), device=device)
|
141 |
+
c_type = timestep_embedding(control_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
|
142 |
+
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
|
143 |
+
|
144 |
+
|
145 |
+
class ControlNetPlusPlus(ControlNetCLDM):
|
146 |
+
def __init__(self, *args,**kwargs):
|
147 |
+
super().__init__(*args, **kwargs)
|
148 |
+
|
149 |
+
operations: comfy.ops.disable_weight_init = kwargs.get("operations", comfy.ops.disable_weight_init)
|
150 |
+
device = kwargs.get("device", None)
|
151 |
+
|
152 |
+
time_embed_dim = self.model_channels * 4
|
153 |
+
control_add_embed_dim = 256
|
154 |
+
|
155 |
+
self.control_add_embedding = ControlAddEmbeddingAdv(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
|
156 |
+
|
157 |
+
def union_controlnet_merge(self, hint: list[Tensor], control_type, emb, context):
|
158 |
+
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
|
159 |
+
indexes = torch.nonzero(control_type[0])
|
160 |
+
inputs = []
|
161 |
+
condition_list = []
|
162 |
+
|
163 |
+
for idx in range(indexes.shape[0]):
|
164 |
+
controlnet_cond = self.input_hint_block(hint[indexes[idx][0]], emb, context)
|
165 |
+
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
166 |
+
if idx < indexes.shape[0]:
|
167 |
+
feat_seq += self.task_embedding[indexes[idx][0]].to(dtype=feat_seq.dtype, device=feat_seq.device)
|
168 |
+
|
169 |
+
inputs.append(feat_seq.unsqueeze(1))
|
170 |
+
condition_list.append(controlnet_cond)
|
171 |
+
|
172 |
+
x = torch.cat(inputs, dim=1)
|
173 |
+
x = self.transformer_layes(x)
|
174 |
+
|
175 |
+
controlnet_cond_fuser = None
|
176 |
+
for idx in range(indexes.shape[0]):
|
177 |
+
alpha = self.spatial_ch_projs(x[:, idx])
|
178 |
+
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
179 |
+
o = condition_list[idx] + alpha
|
180 |
+
if controlnet_cond_fuser is None:
|
181 |
+
controlnet_cond_fuser = o
|
182 |
+
else:
|
183 |
+
controlnet_cond_fuser += o
|
184 |
+
return controlnet_cond_fuser
|
185 |
+
|
186 |
+
def forward(self, x: Tensor, hint: list[Tensor], timesteps, context, y: Tensor=None, **kwargs):
|
187 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
188 |
+
emb = self.time_embed(t_emb)
|
189 |
+
|
190 |
+
guided_hint = None
|
191 |
+
if self.control_add_embedding is not None:
|
192 |
+
control_type = kwargs.get("control_type", None)
|
193 |
+
|
194 |
+
emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
|
195 |
+
if control_type is not None:
|
196 |
+
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
|
197 |
+
|
198 |
+
if guided_hint is None:
|
199 |
+
guided_hint = self.input_hint_block(hint[0], emb, context)
|
200 |
+
|
201 |
+
out_output = []
|
202 |
+
out_middle = []
|
203 |
+
|
204 |
+
hs = []
|
205 |
+
if self.num_classes is not None:
|
206 |
+
assert y.shape[0] == x.shape[0]
|
207 |
+
emb = emb + self.label_emb(y)
|
208 |
+
|
209 |
+
h = x
|
210 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
211 |
+
if guided_hint is not None:
|
212 |
+
h = module(h, emb, context)
|
213 |
+
h += guided_hint
|
214 |
+
guided_hint = None
|
215 |
+
else:
|
216 |
+
h = module(h, emb, context)
|
217 |
+
out_output.append(zero_conv(h, emb, context))
|
218 |
+
|
219 |
+
h = self.middle_block(h, emb, context)
|
220 |
+
out_middle.append(self.middle_block_out(h, emb, context))
|
221 |
+
|
222 |
+
return {"middle": out_middle, "output": out_output}
|
223 |
+
|
224 |
+
|
225 |
+
class ControlNetPlusPlusAdvanced(ControlNet, AdvancedControlBase):
|
226 |
+
def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, load_device=None, manual_cast_dtype=None):
|
227 |
+
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
228 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
|
229 |
+
self.add_compatible_weight(ControlWeightType.CONTROLNETPLUSPLUS)
|
230 |
+
# for IDE type hint purposes
|
231 |
+
self.control_model: ControlNetPlusPlus
|
232 |
+
self.cond_hint_original: Union[PlusPlusImageWrapper, PlusPlusInputGroup]
|
233 |
+
self.cond_hint: list[Union[Tensor, None]]
|
234 |
+
self.cond_hint_shape: Tensor = None
|
235 |
+
self.cond_hint_types: Tensor = None
|
236 |
+
# in case it is using the single loader
|
237 |
+
self.single_control_type: str = None
|
238 |
+
|
239 |
+
def get_universal_weights(self) -> ControlWeights:
|
240 |
+
def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
|
241 |
+
if key == "middle":
|
242 |
+
return 1.0
|
243 |
+
c_len = len(control[key])
|
244 |
+
raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
|
245 |
+
raw_weights = raw_weights[:-1]
|
246 |
+
if key == "input":
|
247 |
+
raw_weights.reverse()
|
248 |
+
return raw_weights[idx]
|
249 |
+
return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)
|
250 |
+
|
251 |
+
def verify_control_type(self, model_name: str, pp_group: PlusPlusInputGroup=None):
|
252 |
+
if pp_group is not None:
|
253 |
+
for pp_input in pp_group.controls.values():
|
254 |
+
if PlusPlusType.to_idx(pp_input.control_type) >= self.control_model.num_control_type:
|
255 |
+
raise Exception(f"ControlNet++ model '{model_name}' does not support control_type '{pp_input.control_type}'.")
|
256 |
+
if self.single_control_type is not None:
|
257 |
+
if PlusPlusType.to_idx(self.single_control_type) >= self.control_model.num_control_type:
|
258 |
+
raise Exception(f"ControlNet++ model '{model_name}' does not support control_type '{self.single_control_type}'.")
|
259 |
+
|
260 |
+
def set_cond_hint_inject(self, *args, **kwargs):
|
261 |
+
to_return = super().set_cond_hint_inject(*args, **kwargs)
|
262 |
+
# if not single_control_type, expect PlusPlusImageWrapper
|
263 |
+
if self.single_control_type is None:
|
264 |
+
# check that cond_hint is wrapped, and unwrap it
|
265 |
+
if type(self.cond_hint_original) != PlusPlusImageWrapper:
|
266 |
+
raise Exception("ControlNet++ (Multi) expects image input from the Load ControlNet++ Model node, NOT from anything else. Images are provided to that node via ControlNet++ Input nodes.")
|
267 |
+
self.cond_hint_original = self.cond_hint_original.condhint.clone()
|
268 |
+
# otherwise, expect single image input (AKA, usual controlnet input)
|
269 |
+
else:
|
270 |
+
# check that cond_hint is not a PlusPlusImageWrapper
|
271 |
+
if type(self.cond_hint_original) == PlusPlusImageWrapper:
|
272 |
+
raise Exception("ControlNet++ (Single) expects usual image input, NOT the image input from a Load ControlNet++ Model (Multi) node.")
|
273 |
+
pp_group = PlusPlusInputGroup()
|
274 |
+
pp_input = PlusPlusInput(self.cond_hint_original, self.single_control_type, 1.0)
|
275 |
+
pp_group.add(pp_input)
|
276 |
+
self.cond_hint_original = pp_group
|
277 |
+
return to_return
|
278 |
+
|
279 |
+
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number):
|
280 |
+
control_prev = None
|
281 |
+
if self.previous_controlnet is not None:
|
282 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
283 |
+
|
284 |
+
if self.timestep_range is not None:
|
285 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
286 |
+
if control_prev is not None:
|
287 |
+
return control_prev
|
288 |
+
else:
|
289 |
+
return None
|
290 |
+
|
291 |
+
dtype = self.control_model.dtype
|
292 |
+
if self.manual_cast_dtype is not None:
|
293 |
+
dtype = self.manual_cast_dtype
|
294 |
+
|
295 |
+
output_dtype = x_noisy.dtype
|
296 |
+
|
297 |
+
# make all cond_hints appropriate dimensions
|
298 |
+
# TODO: change this to not require cond_hint upscaling every step when self.sub_idxs is present
|
299 |
+
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint_shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint_shape[3]:
|
300 |
+
if self.cond_hint is not None:
|
301 |
+
del self.cond_hint
|
302 |
+
self.cond_hint = [None] * self.control_model.num_control_type
|
303 |
+
self.cond_hint_types = torch.tensor([0.0] * self.control_model.num_control_type)
|
304 |
+
self.cond_hint_shape = None
|
305 |
+
compression_ratio = self.compression_ratio
|
306 |
+
# unlike normal controlnet, need to handle each input image tensor (for each type)
|
307 |
+
for pp_type, pp_input in self.cond_hint_original.controls.items():
|
308 |
+
pp_idx = PlusPlusType.to_idx(pp_type)
|
309 |
+
# if negative, means no type should be selected (single only)
|
310 |
+
if pp_idx < 0:
|
311 |
+
pp_idx = 0
|
312 |
+
else:
|
313 |
+
self.cond_hint_types[pp_idx] = pp_input.strength
|
314 |
+
# if self.cond_hint_original lengths greater or equal to latent count, subdivide
|
315 |
+
if self.sub_idxs is not None:
|
316 |
+
actual_cond_hint_orig = pp_input.image
|
317 |
+
if pp_input.image.size(0) < self.full_latent_length:
|
318 |
+
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
|
319 |
+
self.cond_hint[pp_idx] = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
|
320 |
+
else:
|
321 |
+
self.cond_hint[pp_idx] = comfy.utils.common_upscale(pp_input.image, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
|
322 |
+
self.cond_hint[pp_idx] = self.cond_hint[pp_idx].to(device=x_noisy.device, dtype=dtype)
|
323 |
+
self.cond_hint_shape = self.cond_hint[pp_idx].shape
|
324 |
+
# prepare cond_hint_controls to match batchsize
|
325 |
+
if self.cond_hint_types.count_nonzero() == 0:
|
326 |
+
self.cond_hint_types = None
|
327 |
+
else:
|
328 |
+
self.cond_hint_types = self.cond_hint_types.unsqueeze(0).to(device=x_noisy.device, dtype=dtype).repeat(x_noisy.shape[0], 1)
|
329 |
+
for i in range(len(self.cond_hint)):
|
330 |
+
if self.cond_hint[i] is not None:
|
331 |
+
if x_noisy.shape[0] != self.cond_hint[i].shape[0]:
|
332 |
+
self.cond_hint[i] = broadcast_image_to_extend(self.cond_hint[i], x_noisy.shape[0], batched_number)
|
333 |
+
if self.cond_hint_types is not None and x_noisy.shape[0] != self.cond_hint_types.shape[0]:
|
334 |
+
self.cond_hint_types = broadcast_image_to_extend(self.cond_hint_types, x_noisy.shape[0], batched_number, False)
|
335 |
+
|
336 |
+
# prepare mask_cond_hint
|
337 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
|
338 |
+
|
339 |
+
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
340 |
+
y = cond.get('y', None)
|
341 |
+
if y is not None:
|
342 |
+
y = y.to(dtype)
|
343 |
+
timestep = self.model_sampling_current.timestep(t)
|
344 |
+
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
345 |
+
|
346 |
+
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, control_type=self.cond_hint_types)
|
347 |
+
return self.control_merge(control, control_prev, output_dtype)
|
348 |
+
|
349 |
+
def copy(self):
|
350 |
+
c = ControlNetPlusPlusAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
351 |
+
self.copy_to(c)
|
352 |
+
self.copy_to_advanced(c)
|
353 |
+
c.single_control_type = self.single_control_type
|
354 |
+
return c
|
355 |
+
|
356 |
+
|
357 |
+
def load_controlnetplusplus(ckpt_path: str, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
|
358 |
+
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
359 |
+
# check that actually is ControlNet++ model
|
360 |
+
if "task_embedding" not in controlnet_data:
|
361 |
+
raise Exception(f"'{ckpt_path}' is not a valid ControlNet++ model.")
|
362 |
+
|
363 |
+
controlnet_config = None
|
364 |
+
supported_inference_dtypes = None
|
365 |
+
|
366 |
+
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
367 |
+
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
|
368 |
+
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
369 |
+
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
370 |
+
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
371 |
+
|
372 |
+
count = 0
|
373 |
+
loop = True
|
374 |
+
while loop:
|
375 |
+
suffix = [".weight", ".bias"]
|
376 |
+
for s in suffix:
|
377 |
+
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
378 |
+
k_out = "zero_convs.{}.0{}".format(count, s)
|
379 |
+
if k_in not in controlnet_data:
|
380 |
+
loop = False
|
381 |
+
break
|
382 |
+
diffusers_keys[k_in] = k_out
|
383 |
+
count += 1
|
384 |
+
|
385 |
+
count = 0
|
386 |
+
loop = True
|
387 |
+
while loop:
|
388 |
+
suffix = [".weight", ".bias"]
|
389 |
+
for s in suffix:
|
390 |
+
if count == 0:
|
391 |
+
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
392 |
+
else:
|
393 |
+
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
394 |
+
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
395 |
+
if k_in not in controlnet_data:
|
396 |
+
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
397 |
+
loop = False
|
398 |
+
diffusers_keys[k_in] = k_out
|
399 |
+
count += 1
|
400 |
+
|
401 |
+
new_sd = {}
|
402 |
+
for k in diffusers_keys:
|
403 |
+
if k in controlnet_data:
|
404 |
+
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
405 |
+
|
406 |
+
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
|
407 |
+
controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
|
408 |
+
for k in list(controlnet_data.keys()):
|
409 |
+
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
|
410 |
+
new_sd[new_k] = controlnet_data.pop(k)
|
411 |
+
|
412 |
+
leftover_keys = controlnet_data.keys()
|
413 |
+
if len(leftover_keys) > 0:
|
414 |
+
logger.warning("leftover ControlNet++ keys: {}".format(leftover_keys))
|
415 |
+
controlnet_data = new_sd
|
416 |
+
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
417 |
+
raise Exception("Unexpected SD3 diffusers format for ControlNet++ model. Something is very wrong.")
|
418 |
+
|
419 |
+
pth_key = 'control_model.zero_convs.0.0.weight'
|
420 |
+
pth = False
|
421 |
+
key = 'zero_convs.0.0.weight'
|
422 |
+
if pth_key in controlnet_data:
|
423 |
+
pth = True
|
424 |
+
key = pth_key
|
425 |
+
prefix = "control_model."
|
426 |
+
elif key in controlnet_data:
|
427 |
+
prefix = ""
|
428 |
+
else:
|
429 |
+
raise Exception("Unexpected T2IAdapter format for ControlNet++ model. Something is very wrong.")
|
430 |
+
|
431 |
+
if controlnet_config is None:
|
432 |
+
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
433 |
+
supported_inference_dtypes = model_config.supported_inference_dtypes
|
434 |
+
controlnet_config = model_config.unet_config
|
435 |
+
|
436 |
+
load_device = comfy.model_management.get_torch_device()
|
437 |
+
if supported_inference_dtypes is None:
|
438 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
439 |
+
else:
|
440 |
+
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
441 |
+
|
442 |
+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
443 |
+
if manual_cast_dtype is not None:
|
444 |
+
controlnet_config["operations"] = comfy.ops.manual_cast
|
445 |
+
controlnet_config["dtype"] = unet_dtype
|
446 |
+
controlnet_config.pop("out_channels")
|
447 |
+
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
448 |
+
control_model = ControlNetPlusPlus(**controlnet_config)
|
449 |
+
|
450 |
+
if pth:
|
451 |
+
if 'difference' in controlnet_data:
|
452 |
+
if model is not None:
|
453 |
+
comfy.model_management.load_models_gpu([model])
|
454 |
+
model_sd = model.model_state_dict()
|
455 |
+
for x in controlnet_data:
|
456 |
+
c_m = "control_model."
|
457 |
+
if x.startswith(c_m):
|
458 |
+
sd_key = "diffusion_model.{}".format(x[len(c_m):])
|
459 |
+
if sd_key in model_sd:
|
460 |
+
cd = controlnet_data[x]
|
461 |
+
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
462 |
+
else:
|
463 |
+
logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
464 |
+
|
465 |
+
class WeightsLoader(torch.nn.Module):
|
466 |
+
pass
|
467 |
+
w = WeightsLoader()
|
468 |
+
w.control_model = control_model
|
469 |
+
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
470 |
+
else:
|
471 |
+
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
472 |
+
|
473 |
+
if len(missing) > 0:
|
474 |
+
logger.warning("missing ControlNet++ keys: {}".format(missing))
|
475 |
+
|
476 |
+
if len(unexpected) > 0:
|
477 |
+
logger.debug("unexpected ControlNet++ keys: {}".format(unexpected))
|
478 |
+
|
479 |
+
global_average_pooling = False
|
480 |
+
filename = os.path.splitext(ckpt_path)[0]
|
481 |
+
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
482 |
+
global_average_pooling = True
|
483 |
+
|
484 |
+
control = ControlNetPlusPlusAdvanced(control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
485 |
+
return control
|
ComfyUI-Advanced-ControlNet/adv_control/control_reference.py
ADDED
@@ -0,0 +1,1112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Union
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
import comfy.model_management
|
8 |
+
import comfy.sample
|
9 |
+
import comfy.model_patcher
|
10 |
+
import comfy.utils
|
11 |
+
from comfy.controlnet import ControlBase
|
12 |
+
from comfy.model_patcher import ModelPatcher
|
13 |
+
from comfy.ldm.modules.attention import BasicTransformerBlock
|
14 |
+
from comfy.ldm.modules.diffusionmodules import openaimodel
|
15 |
+
|
16 |
+
from .logger import logger
|
17 |
+
from .utils import (AdvancedControlBase, ControlWeights, TimestepKeyframeGroup, TimestepKeyframe, AbstractPreprocWrapper,
|
18 |
+
broadcast_image_to_extend, ORIG_PREVIOUS_CONTROLNET, CONTROL_INIT_BY_ACN)
|
19 |
+
|
20 |
+
|
21 |
+
REF_READ_ATTN_CONTROL_LIST = "ref_read_attn_control_list"
|
22 |
+
REF_WRITE_ATTN_CONTROL_LIST = "ref_write_attn_control_list"
|
23 |
+
REF_READ_ADAIN_CONTROL_LIST = "ref_read_adain_control_list"
|
24 |
+
REF_WRITE_ADAIN_CONTROL_LIST = "ref_write_adain_control_list"
|
25 |
+
|
26 |
+
REF_ATTN_CONTROL_LIST = "ref_attn_control_list"
|
27 |
+
REF_ADAIN_CONTROL_LIST = "ref_adain_control_list"
|
28 |
+
REF_CONTROL_LIST_ALL = "ref_control_list_all"
|
29 |
+
REF_CONTROL_INFO = "ref_control_info"
|
30 |
+
REF_ATTN_MACHINE_STATE = "ref_attn_machine_state"
|
31 |
+
REF_ADAIN_MACHINE_STATE = "ref_adain_machine_state"
|
32 |
+
REF_COND_IDXS = "ref_cond_idxs"
|
33 |
+
REF_UNCOND_IDXS = "ref_uncond_idxs"
|
34 |
+
|
35 |
+
CONTEXTREF_OPTIONS_CLASS = "contextref_options_class"
|
36 |
+
CONTEXTREF_CLEAN_FUNC = "contextref_clean_func"
|
37 |
+
CONTEXTREF_CONTROL_LIST_ALL = "contextref_control_list_all"
|
38 |
+
CONTEXTREF_MACHINE_STATE = "contextref_machine_state"
|
39 |
+
CONTEXTREF_TEMP_COND_IDX = "contextref_temp_cond_idx"
|
40 |
+
|
41 |
+
HIGHEST_VERSION_SUPPORT = 1
|
42 |
+
RETURNED_CONTEXTREF_VERSION = 1
|
43 |
+
|
44 |
+
|
45 |
+
class RefConst:
|
46 |
+
OPTS = "refcn_opts"
|
47 |
+
CREF_MODE = "contextref_mode"
|
48 |
+
|
49 |
+
|
50 |
+
class MachineState:
|
51 |
+
WRITE = "write"
|
52 |
+
READ = "read"
|
53 |
+
READ_WRITE = "read_write"
|
54 |
+
STYLEALIGN = "stylealign"
|
55 |
+
OFF = "off"
|
56 |
+
|
57 |
+
def is_read(state: str):
|
58 |
+
return state in [MachineState.READ, MachineState.READ_WRITE]
|
59 |
+
|
60 |
+
def is_write(state: str):
|
61 |
+
return state in [MachineState.WRITE, MachineState.READ_WRITE]
|
62 |
+
|
63 |
+
|
64 |
+
class ReferenceType:
|
65 |
+
ATTN = "reference_attn"
|
66 |
+
ADAIN = "reference_adain"
|
67 |
+
ATTN_ADAIN = "reference_attn+adain"
|
68 |
+
STYLE_ALIGN = "StyleAlign"
|
69 |
+
|
70 |
+
_LIST = [ATTN, ADAIN, ATTN_ADAIN]
|
71 |
+
_LIST_ATTN = [ATTN, ATTN_ADAIN]
|
72 |
+
_LIST_ADAIN = [ADAIN, ATTN_ADAIN]
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def is_attn(cls, ref_type: str):
|
76 |
+
return ref_type in cls._LIST_ATTN
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def is_adain(cls, ref_type: str):
|
80 |
+
return ref_type in cls._LIST_ADAIN
|
81 |
+
|
82 |
+
|
83 |
+
class ReferenceOptions:
|
84 |
+
def __init__(self, reference_type: str,
|
85 |
+
attn_style_fidelity: float, adain_style_fidelity: float,
|
86 |
+
attn_ref_weight: float, adain_ref_weight: float,
|
87 |
+
attn_strength: float=1.0, adain_strength: float=1.0,
|
88 |
+
ref_with_other_cns: bool=False):
|
89 |
+
self.reference_type = reference_type
|
90 |
+
# attn
|
91 |
+
self.original_attn_style_fidelity = attn_style_fidelity
|
92 |
+
self.attn_style_fidelity = attn_style_fidelity
|
93 |
+
self.attn_ref_weight = attn_ref_weight
|
94 |
+
self.attn_strength = attn_strength
|
95 |
+
# adain
|
96 |
+
self.original_adain_style_fidelity = adain_style_fidelity
|
97 |
+
self.adain_style_fidelity = adain_style_fidelity
|
98 |
+
self.adain_ref_weight = adain_ref_weight
|
99 |
+
self.adain_strength = adain_strength
|
100 |
+
# other
|
101 |
+
self.ref_with_other_cns = ref_with_other_cns
|
102 |
+
|
103 |
+
def clone(self):
|
104 |
+
return ReferenceOptions(reference_type=self.reference_type,
|
105 |
+
attn_style_fidelity=self.original_attn_style_fidelity, adain_style_fidelity=self.original_adain_style_fidelity,
|
106 |
+
attn_ref_weight=self.attn_ref_weight, adain_ref_weight=self.adain_ref_weight,
|
107 |
+
attn_strength=self.attn_strength, adain_strength=self.adain_strength,
|
108 |
+
ref_with_other_cns=self.ref_with_other_cns)
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def create_combo(reference_type: str, style_fidelity: float, ref_weight: float, ref_with_other_cns: bool=False):
|
112 |
+
return ReferenceOptions(reference_type=reference_type,
|
113 |
+
attn_style_fidelity=style_fidelity, adain_style_fidelity=style_fidelity,
|
114 |
+
attn_ref_weight=ref_weight, adain_ref_weight=ref_weight,
|
115 |
+
ref_with_other_cns=ref_with_other_cns)
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def create_from_kwargs(attn_style_fidelity=0.0, adain_style_fidelity=0.0,
|
119 |
+
attn_ref_weight=0.0, adain_ref_weight=0.0,
|
120 |
+
attn_strength=0.0, adain_strength=0.0, **kwargs):
|
121 |
+
has_attn = attn_strength > 0.0
|
122 |
+
has_adain = adain_strength > 0.0
|
123 |
+
if has_attn and has_adain:
|
124 |
+
reference_type = ReferenceType.ATTN_ADAIN
|
125 |
+
elif has_adain:
|
126 |
+
reference_type = ReferenceType.ADAIN
|
127 |
+
else:
|
128 |
+
reference_type = ReferenceType.ATTN
|
129 |
+
return ReferenceOptions(reference_type=reference_type,
|
130 |
+
attn_style_fidelity=float(attn_style_fidelity), adain_style_fidelity=float(adain_style_fidelity),
|
131 |
+
attn_ref_weight=float(attn_ref_weight), adain_ref_weight=float(adain_ref_weight),
|
132 |
+
attn_strength=float(attn_strength), adain_strength=float(adain_strength))
|
133 |
+
|
134 |
+
|
135 |
+
class ReferencePreprocWrapper(AbstractPreprocWrapper):
|
136 |
+
error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of Reference preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
|
137 |
+
def __init__(self, condhint: Tensor):
|
138 |
+
super().__init__(condhint)
|
139 |
+
|
140 |
+
|
141 |
+
class ReferenceAdvanced(ControlBase, AdvancedControlBase):
|
142 |
+
CHANNEL_TO_MULT = {320: 1, 640: 2, 1280: 4}
|
143 |
+
|
144 |
+
def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup):
|
145 |
+
super().__init__()
|
146 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), allow_condhint_latents=True)
|
147 |
+
# TODO: allow vae_optional to be used instead of preprocessor
|
148 |
+
#require_vae=True
|
149 |
+
self._ref_opts = ref_opts
|
150 |
+
self.order = 0
|
151 |
+
self.model_latent_format = None
|
152 |
+
self.model_sampling_current = None
|
153 |
+
self.should_apply_attn_effective_strength = False
|
154 |
+
self.should_apply_adain_effective_strength = False
|
155 |
+
self.should_apply_effective_masks = False
|
156 |
+
self.latent_shape = None
|
157 |
+
# ContextRef stuff
|
158 |
+
self.is_context_ref = False
|
159 |
+
self.contextref_cond_idx = -1
|
160 |
+
self.contextref_version = RETURNED_CONTEXTREF_VERSION
|
161 |
+
|
162 |
+
@property
|
163 |
+
def ref_opts(self):
|
164 |
+
if self._current_timestep_keyframe is not None and self._current_timestep_keyframe.has_control_weights():
|
165 |
+
return self._current_timestep_keyframe.control_weights.extras.get(RefConst.OPTS, self._ref_opts)
|
166 |
+
return self._ref_opts
|
167 |
+
|
168 |
+
def any_attn_strength_to_apply(self):
|
169 |
+
return self.should_apply_attn_effective_strength or self.should_apply_effective_masks
|
170 |
+
|
171 |
+
def any_adain_strength_to_apply(self):
|
172 |
+
return self.should_apply_adain_effective_strength or self.should_apply_effective_masks
|
173 |
+
|
174 |
+
def get_effective_strength(self):
|
175 |
+
effective_strength = self.strength
|
176 |
+
if self._current_timestep_keyframe is not None:
|
177 |
+
effective_strength = effective_strength * self._current_timestep_keyframe.strength
|
178 |
+
return effective_strength
|
179 |
+
|
180 |
+
def get_effective_attn_mask_or_float(self, x: Tensor, channels: int, is_mid: bool):
|
181 |
+
if not self.should_apply_effective_masks:
|
182 |
+
return self.get_effective_strength() * self.ref_opts.attn_strength
|
183 |
+
if is_mid:
|
184 |
+
div = 8
|
185 |
+
else:
|
186 |
+
div = self.CHANNEL_TO_MULT[channels]
|
187 |
+
real_mask = torch.ones([self.latent_shape[0], 1, self.latent_shape[2]//div, self.latent_shape[3]//div]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.attn_strength
|
188 |
+
self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number)
|
189 |
+
# mask is now shape [b, 1, h ,w]; need to turn into [b, h*w, 1]
|
190 |
+
b, c, h, w = real_mask.shape
|
191 |
+
real_mask = real_mask.permute(0, 2, 3, 1).reshape(b, h*w, c)
|
192 |
+
return real_mask
|
193 |
+
|
194 |
+
def get_effective_adain_mask_or_float(self, x: Tensor):
|
195 |
+
if not self.should_apply_effective_masks:
|
196 |
+
return self.get_effective_strength() * self.ref_opts.adain_strength
|
197 |
+
b, c, h, w = x.shape
|
198 |
+
real_mask = torch.ones([b, 1, h, w]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.adain_strength
|
199 |
+
self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number)
|
200 |
+
return real_mask
|
201 |
+
|
202 |
+
def get_contextref_mode_replace(self):
|
203 |
+
# used by ADE to get mode_replace for current keyframe
|
204 |
+
if self._current_timestep_keyframe.has_control_weights():
|
205 |
+
return self._current_timestep_keyframe.control_weights.extras.get(RefConst.CREF_MODE, None)
|
206 |
+
return None
|
207 |
+
|
208 |
+
def should_run(self):
|
209 |
+
running = super().should_run()
|
210 |
+
if not running:
|
211 |
+
return running
|
212 |
+
attn_run = False
|
213 |
+
adain_run = False
|
214 |
+
if ReferenceType.is_attn(self.ref_opts.reference_type):
|
215 |
+
# attn will run as long as neither weight or strength is zero
|
216 |
+
attn_run = not (math.isclose(self.ref_opts.attn_ref_weight, 0.0) or math.isclose(self.ref_opts.attn_strength, 0.0))
|
217 |
+
if ReferenceType.is_adain(self.ref_opts.reference_type):
|
218 |
+
# adain will run as long as neither weight or strength is zero
|
219 |
+
adain_run = not (math.isclose(self.ref_opts.adain_ref_weight, 0.0) or math.isclose(self.ref_opts.adain_strength, 0.0))
|
220 |
+
return attn_run or adain_run
|
221 |
+
|
222 |
+
def pre_run_advanced(self, model, percent_to_timestep_function):
|
223 |
+
AdvancedControlBase.pre_run_advanced(self, model, percent_to_timestep_function)
|
224 |
+
if isinstance(self.cond_hint_original, AbstractPreprocWrapper):
|
225 |
+
self.cond_hint_original = self.cond_hint_original.condhint
|
226 |
+
self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint
|
227 |
+
self.model_sampling_current = model.model_sampling
|
228 |
+
# SDXL is more sensitive to style_fidelity according to sd-webui-controlnet comments;
|
229 |
+
# prepare all ref_opts accordingly
|
230 |
+
all_ref_opts = [self._ref_opts]
|
231 |
+
for kf in self.timestep_keyframes.keyframes:
|
232 |
+
if kf.has_control_weights() and RefConst.OPTS in kf.control_weights.extras:
|
233 |
+
all_ref_opts.append(kf.control_weights.extras[RefConst.OPTS])
|
234 |
+
for ropts in all_ref_opts:
|
235 |
+
if type(model).__name__ == "SDXL":
|
236 |
+
ropts.attn_style_fidelity = ropts.original_attn_style_fidelity ** 3.0
|
237 |
+
ropts.adain_style_fidelity = ropts.original_adain_style_fidelity ** 3.0
|
238 |
+
else:
|
239 |
+
ropts.attn_style_fidelity = ropts.original_attn_style_fidelity
|
240 |
+
ropts.adain_style_fidelity = ropts.original_adain_style_fidelity
|
241 |
+
|
242 |
+
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
|
243 |
+
# normal ControlNet stuff
|
244 |
+
control_prev = None
|
245 |
+
if self.previous_controlnet is not None:
|
246 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
247 |
+
|
248 |
+
if self.timestep_range is not None:
|
249 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
250 |
+
return control_prev
|
251 |
+
|
252 |
+
dtype = x_noisy.dtype
|
253 |
+
# cond_hint_original only matters for RefCN, NOT ContextRef
|
254 |
+
if self.cond_hint_original is not None:
|
255 |
+
# prepare cond_hint - it is a latent, NOT an image
|
256 |
+
#if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] != self.cond_hint.shape[2] or x_noisy.shape[3] != self.cond_hint.shape[3]:
|
257 |
+
if self.cond_hint is not None:
|
258 |
+
del self.cond_hint
|
259 |
+
self.cond_hint = None
|
260 |
+
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
|
261 |
+
if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
|
262 |
+
self.cond_hint = comfy.utils.common_upscale(
|
263 |
+
self.cond_hint_original[self.sub_idxs],
|
264 |
+
x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device)
|
265 |
+
else:
|
266 |
+
self.cond_hint = comfy.utils.common_upscale(
|
267 |
+
self.cond_hint_original,
|
268 |
+
x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(x_noisy.device)
|
269 |
+
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
270 |
+
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False)
|
271 |
+
# noise cond_hint based on sigma (current step)
|
272 |
+
self.cond_hint = self.model_latent_format.process_in(self.cond_hint)
|
273 |
+
self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None)
|
274 |
+
timestep = self.model_sampling_current.timestep(t)
|
275 |
+
self.should_apply_attn_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.attn_strength, 1.0))
|
276 |
+
self.should_apply_adain_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.adain_strength, 1.0))
|
277 |
+
# prepare mask - use direct_attn, so the mask dims will match source latents (and be smaller)
|
278 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, direct_attn=True)
|
279 |
+
self.should_apply_effective_masks = self.latent_keyframes is not None or self.mask_cond_hint is not None or self.tk_mask_cond_hint is not None
|
280 |
+
self.latent_shape = list(x_noisy.shape)
|
281 |
+
# done preparing; model patches will take care of everything now.
|
282 |
+
# return normal controlnet stuff
|
283 |
+
return control_prev
|
284 |
+
|
285 |
+
def cleanup_advanced(self):
|
286 |
+
super().cleanup_advanced()
|
287 |
+
del self.model_latent_format
|
288 |
+
self.model_latent_format = None
|
289 |
+
del self.model_sampling_current
|
290 |
+
self.model_sampling_current = None
|
291 |
+
self.should_apply_attn_effective_strength = False
|
292 |
+
self.should_apply_adain_effective_strength = False
|
293 |
+
self.should_apply_effective_masks = False
|
294 |
+
|
295 |
+
def copy(self):
|
296 |
+
c = ReferenceAdvanced(self.ref_opts, self.timestep_keyframes)
|
297 |
+
c.order = self.order
|
298 |
+
c.is_context_ref = self.is_context_ref
|
299 |
+
self.copy_to(c)
|
300 |
+
self.copy_to_advanced(c)
|
301 |
+
return c
|
302 |
+
|
303 |
+
# avoid deepcopy shenanigans by making deepcopy not do anything to the reference
|
304 |
+
# TODO: do the bookkeeping to do this in a proper way for all Adv-ControlNets
|
305 |
+
def __deepcopy__(self, memo):
|
306 |
+
return self
|
307 |
+
|
308 |
+
|
309 |
+
def handle_context_ref_setup(contextref_obj, transformer_options: dict, positive, negative):
|
310 |
+
transformer_options[CONTEXTREF_MACHINE_STATE] = MachineState.OFF
|
311 |
+
# verify version is compatible
|
312 |
+
if contextref_obj.version > HIGHEST_VERSION_SUPPORT:
|
313 |
+
raise Exception(f"AnimateDiff-Evolved's ContextRef v{contextref_obj.version} is not supported in currently-installed Advanced-ControlNet (only supports ContextRef up to v{HIGHEST_VERSION_SUPPORT}); " +
|
314 |
+
f"update your Advanced-ControlNet nodes for ContextRef to work.")
|
315 |
+
# init ReferenceOptions
|
316 |
+
cref_opt_dict = contextref_obj.tune.create_dict() # ContextRefTune obj from ADE
|
317 |
+
opts = ReferenceOptions.create_from_kwargs(**cref_opt_dict)
|
318 |
+
# init TimestepKeyframes
|
319 |
+
cref_tks_list = contextref_obj.keyframe.create_list_of_dicts() # ContextRefKeyframeGroup obj from ADE
|
320 |
+
timestep_keyframes = _create_tks_from_dict_list(cref_tks_list)
|
321 |
+
# create ReferenceAdvanced
|
322 |
+
cref = ReferenceAdvanced(ref_opts=opts, timestep_keyframes=timestep_keyframes)
|
323 |
+
cref.strength = contextref_obj.strength # ContextRef obj from ADE
|
324 |
+
cref.set_cond_hint_mask(contextref_obj.mask)
|
325 |
+
cref.order = 99
|
326 |
+
cref.is_context_ref = True
|
327 |
+
context_ref_list = [cref]
|
328 |
+
transformer_options[CONTEXTREF_CONTROL_LIST_ALL] = context_ref_list
|
329 |
+
transformer_options[CONTEXTREF_OPTIONS_CLASS] = ReferenceOptions
|
330 |
+
_add_context_ref_to_conds([positive, negative], cref)
|
331 |
+
return context_ref_list
|
332 |
+
|
333 |
+
|
334 |
+
def _create_tks_from_dict_list(dlist: list[dict[str]]) -> TimestepKeyframeGroup:
|
335 |
+
tks = TimestepKeyframeGroup()
|
336 |
+
if dlist is None or len(dlist) == 0:
|
337 |
+
return tks
|
338 |
+
for d in dlist:
|
339 |
+
# scheduling
|
340 |
+
start_percent = d["start_percent"]
|
341 |
+
guarantee_steps = d["guarantee_steps"]
|
342 |
+
inherit_missing = d["inherit_missing"]
|
343 |
+
# values
|
344 |
+
strength = d["strength"]
|
345 |
+
mask = d["mask"]
|
346 |
+
tune = d["tune"]
|
347 |
+
mode = d["mode"]
|
348 |
+
weights = None
|
349 |
+
extras = {}
|
350 |
+
if tune is not None:
|
351 |
+
cref_opt_dict = tune.create_dict() # ContextRefTune obj from ADE
|
352 |
+
opts = ReferenceOptions.create_from_kwargs(**cref_opt_dict)
|
353 |
+
extras[RefConst.OPTS] = opts
|
354 |
+
if mode is not None:
|
355 |
+
extras[RefConst.CREF_MODE] = mode
|
356 |
+
weights = ControlWeights.default(extras=extras)
|
357 |
+
# create keyframe
|
358 |
+
tk = TimestepKeyframe(start_percent=start_percent, guarantee_steps=guarantee_steps, inherit_missing=inherit_missing,
|
359 |
+
strength=strength, mask_hint_orig=mask, control_weights=weights)
|
360 |
+
tks.add(tk)
|
361 |
+
return tks
|
362 |
+
|
363 |
+
|
364 |
+
def _add_context_ref_to_conds(conds: list[list[dict[str]]], context_ref: ReferenceAdvanced):
|
365 |
+
def _add_context_ref_to_existing_control(control: ControlBase, context_ref: ReferenceAdvanced):
|
366 |
+
curr_cn = control
|
367 |
+
while curr_cn is not None:
|
368 |
+
if type(curr_cn) == ReferenceAdvanced and curr_cn.is_context_ref:
|
369 |
+
break
|
370 |
+
if curr_cn.previous_controlnet is not None:
|
371 |
+
curr_cn = curr_cn.previous_controlnet
|
372 |
+
continue
|
373 |
+
orig_previous_controlnet = curr_cn.previous_controlnet
|
374 |
+
# NOTE: code is already in place to restore any ORIG_PREVIOUS_CONTROLNET props
|
375 |
+
setattr(curr_cn, ORIG_PREVIOUS_CONTROLNET, orig_previous_controlnet)
|
376 |
+
curr_cn.previous_controlnet = context_ref
|
377 |
+
curr_cn = orig_previous_controlnet
|
378 |
+
|
379 |
+
def _add_context_ref(actual_cond: dict[str], context_ref: ReferenceAdvanced):
|
380 |
+
# if controls already present on cond, add it to the last previous_controlnet
|
381 |
+
if "control" in actual_cond:
|
382 |
+
return _add_context_ref_to_existing_control(actual_cond["control"], context_ref)
|
383 |
+
# otherwise, need to add it to begin with, and should mark that it should be cleaned after
|
384 |
+
actual_cond["control"] = context_ref
|
385 |
+
actual_cond[CONTROL_INIT_BY_ACN] = True
|
386 |
+
|
387 |
+
# either add context_ref to end of existing cnet chain, or init 'control' key on actual cond
|
388 |
+
for cond in conds:
|
389 |
+
if cond is not None:
|
390 |
+
for sub_cond in cond:
|
391 |
+
actual_cond = sub_cond[1]
|
392 |
+
_add_context_ref(actual_cond, context_ref)
|
393 |
+
|
394 |
+
|
395 |
+
def ref_noise_latents(latents: Tensor, sigma: Tensor, noise: Tensor=None):
|
396 |
+
sigma = sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
397 |
+
alpha_cumprod = 1 / ((sigma * sigma) + 1)
|
398 |
+
sqrt_alpha_prod = alpha_cumprod ** 0.5
|
399 |
+
sqrt_one_minus_alpha_prod = (1. - alpha_cumprod) ** 0.5
|
400 |
+
if noise is None:
|
401 |
+
# generator = torch.Generator(device="cuda")
|
402 |
+
# generator.manual_seed(0)
|
403 |
+
# noise = torch.empty_like(latents).normal_(generator=generator)
|
404 |
+
# generator = torch.Generator()
|
405 |
+
# generator.manual_seed(0)
|
406 |
+
# noise = torch.randn(latents.size(), generator=generator).to(latents.device)
|
407 |
+
noise = torch.randn_like(latents).to(latents.device)
|
408 |
+
return sqrt_alpha_prod * latents + sqrt_one_minus_alpha_prod * noise
|
409 |
+
|
410 |
+
|
411 |
+
def simple_noise_latents(latents: Tensor, sigma: float, noise: Tensor=None):
|
412 |
+
if noise is None:
|
413 |
+
noise = torch.rand_like(latents)
|
414 |
+
return latents + noise * sigma
|
415 |
+
|
416 |
+
|
417 |
+
class BankStylesBasicTransformerBlock:
|
418 |
+
def __init__(self):
|
419 |
+
# ref
|
420 |
+
self.bank = []
|
421 |
+
self.style_cfgs = []
|
422 |
+
self.cn_idx: list[int] = []
|
423 |
+
# contextref - list of lists as each cond/uncond stored separately
|
424 |
+
self.c_bank: list[list] = []
|
425 |
+
self.c_style_cfgs: list[list] = []
|
426 |
+
self.c_cn_idx: list[list[int]] = []
|
427 |
+
|
428 |
+
def get_bank(self, cref_idx, ignore_contextref, cdevice=None):
|
429 |
+
if ignore_contextref or cref_idx >= len(self.c_bank):
|
430 |
+
return self.bank
|
431 |
+
real_c_bank_list = self.c_bank[cref_idx]
|
432 |
+
if cdevice != None:
|
433 |
+
real_c_bank_list = real_c_bank_list.copy()
|
434 |
+
for i in range(len(real_c_bank_list)):
|
435 |
+
real_c_bank_list[i] = real_c_bank_list[i].to(cdevice)
|
436 |
+
return self.bank + real_c_bank_list
|
437 |
+
|
438 |
+
def get_avg_style_fidelity(self, cref_idx, ignore_contextref):
|
439 |
+
if ignore_contextref or cref_idx >= len(self.c_style_cfgs):
|
440 |
+
return sum(self.style_cfgs) / float(len(self.style_cfgs))
|
441 |
+
combined = self.style_cfgs + self.c_style_cfgs[cref_idx]
|
442 |
+
return sum(combined) / float(len(combined))
|
443 |
+
|
444 |
+
def get_cn_idxs(self, cref_idx, ignore_contxtref):
|
445 |
+
if ignore_contxtref or cref_idx >= len(self.c_cn_idx):
|
446 |
+
return self.cn_idx
|
447 |
+
return self.cn_idx + self.c_cn_idx[cref_idx]
|
448 |
+
|
449 |
+
def init_cref_for_idx(self, cref_idx: int):
|
450 |
+
# makes sure cref lists can accommodate cref_idx
|
451 |
+
if cref_idx < 0:
|
452 |
+
return
|
453 |
+
while cref_idx >= len(self.c_bank):
|
454 |
+
self.c_bank.append([])
|
455 |
+
self.c_style_cfgs.append([])
|
456 |
+
self.c_cn_idx.append([])
|
457 |
+
|
458 |
+
def clear_cref_for_idx(self, cref_idx: int):
|
459 |
+
if cref_idx < 0 or cref_idx >= len(self.c_bank):
|
460 |
+
return
|
461 |
+
self.c_bank[cref_idx] = []
|
462 |
+
self.c_style_cfgs[cref_idx] = []
|
463 |
+
self.c_cn_idx[cref_idx] = []
|
464 |
+
|
465 |
+
def clean_ref(self):
|
466 |
+
del self.bank
|
467 |
+
del self.style_cfgs
|
468 |
+
del self.cn_idx
|
469 |
+
self.bank = []
|
470 |
+
self.style_cfgs = []
|
471 |
+
self.cn_idx = []
|
472 |
+
|
473 |
+
def clean_contextref(self):
|
474 |
+
del self.c_bank
|
475 |
+
del self.c_style_cfgs
|
476 |
+
del self.c_cn_idx
|
477 |
+
self.c_bank = []
|
478 |
+
self.c_style_cfgs = []
|
479 |
+
self.c_cn_idx = []
|
480 |
+
|
481 |
+
def clean_all(self):
|
482 |
+
self.clean_ref()
|
483 |
+
self.clean_contextref()
|
484 |
+
|
485 |
+
|
486 |
+
class BankStylesTimestepEmbedSequential:
|
487 |
+
def __init__(self):
|
488 |
+
# ref
|
489 |
+
self.var_bank = []
|
490 |
+
self.mean_bank = []
|
491 |
+
self.style_cfgs = []
|
492 |
+
self.cn_idx: list[int] = []
|
493 |
+
# cref
|
494 |
+
self.c_var_bank: list[list] = []
|
495 |
+
self.c_mean_bank: list[list] = []
|
496 |
+
self.c_style_cfgs: list[list] = []
|
497 |
+
self.c_cn_idx: list[list[int]] = []
|
498 |
+
|
499 |
+
def get_var_bank(self, cref_idx, ignore_contextref):
|
500 |
+
if ignore_contextref or cref_idx >= len(self.c_var_bank):
|
501 |
+
return self.var_bank
|
502 |
+
return self.var_bank + self.c_var_bank[cref_idx]
|
503 |
+
|
504 |
+
def get_mean_bank(self, cref_idx, ignore_contextref):
|
505 |
+
if ignore_contextref or cref_idx >= len(self.c_mean_bank):
|
506 |
+
return self.mean_bank
|
507 |
+
return self.mean_bank + self.c_mean_bank[cref_idx]
|
508 |
+
|
509 |
+
def get_style_cfgs(self, cref_idx, ignore_contextref):
|
510 |
+
if ignore_contextref or cref_idx >= len(self.c_style_cfgs):
|
511 |
+
return self.style_cfgs
|
512 |
+
return self.style_cfgs + self.c_style_cfgs[cref_idx]
|
513 |
+
|
514 |
+
def get_cn_idxs(self, cref_idx, ignore_contextref):
|
515 |
+
if ignore_contextref or cref_idx >= len(self.c_cn_idx):
|
516 |
+
return self.cn_idx
|
517 |
+
return self.cn_idx + self.c_cn_idx[cref_idx]
|
518 |
+
|
519 |
+
def init_cref_for_idx(self, cref_idx: int):
|
520 |
+
# makes sure cref lists can accommodate cref_idx
|
521 |
+
if cref_idx < 0:
|
522 |
+
return
|
523 |
+
while cref_idx >= len(self.c_var_bank):
|
524 |
+
self.c_var_bank.append([])
|
525 |
+
self.c_mean_bank.append([])
|
526 |
+
self.c_style_cfgs.append([])
|
527 |
+
self.c_cn_idx.append([])
|
528 |
+
|
529 |
+
def clear_cref_for_idx(self, cref_idx: int):
|
530 |
+
if cref_idx < 0 or cref_idx >= len(self.c_var_bank):
|
531 |
+
return
|
532 |
+
self.c_var_bank[cref_idx] = []
|
533 |
+
self.c_mean_bank[cref_idx] = []
|
534 |
+
self.c_style_cfgs[cref_idx] = []
|
535 |
+
self.c_cn_idx[cref_idx] = []
|
536 |
+
|
537 |
+
def clean_ref(self):
|
538 |
+
del self.mean_bank
|
539 |
+
del self.var_bank
|
540 |
+
del self.style_cfgs
|
541 |
+
del self.cn_idx
|
542 |
+
self.mean_bank = []
|
543 |
+
self.var_bank = []
|
544 |
+
self.style_cfgs = []
|
545 |
+
self.cn_idx = []
|
546 |
+
|
547 |
+
def clean_contextref(self):
|
548 |
+
del self.c_var_bank
|
549 |
+
del self.c_mean_bank
|
550 |
+
del self.c_style_cfgs
|
551 |
+
del self.c_cn_idx
|
552 |
+
self.c_var_bank = []
|
553 |
+
self.c_mean_bank = []
|
554 |
+
self.c_style_cfgs = []
|
555 |
+
self.c_cn_idx = []
|
556 |
+
|
557 |
+
def clean_all(self):
|
558 |
+
self.clean_ref()
|
559 |
+
self.clean_contextref()
|
560 |
+
|
561 |
+
|
562 |
+
class InjectionBasicTransformerBlockHolder:
|
563 |
+
def __init__(self, block: BasicTransformerBlock, idx=None):
|
564 |
+
if hasattr(block, "_forward"): # backward compatibility
|
565 |
+
self.original_forward = block._forward
|
566 |
+
else:
|
567 |
+
self.original_forward = block.forward
|
568 |
+
self.idx = idx
|
569 |
+
self.attn_weight = 1.0
|
570 |
+
self.is_middle = False
|
571 |
+
self.bank_styles = BankStylesBasicTransformerBlock()
|
572 |
+
|
573 |
+
def restore(self, block: BasicTransformerBlock):
|
574 |
+
if hasattr(block, "_forward"): # backward compatibility
|
575 |
+
block._forward = self.original_forward
|
576 |
+
else:
|
577 |
+
block.forward = self.original_forward
|
578 |
+
|
579 |
+
def clean_ref(self):
|
580 |
+
self.bank_styles.clean_ref()
|
581 |
+
|
582 |
+
def clean_contextref(self):
|
583 |
+
self.bank_styles.clean_contextref()
|
584 |
+
|
585 |
+
def clean_all(self):
|
586 |
+
self.bank_styles.clean_all()
|
587 |
+
|
588 |
+
|
589 |
+
class InjectionTimestepEmbedSequentialHolder:
|
590 |
+
def __init__(self, block: openaimodel.TimestepEmbedSequential, idx=None, is_middle=False, is_input=False, is_output=False):
|
591 |
+
self.original_forward = block.forward
|
592 |
+
self.idx = idx
|
593 |
+
self.gn_weight = 1.0
|
594 |
+
self.is_middle = is_middle
|
595 |
+
self.is_input = is_input
|
596 |
+
self.is_output = is_output
|
597 |
+
self.bank_styles = BankStylesTimestepEmbedSequential()
|
598 |
+
|
599 |
+
def restore(self, block: openaimodel.TimestepEmbedSequential):
|
600 |
+
block.forward = self.original_forward
|
601 |
+
|
602 |
+
def clean_ref(self):
|
603 |
+
self.bank_styles.clean_ref()
|
604 |
+
|
605 |
+
def clean_contextref(self):
|
606 |
+
self.bank_styles.clean_contextref()
|
607 |
+
|
608 |
+
def clean_all(self):
|
609 |
+
self.bank_styles.clean_all()
|
610 |
+
|
611 |
+
|
612 |
+
class ReferenceInjections:
|
613 |
+
def __init__(self, attn_modules: list['RefBasicTransformerBlock']=None, gn_modules: list['RefTimestepEmbedSequential']=None):
|
614 |
+
self.attn_modules = attn_modules if attn_modules else []
|
615 |
+
self.gn_modules = gn_modules if gn_modules else []
|
616 |
+
self.diffusion_model_orig_forward: Callable = None
|
617 |
+
|
618 |
+
def clean_ref_module_mem(self):
|
619 |
+
for attn_module in self.attn_modules:
|
620 |
+
try:
|
621 |
+
attn_module.injection_holder.clean_ref()
|
622 |
+
except Exception:
|
623 |
+
pass
|
624 |
+
for gn_module in self.gn_modules:
|
625 |
+
try:
|
626 |
+
gn_module.injection_holder.clean_ref()
|
627 |
+
except Exception:
|
628 |
+
pass
|
629 |
+
|
630 |
+
def clean_contextref_module_mem(self):
|
631 |
+
for attn_module in self.attn_modules:
|
632 |
+
try:
|
633 |
+
attn_module.injection_holder.clean_contextref()
|
634 |
+
except Exception:
|
635 |
+
pass
|
636 |
+
for gn_module in self.gn_modules:
|
637 |
+
try:
|
638 |
+
gn_module.injection_holder.clean_contextref()
|
639 |
+
except Exception:
|
640 |
+
pass
|
641 |
+
|
642 |
+
def clean_all_module_mem(self):
|
643 |
+
for attn_module in self.attn_modules:
|
644 |
+
try:
|
645 |
+
attn_module.injection_holder.clean_all()
|
646 |
+
except Exception:
|
647 |
+
pass
|
648 |
+
for gn_module in self.gn_modules:
|
649 |
+
try:
|
650 |
+
gn_module.injection_holder.clean_all()
|
651 |
+
except Exception:
|
652 |
+
pass
|
653 |
+
|
654 |
+
def cleanup(self):
|
655 |
+
self.clean_all_module_mem()
|
656 |
+
del self.attn_modules
|
657 |
+
self.attn_modules = []
|
658 |
+
del self.gn_modules
|
659 |
+
self.gn_modules = []
|
660 |
+
self.diffusion_model_orig_forward = None
|
661 |
+
|
662 |
+
|
663 |
+
def factory_forward_inject_UNetModel(reference_injections: ReferenceInjections):
|
664 |
+
def forward_inject_UNetModel(self, x: Tensor, *args, **kwargs):
|
665 |
+
# get control and transformer_options from kwargs
|
666 |
+
real_args = list(args)
|
667 |
+
real_kwargs = list(kwargs.keys())
|
668 |
+
control = kwargs.get("control", None)
|
669 |
+
transformer_options: dict[str] = kwargs.get("transformer_options", {})
|
670 |
+
# NOTE: adds support for both ReferenceCN and ContextRef, so need to track them separately
|
671 |
+
# get ReferenceAdvanced objects
|
672 |
+
ref_controlnets: list[ReferenceAdvanced] = transformer_options.get(REF_CONTROL_LIST_ALL, [])
|
673 |
+
context_controlnets: list[ReferenceAdvanced] = transformer_options.get(CONTEXTREF_CONTROL_LIST_ALL, [])
|
674 |
+
# clean contextref stuff if OFF
|
675 |
+
if len(context_controlnets) > 0 and transformer_options[CONTEXTREF_MACHINE_STATE] == MachineState.OFF:
|
676 |
+
reference_injections.clean_contextref_module_mem()
|
677 |
+
context_controlnets = []
|
678 |
+
# discard any controlnets that should not run
|
679 |
+
ref_controlnets = [z for z in ref_controlnets if z.should_run()]
|
680 |
+
context_controlnets = [z for z in context_controlnets if z.should_run()]
|
681 |
+
# if nothing related to reference controlnets, do nothing special
|
682 |
+
if len(ref_controlnets) == 0 and len(context_controlnets) == 0:
|
683 |
+
return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs)
|
684 |
+
try:
|
685 |
+
# assign cond and uncond idxs
|
686 |
+
batched_number = len(transformer_options["cond_or_uncond"])
|
687 |
+
per_batch = x.shape[0] // batched_number
|
688 |
+
indiv_conds = []
|
689 |
+
for cond_type in transformer_options["cond_or_uncond"]:
|
690 |
+
indiv_conds.extend([cond_type] * per_batch)
|
691 |
+
transformer_options[REF_UNCOND_IDXS] = [i for i, z in enumerate(indiv_conds) if z == 1]
|
692 |
+
transformer_options[REF_COND_IDXS] = [i for i, z in enumerate(indiv_conds) if z == 0]
|
693 |
+
# check which controlnets do which thing
|
694 |
+
attn_controlnets = []
|
695 |
+
adain_controlnets = []
|
696 |
+
for control in ref_controlnets:
|
697 |
+
if ReferenceType.is_attn(control.ref_opts.reference_type):
|
698 |
+
attn_controlnets.append(control)
|
699 |
+
if ReferenceType.is_adain(control.ref_opts.reference_type):
|
700 |
+
adain_controlnets.append(control)
|
701 |
+
context_attn_controlnets = []
|
702 |
+
context_adain_controlnets = []
|
703 |
+
# for ease of access, store current contextref_cond_idx value
|
704 |
+
if len(context_controlnets) == 0:
|
705 |
+
transformer_options[CONTEXTREF_TEMP_COND_IDX] = -1
|
706 |
+
else:
|
707 |
+
transformer_options[CONTEXTREF_TEMP_COND_IDX] = context_controlnets[0].contextref_cond_idx
|
708 |
+
# logger.info(f"{transformer_options[CONTEXTREF_MACHINE_STATE]}: {transformer_options[CONTEXTREF_TEMP_COND_IDX]}")
|
709 |
+
|
710 |
+
for control in context_controlnets:
|
711 |
+
if ReferenceType.is_attn(control.ref_opts.reference_type):
|
712 |
+
context_attn_controlnets.append(control)
|
713 |
+
if ReferenceType.is_adain(control.ref_opts.reference_type):
|
714 |
+
context_adain_controlnets.append(control)
|
715 |
+
if len(adain_controlnets) > 0 or len(context_adain_controlnets) > 0:
|
716 |
+
# ComfyUI uses forward_timestep_embed with the TimestepEmbedSequential passed into it
|
717 |
+
orig_forward_timestep_embed = openaimodel.forward_timestep_embed
|
718 |
+
openaimodel.forward_timestep_embed = forward_timestep_embed_ref_inject_factory(orig_forward_timestep_embed)
|
719 |
+
|
720 |
+
# if RefCN to be used, handle running diffusion with ref cond hints
|
721 |
+
if len(ref_controlnets) > 0:
|
722 |
+
for control in ref_controlnets:
|
723 |
+
read_attn_list = []
|
724 |
+
write_attn_list = []
|
725 |
+
read_adain_list = []
|
726 |
+
write_adain_list = []
|
727 |
+
|
728 |
+
if ReferenceType.is_attn(control.ref_opts.reference_type):
|
729 |
+
write_attn_list.append(control)
|
730 |
+
if ReferenceType.is_adain(control.ref_opts.reference_type):
|
731 |
+
write_adain_list.append(control)
|
732 |
+
# apply lists
|
733 |
+
transformer_options[REF_READ_ATTN_CONTROL_LIST] = read_attn_list
|
734 |
+
transformer_options[REF_WRITE_ATTN_CONTROL_LIST] = write_attn_list
|
735 |
+
transformer_options[REF_READ_ADAIN_CONTROL_LIST] = read_adain_list
|
736 |
+
transformer_options[REF_WRITE_ADAIN_CONTROL_LIST] = write_adain_list
|
737 |
+
|
738 |
+
orig_kwargs = kwargs
|
739 |
+
# disable other controlnets for this run, if specified
|
740 |
+
if not control.ref_opts.ref_with_other_cns:
|
741 |
+
kwargs = kwargs.copy()
|
742 |
+
kwargs["control"] = None
|
743 |
+
reference_injections.diffusion_model_orig_forward(control.cond_hint.to(dtype=x.dtype).to(device=x.device), *args, **kwargs)
|
744 |
+
kwargs = orig_kwargs
|
745 |
+
# prepare running diffusion for real now
|
746 |
+
read_attn_list = []
|
747 |
+
write_attn_list = []
|
748 |
+
read_adain_list = []
|
749 |
+
write_adain_list = []
|
750 |
+
|
751 |
+
# add RefCNs to read lists
|
752 |
+
read_attn_list.extend(attn_controlnets)
|
753 |
+
read_adain_list.extend(adain_controlnets)
|
754 |
+
|
755 |
+
# do contextref stuff, if needed
|
756 |
+
if len(context_controlnets) > 0:
|
757 |
+
# clean contextref stuff if first WRITE
|
758 |
+
# if context_controlnets[0].contextref_cond_idx == 0 and is_write(transformer_options[CONTEXTREF_MACHINE_STATE]):
|
759 |
+
# reference_injections.clean_contextref_module_mem()
|
760 |
+
### add ContextRef to appropriate lists
|
761 |
+
# attn
|
762 |
+
if is_read(transformer_options[CONTEXTREF_MACHINE_STATE]):
|
763 |
+
read_attn_list.extend(context_attn_controlnets)
|
764 |
+
if is_write(transformer_options[CONTEXTREF_MACHINE_STATE]):
|
765 |
+
write_attn_list.extend(context_attn_controlnets)
|
766 |
+
# adain
|
767 |
+
if is_read(transformer_options[CONTEXTREF_MACHINE_STATE]):
|
768 |
+
read_adain_list.extend(context_adain_controlnets)
|
769 |
+
if is_write(transformer_options[CONTEXTREF_MACHINE_STATE]):
|
770 |
+
write_adain_list.extend(context_adain_controlnets)
|
771 |
+
# apply lists, containing both RefCN and ContextRef
|
772 |
+
transformer_options[REF_READ_ATTN_CONTROL_LIST] = read_attn_list
|
773 |
+
transformer_options[REF_WRITE_ATTN_CONTROL_LIST] = write_attn_list
|
774 |
+
transformer_options[REF_READ_ADAIN_CONTROL_LIST] = read_adain_list
|
775 |
+
transformer_options[REF_WRITE_ADAIN_CONTROL_LIST] = write_adain_list
|
776 |
+
# run diffusion for real
|
777 |
+
try:
|
778 |
+
return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs)
|
779 |
+
finally:
|
780 |
+
# increment current cond idx
|
781 |
+
if len(context_controlnets) > 0:
|
782 |
+
for cn in context_controlnets:
|
783 |
+
cn.contextref_cond_idx += 1
|
784 |
+
finally:
|
785 |
+
# make sure ref banks are cleared no matter what happens - otherwise, RIP VRAM
|
786 |
+
reference_injections.clean_ref_module_mem()
|
787 |
+
if len(adain_controlnets) > 0 or len(context_adain_controlnets) > 0:
|
788 |
+
openaimodel.forward_timestep_embed = orig_forward_timestep_embed
|
789 |
+
|
790 |
+
|
791 |
+
return forward_inject_UNetModel
|
792 |
+
|
793 |
+
|
794 |
+
# dummy class just to help IDE keep track of injected variables
|
795 |
+
class RefBasicTransformerBlock(BasicTransformerBlock):
|
796 |
+
injection_holder: InjectionBasicTransformerBlockHolder = None
|
797 |
+
|
798 |
+
def _forward_inject_BasicTransformerBlock(self: RefBasicTransformerBlock, x: Tensor, context: Tensor=None, transformer_options: dict[str]={}):
|
799 |
+
extra_options = {}
|
800 |
+
block = transformer_options.get("block", None)
|
801 |
+
block_index = transformer_options.get("block_index", 0)
|
802 |
+
transformer_patches = {}
|
803 |
+
transformer_patches_replace = {}
|
804 |
+
|
805 |
+
for k in transformer_options:
|
806 |
+
if k == "patches":
|
807 |
+
transformer_patches = transformer_options[k]
|
808 |
+
elif k == "patches_replace":
|
809 |
+
transformer_patches_replace = transformer_options[k]
|
810 |
+
else:
|
811 |
+
extra_options[k] = transformer_options[k]
|
812 |
+
|
813 |
+
extra_options["n_heads"] = self.n_heads
|
814 |
+
extra_options["dim_head"] = self.d_head
|
815 |
+
|
816 |
+
if self.ff_in:
|
817 |
+
x_skip = x
|
818 |
+
x = self.ff_in(self.norm_in(x))
|
819 |
+
if self.is_res:
|
820 |
+
x += x_skip
|
821 |
+
|
822 |
+
n: Tensor = self.norm1(x)
|
823 |
+
if self.disable_self_attn:
|
824 |
+
context_attn1 = context
|
825 |
+
else:
|
826 |
+
context_attn1 = None
|
827 |
+
value_attn1 = None
|
828 |
+
|
829 |
+
# Reference CN stuff
|
830 |
+
uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, [])
|
831 |
+
#c_idx_mask = transformer_options.get(REF_COND_IDXS, [])
|
832 |
+
# WRITE mode may have only 1 ReferenceAdvanced for RefCN at a time, other modes will have all ReferenceAdvanced
|
833 |
+
ref_write_cns: list[ReferenceAdvanced] = transformer_options.get(REF_WRITE_ATTN_CONTROL_LIST, [])
|
834 |
+
ref_read_cns: list[ReferenceAdvanced] = transformer_options.get(REF_READ_ATTN_CONTROL_LIST, [])
|
835 |
+
cref_cond_idx: int = transformer_options.get(CONTEXTREF_TEMP_COND_IDX, -1)
|
836 |
+
ignore_contextref_read = cref_cond_idx < 0 # if writing to bank, should NOT be read in the same execution
|
837 |
+
|
838 |
+
cached_n = None
|
839 |
+
cref_write_cns: list[ReferenceAdvanced] = []
|
840 |
+
# check if any WRITE cns are applicable; Reference CN WRITEs immediately, ContextREF WRITEs after READ completed
|
841 |
+
# if any refs to WRITE, save n and style_fidelity
|
842 |
+
for refcn in ref_write_cns:
|
843 |
+
if refcn.ref_opts.attn_ref_weight > self.injection_holder.attn_weight:
|
844 |
+
if cached_n is None:
|
845 |
+
cached_n = n.detach().clone()
|
846 |
+
# for ContextRef, make sure relevant lists are long enough to cond_idx
|
847 |
+
# store RefCN and ContextRef stuff separately
|
848 |
+
if refcn.is_context_ref:
|
849 |
+
cref_write_cns.append(refcn)
|
850 |
+
self.injection_holder.bank_styles.init_cref_for_idx(cref_cond_idx)
|
851 |
+
else: # Reference CN WRITE
|
852 |
+
self.injection_holder.bank_styles.bank.append(cached_n)
|
853 |
+
self.injection_holder.bank_styles.style_cfgs.append(refcn.ref_opts.attn_style_fidelity)
|
854 |
+
self.injection_holder.bank_styles.cn_idx.append(refcn.order)
|
855 |
+
if len(cref_write_cns) == 0:
|
856 |
+
del cached_n
|
857 |
+
|
858 |
+
if "attn1_patch" in transformer_patches:
|
859 |
+
patch = transformer_patches["attn1_patch"]
|
860 |
+
if context_attn1 is None:
|
861 |
+
context_attn1 = n
|
862 |
+
value_attn1 = context_attn1
|
863 |
+
for p in patch:
|
864 |
+
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
865 |
+
|
866 |
+
if block is not None:
|
867 |
+
transformer_block = (block[0], block[1], block_index)
|
868 |
+
else:
|
869 |
+
transformer_block = None
|
870 |
+
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
|
871 |
+
block_attn1 = transformer_block
|
872 |
+
if block_attn1 not in attn1_replace_patch:
|
873 |
+
block_attn1 = block
|
874 |
+
|
875 |
+
if block_attn1 in attn1_replace_patch:
|
876 |
+
if context_attn1 is None:
|
877 |
+
context_attn1 = n
|
878 |
+
value_attn1 = n
|
879 |
+
n = self.attn1.to_q(n)
|
880 |
+
# Reference CN READ - use attn1_replace_patch appropriately
|
881 |
+
if len(ref_read_cns) > 0 and len(self.injection_holder.bank_styles.get_bank(cref_cond_idx, ignore_contextref_read)) > 0:
|
882 |
+
bank_styles = self.injection_holder.bank_styles
|
883 |
+
style_fidelity = bank_styles.get_avg_style_fidelity(cref_cond_idx, ignore_contextref_read)
|
884 |
+
real_bank = bank_styles.get_bank(cref_cond_idx, ignore_contextref_read, cdevice=n.device).copy()
|
885 |
+
real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read)
|
886 |
+
cn_idx = 0
|
887 |
+
for idx, order in enumerate(real_cn_idxs):
|
888 |
+
# make sure matching ref cn is selected
|
889 |
+
for i in range(cn_idx, len(ref_read_cns)):
|
890 |
+
if ref_read_cns[i].order == order:
|
891 |
+
cn_idx = i
|
892 |
+
break
|
893 |
+
assert order == ref_read_cns[cn_idx].order
|
894 |
+
if ref_read_cns[cn_idx].any_attn_strength_to_apply():
|
895 |
+
effective_strength = ref_read_cns[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle)
|
896 |
+
real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength)
|
897 |
+
n_uc = self.attn1.to_out(attn1_replace_patch[block_attn1](
|
898 |
+
n,
|
899 |
+
self.attn1.to_k(torch.cat([context_attn1] + real_bank, dim=1)),
|
900 |
+
self.attn1.to_v(torch.cat([value_attn1] + real_bank, dim=1)),
|
901 |
+
extra_options))
|
902 |
+
n_c = n_uc.clone()
|
903 |
+
if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
|
904 |
+
n_c[uc_idx_mask] = self.attn1.to_out(attn1_replace_patch[block_attn1](
|
905 |
+
n[uc_idx_mask],
|
906 |
+
self.attn1.to_k(context_attn1[uc_idx_mask]),
|
907 |
+
self.attn1.to_v(value_attn1[uc_idx_mask]),
|
908 |
+
extra_options))
|
909 |
+
n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc
|
910 |
+
bank_styles.clean_ref()
|
911 |
+
else:
|
912 |
+
context_attn1 = self.attn1.to_k(context_attn1)
|
913 |
+
value_attn1 = self.attn1.to_v(value_attn1)
|
914 |
+
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
915 |
+
n = self.attn1.to_out(n)
|
916 |
+
else:
|
917 |
+
# Reference CN READ - no attn1_replace_patch
|
918 |
+
if len(ref_read_cns) > 0 and len(self.injection_holder.bank_styles.get_bank(cref_cond_idx, ignore_contextref_read)) > 0:
|
919 |
+
if context_attn1 is None:
|
920 |
+
context_attn1 = n
|
921 |
+
bank_styles = self.injection_holder.bank_styles
|
922 |
+
style_fidelity = bank_styles.get_avg_style_fidelity(cref_cond_idx, ignore_contextref_read)
|
923 |
+
real_bank = bank_styles.get_bank(cref_cond_idx, ignore_contextref_read, cdevice=n.device).copy()
|
924 |
+
real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read)
|
925 |
+
cn_idx = 0
|
926 |
+
for idx, order in enumerate(real_cn_idxs):
|
927 |
+
# make sure matching ref cn is selected
|
928 |
+
for i in range(cn_idx, len(ref_read_cns)):
|
929 |
+
if ref_read_cns[i].order == order:
|
930 |
+
cn_idx = i
|
931 |
+
break
|
932 |
+
assert order == ref_read_cns[cn_idx].order
|
933 |
+
if ref_read_cns[cn_idx].any_attn_strength_to_apply():
|
934 |
+
effective_strength = ref_read_cns[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle)
|
935 |
+
real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength)
|
936 |
+
n_uc: Tensor = self.attn1(
|
937 |
+
n,
|
938 |
+
context=torch.cat([context_attn1] + real_bank, dim=1),
|
939 |
+
value=torch.cat([value_attn1] + real_bank, dim=1) if value_attn1 is not None else value_attn1)
|
940 |
+
n_c = n_uc.clone()
|
941 |
+
if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
|
942 |
+
n_c[uc_idx_mask] = self.attn1(
|
943 |
+
n[uc_idx_mask],
|
944 |
+
context=context_attn1[uc_idx_mask],
|
945 |
+
value=value_attn1[uc_idx_mask] if value_attn1 is not None else value_attn1)
|
946 |
+
n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc
|
947 |
+
bank_styles.clean_ref()
|
948 |
+
else:
|
949 |
+
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
950 |
+
|
951 |
+
# ContextRef CN WRITE
|
952 |
+
if len(cref_write_cns) > 0:
|
953 |
+
# clear so that ContextRef CNs can properly 'replace' previous value at cond_idx
|
954 |
+
self.injection_holder.bank_styles.clear_cref_for_idx(cref_cond_idx)
|
955 |
+
for refcn in cref_write_cns:
|
956 |
+
# add a whole list to match expected type when combining
|
957 |
+
self.injection_holder.bank_styles.c_bank[cref_cond_idx].append(cached_n.to(comfy.model_management.unet_offload_device()))
|
958 |
+
self.injection_holder.bank_styles.c_style_cfgs[cref_cond_idx].append(refcn.ref_opts.attn_style_fidelity)
|
959 |
+
self.injection_holder.bank_styles.c_cn_idx[cref_cond_idx].append(refcn.order)
|
960 |
+
del cached_n
|
961 |
+
|
962 |
+
if "attn1_output_patch" in transformer_patches:
|
963 |
+
patch = transformer_patches["attn1_output_patch"]
|
964 |
+
for p in patch:
|
965 |
+
n = p(n, extra_options)
|
966 |
+
|
967 |
+
x += n
|
968 |
+
if "middle_patch" in transformer_patches:
|
969 |
+
patch = transformer_patches["middle_patch"]
|
970 |
+
for p in patch:
|
971 |
+
x = p(x, extra_options)
|
972 |
+
|
973 |
+
if self.attn2 is not None:
|
974 |
+
n = self.norm2(x)
|
975 |
+
if self.switch_temporal_ca_to_sa:
|
976 |
+
context_attn2 = n
|
977 |
+
else:
|
978 |
+
context_attn2 = context
|
979 |
+
value_attn2 = None
|
980 |
+
if "attn2_patch" in transformer_patches:
|
981 |
+
patch = transformer_patches["attn2_patch"]
|
982 |
+
value_attn2 = context_attn2
|
983 |
+
for p in patch:
|
984 |
+
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
985 |
+
|
986 |
+
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
987 |
+
block_attn2 = transformer_block
|
988 |
+
if block_attn2 not in attn2_replace_patch:
|
989 |
+
block_attn2 = block
|
990 |
+
|
991 |
+
if block_attn2 in attn2_replace_patch:
|
992 |
+
if value_attn2 is None:
|
993 |
+
value_attn2 = context_attn2
|
994 |
+
n = self.attn2.to_q(n)
|
995 |
+
context_attn2 = self.attn2.to_k(context_attn2)
|
996 |
+
value_attn2 = self.attn2.to_v(value_attn2)
|
997 |
+
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
998 |
+
n = self.attn2.to_out(n)
|
999 |
+
else:
|
1000 |
+
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
1001 |
+
|
1002 |
+
if "attn2_output_patch" in transformer_patches:
|
1003 |
+
patch = transformer_patches["attn2_output_patch"]
|
1004 |
+
for p in patch:
|
1005 |
+
n = p(n, extra_options)
|
1006 |
+
|
1007 |
+
x += n
|
1008 |
+
if self.is_res:
|
1009 |
+
x_skip = x
|
1010 |
+
x = self.ff(self.norm3(x))
|
1011 |
+
if self.is_res:
|
1012 |
+
x += x_skip
|
1013 |
+
|
1014 |
+
return x
|
1015 |
+
|
1016 |
+
|
1017 |
+
class RefTimestepEmbedSequential(openaimodel.TimestepEmbedSequential):
|
1018 |
+
injection_holder: InjectionTimestepEmbedSequentialHolder = None
|
1019 |
+
|
1020 |
+
def forward_timestep_embed_ref_inject_factory(orig_timestep_embed_inject_factory: Callable):
|
1021 |
+
def forward_timestep_embed_ref_inject(*args, **kwargs):
|
1022 |
+
ts: RefTimestepEmbedSequential = args[0]
|
1023 |
+
if not hasattr(ts, "injection_holder"):
|
1024 |
+
return orig_timestep_embed_inject_factory(*args, **kwargs)
|
1025 |
+
eps = 1e-6
|
1026 |
+
x: Tensor = orig_timestep_embed_inject_factory(*args, **kwargs)
|
1027 |
+
y: Tensor = None
|
1028 |
+
transformer_options: dict[str] = args[4]
|
1029 |
+
# Reference CN stuff
|
1030 |
+
uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, [])
|
1031 |
+
#c_idx_mask = transformer_options.get(REF_COND_IDXS, [])
|
1032 |
+
# WRITE mode will only have one ReferenceAdvanced, other modes will have all ReferenceAdvanced
|
1033 |
+
ref_write_cns: list[ReferenceAdvanced] = transformer_options.get(REF_WRITE_ADAIN_CONTROL_LIST, [])
|
1034 |
+
ref_read_cns: list[ReferenceAdvanced] = transformer_options.get(REF_READ_ADAIN_CONTROL_LIST, [])
|
1035 |
+
cref_cond_idx: int = transformer_options.get(CONTEXTREF_TEMP_COND_IDX, -1)
|
1036 |
+
ignore_contextref_read = cref_cond_idx < 0 # if writing to bank, should NOT be read in the same execution
|
1037 |
+
|
1038 |
+
cached_var = None
|
1039 |
+
cached_mean = None
|
1040 |
+
cref_write_cns: list[ReferenceAdvanced] = []
|
1041 |
+
# if any refs to WRITE, save var, mean, and style_cfg
|
1042 |
+
for refcn in ref_write_cns:
|
1043 |
+
if refcn.ref_opts.adain_ref_weight > ts.injection_holder.gn_weight:
|
1044 |
+
if cached_var is None:
|
1045 |
+
cached_var, cached_mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
1046 |
+
if refcn.is_context_ref:
|
1047 |
+
cref_write_cns.append(refcn)
|
1048 |
+
ts.injection_holder.bank_styles.init_cref_for_idx(cref_cond_idx)
|
1049 |
+
else:
|
1050 |
+
ts.injection_holder.bank_styles.var_bank.append(cached_var)
|
1051 |
+
ts.injection_holder.bank_styles.mean_bank.append(cached_mean)
|
1052 |
+
ts.injection_holder.bank_styles.style_cfgs.append(refcn.ref_opts.adain_style_fidelity)
|
1053 |
+
ts.injection_holder.bank_styles.cn_idx.append(refcn.order)
|
1054 |
+
if len(cref_write_cns) == 0:
|
1055 |
+
del cached_var
|
1056 |
+
del cached_mean
|
1057 |
+
|
1058 |
+
# if any refs to READ, do math with saved var, mean, and style_cfg
|
1059 |
+
if len(ref_read_cns) > 0:
|
1060 |
+
if len(ts.injection_holder.bank_styles.get_var_bank(cref_cond_idx, ignore_contextref_read)) > 0:
|
1061 |
+
bank_styles = ts.injection_holder.bank_styles
|
1062 |
+
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
1063 |
+
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
1064 |
+
y_uc = torch.zeros_like(x)
|
1065 |
+
cn_idx = 0
|
1066 |
+
real_style_cfgs = bank_styles.get_style_cfgs(cref_cond_idx, ignore_contextref_read)
|
1067 |
+
real_var_bank = bank_styles.get_var_bank(cref_cond_idx, ignore_contextref_read)
|
1068 |
+
real_mean_bank = bank_styles.get_mean_bank(cref_cond_idx, ignore_contextref_read)
|
1069 |
+
real_cn_idxs = bank_styles.get_cn_idxs(cref_cond_idx, ignore_contextref_read)
|
1070 |
+
for idx, order in enumerate(real_cn_idxs):
|
1071 |
+
# make sure matching ref cn is selected
|
1072 |
+
for i in range(cn_idx, len(ref_read_cns)):
|
1073 |
+
if ref_read_cns[i].order == order:
|
1074 |
+
cn_idx = i
|
1075 |
+
break
|
1076 |
+
assert order == ref_read_cns[cn_idx].order
|
1077 |
+
style_fidelity = real_style_cfgs[idx]
|
1078 |
+
var_acc = real_var_bank[idx]
|
1079 |
+
mean_acc = real_mean_bank[idx]
|
1080 |
+
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
1081 |
+
sub_y_uc = (((x - mean) / std) * std_acc) + mean_acc
|
1082 |
+
if ref_read_cns[cn_idx].any_adain_strength_to_apply():
|
1083 |
+
effective_strength = ref_read_cns[cn_idx].get_effective_adain_mask_or_float(x=x)
|
1084 |
+
sub_y_uc = sub_y_uc * effective_strength + x * (1-effective_strength)
|
1085 |
+
y_uc += sub_y_uc
|
1086 |
+
# get average, if more than one
|
1087 |
+
if len(real_cn_idxs) > 1:
|
1088 |
+
y_uc /= len(real_cn_idxs)
|
1089 |
+
y_c = y_uc.clone()
|
1090 |
+
if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
|
1091 |
+
y_c[uc_idx_mask] = x.to(y_c.dtype)[uc_idx_mask]
|
1092 |
+
y = style_fidelity * y_c + (1.0 - style_fidelity) * y_uc
|
1093 |
+
ts.injection_holder.bank_styles.clean_ref()
|
1094 |
+
|
1095 |
+
# ContextRef CN WRITE
|
1096 |
+
if len(cref_write_cns) > 0:
|
1097 |
+
# clear so that ContextRef CNs can properly 'replace' previous value at cond_idx
|
1098 |
+
ts.injection_holder.bank_styles.clear_cref_for_idx(cref_cond_idx)
|
1099 |
+
for refcn in cref_write_cns:
|
1100 |
+
# add a whole list to match expected type when combining
|
1101 |
+
ts.injection_holder.bank_styles.c_var_bank[cref_cond_idx].append(cached_var)
|
1102 |
+
ts.injection_holder.bank_styles.c_mean_bank[cref_cond_idx].append(cached_mean)
|
1103 |
+
ts.injection_holder.bank_styles.c_style_cfgs[cref_cond_idx].append(refcn.ref_opts.adain_style_fidelity)
|
1104 |
+
ts.injection_holder.bank_styles.c_cn_idx[cref_cond_idx].append(refcn.order)
|
1105 |
+
del cached_var
|
1106 |
+
del cached_mean
|
1107 |
+
|
1108 |
+
if y is None:
|
1109 |
+
y = x
|
1110 |
+
return y.to(x.dtype)
|
1111 |
+
|
1112 |
+
return forward_timestep_embed_ref_inject
|
ComfyUI-Advanced-ControlNet/adv_control/control_sparsectrl.py
ADDED
@@ -0,0 +1,1078 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#taken from: https://github.com/lllyasviel/ControlNet
|
2 |
+
#and modified
|
3 |
+
#and then taken from comfy/cldm/cldm.py and modified again
|
4 |
+
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
+
import copy
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
from typing import Iterable, Union
|
10 |
+
import torch
|
11 |
+
import torch as th
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch import Tensor
|
14 |
+
from einops import rearrange, repeat
|
15 |
+
|
16 |
+
from comfy.ldm.modules.diffusionmodules.util import (
|
17 |
+
zero_module,
|
18 |
+
timestep_embedding,
|
19 |
+
)
|
20 |
+
|
21 |
+
from comfy.cli_args import args
|
22 |
+
from comfy.cldm.cldm import ControlNet as ControlNetCLDM
|
23 |
+
from comfy.ldm.modules.attention import SpatialTransformer
|
24 |
+
from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default
|
25 |
+
from comfy.ldm.modules.attention import FeedForward, SpatialTransformer
|
26 |
+
from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential
|
27 |
+
from comfy.model_patcher import ModelPatcher
|
28 |
+
import comfy.ops
|
29 |
+
import comfy.model_management
|
30 |
+
import comfy.utils
|
31 |
+
|
32 |
+
from .logger import logger
|
33 |
+
from .utils import (BIGMAX, AbstractPreprocWrapper, disable_weight_init_clean_groupnorm,
|
34 |
+
prepare_mask_batch, broadcast_image_to_extend, extend_to_batch_size)
|
35 |
+
|
36 |
+
|
37 |
+
# until xformers bug is fixed, do not use xformers for VersatileAttention! TODO: change this when fix is out
|
38 |
+
# logic for choosing optimized_attention method taken from comfy/ldm/modules/attention.py
|
39 |
+
# a fallback_attention_mm is selected to avoid CUDA configuration limitation with pytorch's scaled_dot_product
|
40 |
+
optimized_attention_mm = attention_basic
|
41 |
+
fallback_attention_mm = attention_basic
|
42 |
+
if comfy.model_management.xformers_enabled():
|
43 |
+
pass
|
44 |
+
#optimized_attention_mm = attention_xformers
|
45 |
+
if comfy.model_management.pytorch_attention_enabled():
|
46 |
+
optimized_attention_mm = attention_pytorch
|
47 |
+
if args.use_split_cross_attention:
|
48 |
+
fallback_attention_mm = attention_split
|
49 |
+
else:
|
50 |
+
fallback_attention_mm = attention_sub_quad
|
51 |
+
else:
|
52 |
+
if args.use_split_cross_attention:
|
53 |
+
optimized_attention_mm = attention_split
|
54 |
+
else:
|
55 |
+
optimized_attention_mm = attention_sub_quad
|
56 |
+
|
57 |
+
|
58 |
+
class SparseConst:
|
59 |
+
HINT_MULT = "sparse_hint_mult"
|
60 |
+
NONHINT_MULT = "sparse_nonhint_mult"
|
61 |
+
MASK_MULT = "sparse_mask_mult"
|
62 |
+
|
63 |
+
|
64 |
+
class SparseControlNet(ControlNetCLDM):
|
65 |
+
def __init__(self, *args,**kwargs):
|
66 |
+
super().__init__(*args, **kwargs)
|
67 |
+
hint_channels = kwargs.get("hint_channels")
|
68 |
+
operations: disable_weight_init_clean_groupnorm = kwargs.get("operations", disable_weight_init_clean_groupnorm)
|
69 |
+
device = kwargs.get("device", None)
|
70 |
+
self.use_simplified_conditioning_embedding = kwargs.get("use_simplified_conditioning_embedding", False)
|
71 |
+
if self.use_simplified_conditioning_embedding:
|
72 |
+
self.input_hint_block = TimestepEmbedSequential(
|
73 |
+
zero_module(operations.conv_nd(self.dims, hint_channels, self.model_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
74 |
+
)
|
75 |
+
self.motion_wrapper: SparseCtrlMotionWrapper = None
|
76 |
+
|
77 |
+
def set_actual_length(self, actual_length: int, full_length: int):
|
78 |
+
if self.motion_wrapper is not None:
|
79 |
+
self.motion_wrapper.set_video_length(video_length=actual_length, full_length=full_length)
|
80 |
+
|
81 |
+
def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs):
|
82 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
83 |
+
emb = self.time_embed(t_emb)
|
84 |
+
|
85 |
+
# SparseCtrl sets noisy input to zeros
|
86 |
+
x = torch.zeros_like(x)
|
87 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
88 |
+
|
89 |
+
out_output = []
|
90 |
+
out_middle = []
|
91 |
+
|
92 |
+
hs = []
|
93 |
+
if self.num_classes is not None:
|
94 |
+
assert y.shape[0] == x.shape[0]
|
95 |
+
emb = emb + self.label_emb(y)
|
96 |
+
|
97 |
+
h = x
|
98 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
99 |
+
if guided_hint is not None:
|
100 |
+
h = module(h, emb, context)
|
101 |
+
h += guided_hint
|
102 |
+
guided_hint = None
|
103 |
+
else:
|
104 |
+
h = module(h, emb, context)
|
105 |
+
out_output.append(zero_conv(h, emb, context))
|
106 |
+
|
107 |
+
h = self.middle_block(h, emb, context)
|
108 |
+
out_middle.append(self.middle_block_out(h, emb, context))
|
109 |
+
|
110 |
+
return {"middle": out_middle, "output": out_output}
|
111 |
+
|
112 |
+
|
113 |
+
class SparseModelPatcher(ModelPatcher):
|
114 |
+
def __init__(self, *args, **kwargs):
|
115 |
+
self.model: SparseControlNet
|
116 |
+
super().__init__(*args, **kwargs)
|
117 |
+
|
118 |
+
def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs):
|
119 |
+
to_return = super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs)
|
120 |
+
if lowvram_model_memory > 0:
|
121 |
+
self._patch_lowvram_extras(device_to=device_to)
|
122 |
+
self._handle_float8_pe_tensors()
|
123 |
+
return to_return
|
124 |
+
|
125 |
+
def _patch_lowvram_extras(self, device_to=None):
|
126 |
+
if self.model.motion_wrapper is not None:
|
127 |
+
# figure out the tensors (likely pe's) that should be cast to device besides just the named_modules
|
128 |
+
remaining_tensors = list(self.model.motion_wrapper.state_dict().keys())
|
129 |
+
named_modules = []
|
130 |
+
for n, _ in self.model.motion_wrapper.named_modules():
|
131 |
+
named_modules.append(n)
|
132 |
+
named_modules.append(f"{n}.weight")
|
133 |
+
named_modules.append(f"{n}.bias")
|
134 |
+
for name in named_modules:
|
135 |
+
if name in remaining_tensors:
|
136 |
+
remaining_tensors.remove(name)
|
137 |
+
|
138 |
+
for key in remaining_tensors:
|
139 |
+
self.patch_weight_to_device(key, device_to)
|
140 |
+
if device_to is not None:
|
141 |
+
comfy.utils.set_attr(self.model.motion_wrapper, key, comfy.utils.get_attr(self.model.motion_wrapper, key).to(device_to))
|
142 |
+
|
143 |
+
def _handle_float8_pe_tensors(self):
|
144 |
+
if self.model.motion_wrapper is not None:
|
145 |
+
remaining_tensors = list(self.model.motion_wrapper.state_dict().keys())
|
146 |
+
pe_tensors = [x for x in remaining_tensors if '.pe' in x]
|
147 |
+
is_first = True
|
148 |
+
for key in pe_tensors:
|
149 |
+
if is_first:
|
150 |
+
is_first = False
|
151 |
+
if comfy.utils.get_attr(self.model.motion_wrapper, key).dtype not in [torch.float8_e5m2, torch.float8_e4m3fn]:
|
152 |
+
break
|
153 |
+
comfy.utils.set_attr(self.model.motion_wrapper, key, comfy.utils.get_attr(self.model.motion_wrapper, key).half())
|
154 |
+
|
155 |
+
# NOTE: no longer called by ComfyUI, but here for backwards compatibility
|
156 |
+
def patch_model_lowvram(self, device_to=None, *args, **kwargs):
|
157 |
+
patched_model = super().patch_model_lowvram(device_to, *args, **kwargs)
|
158 |
+
self._patch_lowvram_extras(device_to=device_to)
|
159 |
+
return patched_model
|
160 |
+
|
161 |
+
def clone(self):
|
162 |
+
# normal ModelPatcher clone actions
|
163 |
+
n = SparseModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
164 |
+
n.patches = {}
|
165 |
+
for k in self.patches:
|
166 |
+
n.patches[k] = self.patches[k][:]
|
167 |
+
if hasattr(n, "patches_uuid"):
|
168 |
+
self.patches_uuid = n.patches_uuid
|
169 |
+
|
170 |
+
n.object_patches = self.object_patches.copy()
|
171 |
+
n.model_options = copy.deepcopy(self.model_options)
|
172 |
+
if hasattr(n, "model_keys"):
|
173 |
+
n.model_keys = self.model_keys
|
174 |
+
if hasattr(n, "backup"):
|
175 |
+
self.backup = n.backup
|
176 |
+
if hasattr(n, "object_patches_backup"):
|
177 |
+
self.object_patches_backup = n.object_patches_backup
|
178 |
+
|
179 |
+
|
180 |
+
class PreprocSparseRGBWrapper(AbstractPreprocWrapper):
|
181 |
+
error_msg = error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
|
182 |
+
def __init__(self, condhint: Tensor):
|
183 |
+
super().__init__(condhint)
|
184 |
+
|
185 |
+
|
186 |
+
class SparseContextAware:
|
187 |
+
NEAREST_HINT = "nearest_hint"
|
188 |
+
OFF = "off"
|
189 |
+
|
190 |
+
LIST = [NEAREST_HINT, OFF]
|
191 |
+
|
192 |
+
|
193 |
+
class SparseSettings:
|
194 |
+
def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False,
|
195 |
+
sparse_mask_mult=1.0, sparse_hint_mult=1.0, sparse_nonhint_mult=1.0, context_aware=SparseContextAware.NEAREST_HINT):
|
196 |
+
# account for Steerable-Motion workflow incompatibility;
|
197 |
+
# doing this to for my own peace of mind (not an issue with my code)
|
198 |
+
if type(sparse_method) == str:
|
199 |
+
logger.warn("Outdated Steerable-Motion workflow detected; attempting to auto-convert indexes input. If you experience an error here, consult Steerable-Motion github, NOT Advanced-ControlNet.")
|
200 |
+
sparse_method = SparseIndexMethod(get_idx_list_from_str(sparse_method))
|
201 |
+
self.sparse_method = sparse_method
|
202 |
+
self.use_motion = use_motion
|
203 |
+
self.motion_strength = motion_strength
|
204 |
+
self.motion_scale = motion_scale
|
205 |
+
self.merged = merged
|
206 |
+
self.sparse_mask_mult = float(sparse_mask_mult)
|
207 |
+
self.sparse_hint_mult = float(sparse_hint_mult)
|
208 |
+
self.sparse_nonhint_mult = float(sparse_nonhint_mult)
|
209 |
+
self.context_aware = context_aware
|
210 |
+
|
211 |
+
def is_context_aware(self):
|
212 |
+
return self.context_aware != SparseContextAware.OFF
|
213 |
+
|
214 |
+
@classmethod
|
215 |
+
def default(cls):
|
216 |
+
return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True)
|
217 |
+
|
218 |
+
|
219 |
+
class SparseMethod(ABC):
|
220 |
+
SPREAD = "spread"
|
221 |
+
INDEX = "index"
|
222 |
+
def __init__(self, method: str):
|
223 |
+
self.method = method
|
224 |
+
|
225 |
+
@abstractmethod
|
226 |
+
def _get_indexes(self, hint_length: int, full_length: int) -> list[int]:
|
227 |
+
pass
|
228 |
+
|
229 |
+
def get_indexes(self, hint_length: int, full_length: int, sub_idxs: list[int]=None) -> tuple[list[int], list[int]]:
|
230 |
+
returned_idxs = self._get_indexes(hint_length, full_length)
|
231 |
+
if sub_idxs is None:
|
232 |
+
return returned_idxs, None
|
233 |
+
# need to map full indexes to condhint indexes
|
234 |
+
index_mapping = {}
|
235 |
+
for i, value in enumerate(returned_idxs):
|
236 |
+
index_mapping[value] = i
|
237 |
+
def get_mapped_idxs(idxs: list[int]):
|
238 |
+
return [index_mapping[idx] for idx in idxs]
|
239 |
+
# check if returned_idxs fit within subidxs
|
240 |
+
fitting_idxs = []
|
241 |
+
for sub_idx in sub_idxs:
|
242 |
+
if sub_idx in returned_idxs:
|
243 |
+
fitting_idxs.append(sub_idx)
|
244 |
+
# if have any fitting_idxs, deal with it
|
245 |
+
if len(fitting_idxs) > 0:
|
246 |
+
return fitting_idxs, get_mapped_idxs(fitting_idxs)
|
247 |
+
|
248 |
+
# since no returned_idxs fit in sub_idxs, need to get the next-closest hint images based on strategy
|
249 |
+
def get_closest_idx(target_idx: int, idxs: list[int]):
|
250 |
+
min_idx = -1
|
251 |
+
min_dist = BIGMAX
|
252 |
+
for idx in idxs:
|
253 |
+
new_dist = abs(idx-target_idx)
|
254 |
+
if new_dist < min_dist:
|
255 |
+
min_idx = idx
|
256 |
+
min_dist = new_dist
|
257 |
+
if min_dist == 1:
|
258 |
+
return min_idx, min_dist
|
259 |
+
return min_idx, min_dist
|
260 |
+
start_closest_idx, start_dist = get_closest_idx(sub_idxs[0], returned_idxs)
|
261 |
+
end_closest_idx, end_dist = get_closest_idx(sub_idxs[-1], returned_idxs)
|
262 |
+
# if only one cond hint exists, do special behavior
|
263 |
+
if hint_length == 1:
|
264 |
+
# if same distance from start and end,
|
265 |
+
if start_dist == end_dist:
|
266 |
+
# find center index of sub_idxs
|
267 |
+
center_idx = sub_idxs[np.linspace(0, len(sub_idxs)-1, 3, endpoint=True, dtype=int)[1]]
|
268 |
+
return [center_idx], get_mapped_idxs([start_closest_idx])
|
269 |
+
# otherwise, return closest
|
270 |
+
if start_dist < end_dist:
|
271 |
+
return [sub_idxs[0]], get_mapped_idxs([start_closest_idx])
|
272 |
+
return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx])
|
273 |
+
# otherwise, select up to two closest images, or just 1, whichever one applies best
|
274 |
+
# if same distance from start and end, return two images to use
|
275 |
+
if start_dist == end_dist:
|
276 |
+
return [sub_idxs[0], sub_idxs[-1]], get_mapped_idxs([start_closest_idx, end_closest_idx])
|
277 |
+
# else, use just one
|
278 |
+
if start_dist < end_dist:
|
279 |
+
return [sub_idxs[0]], get_mapped_idxs([start_closest_idx])
|
280 |
+
return [sub_idxs[-1]], get_mapped_idxs([end_closest_idx])
|
281 |
+
|
282 |
+
|
283 |
+
class SparseSpreadMethod(SparseMethod):
|
284 |
+
UNIFORM = "uniform"
|
285 |
+
STARTING = "starting"
|
286 |
+
ENDING = "ending"
|
287 |
+
CENTER = "center"
|
288 |
+
|
289 |
+
LIST = [UNIFORM, STARTING, ENDING, CENTER]
|
290 |
+
|
291 |
+
def __init__(self, spread=UNIFORM):
|
292 |
+
super().__init__(self.SPREAD)
|
293 |
+
self.spread = spread
|
294 |
+
|
295 |
+
def _get_indexes(self, hint_length: int, full_length: int) -> list[int]:
|
296 |
+
# if hint_length >= full_length, limit hints to full_length
|
297 |
+
if hint_length >= full_length:
|
298 |
+
return list(range(full_length))
|
299 |
+
# handle special case of 1 hint image
|
300 |
+
if hint_length == 1:
|
301 |
+
if self.spread in [self.UNIFORM, self.STARTING]:
|
302 |
+
return [0]
|
303 |
+
elif self.spread == self.ENDING:
|
304 |
+
return [full_length-1]
|
305 |
+
elif self.spread == self.CENTER:
|
306 |
+
# return second (of three) values as the center
|
307 |
+
return [np.linspace(0, full_length-1, 3, endpoint=True, dtype=int)[1]]
|
308 |
+
else:
|
309 |
+
raise ValueError(f"Unrecognized spread: {self.spread}")
|
310 |
+
# otherwise, handle other cases
|
311 |
+
if self.spread == self.UNIFORM:
|
312 |
+
return list(np.linspace(0, full_length-1, hint_length, endpoint=True, dtype=int))
|
313 |
+
elif self.spread == self.STARTING:
|
314 |
+
# make split 1 larger, remove last element
|
315 |
+
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1]
|
316 |
+
elif self.spread == self.ENDING:
|
317 |
+
# make split 1 larger, remove first element
|
318 |
+
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[1:]
|
319 |
+
elif self.spread == self.CENTER:
|
320 |
+
# if hint length is not 3 greater than full length, do STARTING behavior
|
321 |
+
if full_length-hint_length < 3:
|
322 |
+
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1]
|
323 |
+
# otherwise, get linspace of 2 greater than needed, then cut off first and last
|
324 |
+
return list(np.linspace(0, full_length-1, hint_length+2, endpoint=True, dtype=int))[1:-1]
|
325 |
+
return ValueError(f"Unrecognized spread: {self.spread}")
|
326 |
+
|
327 |
+
|
328 |
+
class SparseIndexMethod(SparseMethod):
|
329 |
+
def __init__(self, idxs: list[int]):
|
330 |
+
super().__init__(self.INDEX)
|
331 |
+
self.idxs = idxs
|
332 |
+
|
333 |
+
def _get_indexes(self, hint_length: int, full_length: int) -> list[int]:
|
334 |
+
orig_hint_length = hint_length
|
335 |
+
if hint_length > full_length:
|
336 |
+
hint_length = full_length
|
337 |
+
# if idxs is less than hint_length, throw error
|
338 |
+
if len(self.idxs) < hint_length:
|
339 |
+
err_msg = f"There are not enough indexes ({len(self.idxs)}) provided to fit the usable {hint_length} input images."
|
340 |
+
if orig_hint_length != hint_length:
|
341 |
+
err_msg = f"{err_msg} (original input images: {orig_hint_length})"
|
342 |
+
raise ValueError(err_msg)
|
343 |
+
# cap idxs to hint_length
|
344 |
+
idxs = self.idxs[:hint_length]
|
345 |
+
new_idxs = []
|
346 |
+
real_idxs = set()
|
347 |
+
for idx in idxs:
|
348 |
+
if idx < 0:
|
349 |
+
real_idx = full_length+idx
|
350 |
+
if real_idx in real_idxs:
|
351 |
+
raise ValueError(f"Index '{idx}' maps to '{real_idx}' and is duplicate - indexes in Sparse Index Method must be unique.")
|
352 |
+
else:
|
353 |
+
real_idx = idx
|
354 |
+
if real_idx in real_idxs:
|
355 |
+
raise ValueError(f"Index '{idx}' is duplicate (or a negative index is equivalent) - indexes in Sparse Index Method must be unique.")
|
356 |
+
real_idxs.add(real_idx)
|
357 |
+
new_idxs.append(real_idx)
|
358 |
+
return new_idxs
|
359 |
+
|
360 |
+
|
361 |
+
def get_idx_list_from_str(indexes: str) -> list[int]:
|
362 |
+
idxs = []
|
363 |
+
unique_idxs = set()
|
364 |
+
# get indeces from string
|
365 |
+
str_idxs = [x.strip() for x in indexes.strip().split(",")]
|
366 |
+
for str_idx in str_idxs:
|
367 |
+
try:
|
368 |
+
idx = int(str_idx)
|
369 |
+
if idx in unique_idxs:
|
370 |
+
raise ValueError(f"'{idx}' is duplicated; indexes must be unique.")
|
371 |
+
idxs.append(idx)
|
372 |
+
unique_idxs.add(idx)
|
373 |
+
except ValueError:
|
374 |
+
raise ValueError(f"'{str_idx}' is not a valid integer index.")
|
375 |
+
if len(idxs) == 0:
|
376 |
+
raise ValueError(f"No indexes were listed in Sparse Index Method.")
|
377 |
+
return idxs
|
378 |
+
|
379 |
+
|
380 |
+
#########################################
|
381 |
+
# motion-related portion of controlnet
|
382 |
+
class BlockType:
|
383 |
+
UP = "up"
|
384 |
+
DOWN = "down"
|
385 |
+
MID = "mid"
|
386 |
+
|
387 |
+
def get_down_block_max(mm_state_dict: dict[str, Tensor]) -> int:
|
388 |
+
return get_block_max(mm_state_dict, "down_blocks")
|
389 |
+
|
390 |
+
def get_up_block_max(mm_state_dict: dict[str, Tensor]) -> int:
|
391 |
+
return get_block_max(mm_state_dict, "up_blocks")
|
392 |
+
|
393 |
+
def get_block_max(mm_state_dict: dict[str, Tensor], block_name: str) -> int:
|
394 |
+
# keep track of biggest down_block count in module
|
395 |
+
biggest_block = -1
|
396 |
+
for key in mm_state_dict.keys():
|
397 |
+
if block_name in key:
|
398 |
+
try:
|
399 |
+
block_int = key.split(".")[1]
|
400 |
+
block_num = int(block_int)
|
401 |
+
if block_num > biggest_block:
|
402 |
+
biggest_block = block_num
|
403 |
+
except ValueError:
|
404 |
+
pass
|
405 |
+
return biggest_block
|
406 |
+
|
407 |
+
def has_mid_block(mm_state_dict: dict[str, Tensor]):
|
408 |
+
# check if keys contain mid_block
|
409 |
+
for key in mm_state_dict.keys():
|
410 |
+
if key.startswith("mid_block."):
|
411 |
+
return True
|
412 |
+
return False
|
413 |
+
|
414 |
+
def get_position_encoding_max_len(mm_state_dict: dict[str, Tensor], mm_name: str=None) -> int:
|
415 |
+
# use pos_encoder.pe entries to determine max length - [1, {max_length}, {320|640|1280}]
|
416 |
+
for key in mm_state_dict.keys():
|
417 |
+
if key.endswith("pos_encoder.pe"):
|
418 |
+
return mm_state_dict[key].size(1) # get middle dim
|
419 |
+
raise ValueError(f"No pos_encoder.pe found in SparseCtrl state_dict - {mm_name} is not a valid SparseCtrl model!")
|
420 |
+
|
421 |
+
|
422 |
+
class SparseCtrlMotionWrapper(nn.Module):
|
423 |
+
def __init__(self, mm_state_dict: dict[str, Tensor], ops=disable_weight_init_clean_groupnorm):
|
424 |
+
super().__init__()
|
425 |
+
self.down_blocks: Iterable[MotionModule] = None
|
426 |
+
self.up_blocks: Iterable[MotionModule] = None
|
427 |
+
self.mid_block: MotionModule = None
|
428 |
+
self.encoding_max_len = get_position_encoding_max_len(mm_state_dict, "")
|
429 |
+
layer_channels = (320, 640, 1280, 1280)
|
430 |
+
if get_down_block_max(mm_state_dict) > -1:
|
431 |
+
self.down_blocks = nn.ModuleList([])
|
432 |
+
for c in layer_channels:
|
433 |
+
self.down_blocks.append(MotionModule(c, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.DOWN, ops=ops))
|
434 |
+
if get_up_block_max(mm_state_dict) > -1:
|
435 |
+
self.up_blocks = nn.ModuleList([])
|
436 |
+
for c in reversed(layer_channels):
|
437 |
+
self.up_blocks.append(MotionModule(c, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.UP, ops=ops))
|
438 |
+
if has_mid_block(mm_state_dict):
|
439 |
+
self.mid_block = MotionModule(1280, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.MID, ops=ops)
|
440 |
+
|
441 |
+
def inject(self, unet: SparseControlNet):
|
442 |
+
# inject input (down) blocks
|
443 |
+
self._inject(unet.input_blocks, self.down_blocks)
|
444 |
+
# inject mid block, if present
|
445 |
+
if self.mid_block is not None:
|
446 |
+
self._inject([unet.middle_block], [self.mid_block])
|
447 |
+
unet.motion_wrapper = self
|
448 |
+
|
449 |
+
def _inject(self, unet_blocks: nn.ModuleList, mm_blocks: nn.ModuleList):
|
450 |
+
# Rules for injection:
|
451 |
+
# For each component list in a unet block:
|
452 |
+
# if SpatialTransformer exists in list, place next block after last occurrence
|
453 |
+
# elif ResBlock exists in list, place next block after first occurrence
|
454 |
+
# else don't place block
|
455 |
+
injection_count = 0
|
456 |
+
unet_idx = 0
|
457 |
+
# details about blocks passed in
|
458 |
+
per_block = len(mm_blocks[0].motion_modules)
|
459 |
+
injection_goal = len(mm_blocks) * per_block
|
460 |
+
# only stop injecting when modules exhausted
|
461 |
+
while injection_count < injection_goal:
|
462 |
+
# figure out which VanillaTemporalModule from mm to inject
|
463 |
+
mm_blk_idx, mm_vtm_idx = injection_count // per_block, injection_count % per_block
|
464 |
+
# figure out layout of unet block components
|
465 |
+
st_idx = -1 # SpatialTransformer index
|
466 |
+
res_idx = -1 # first ResBlock index
|
467 |
+
# first, figure out indeces of relevant blocks
|
468 |
+
for idx, component in enumerate(unet_blocks[unet_idx]):
|
469 |
+
if type(component) == SpatialTransformer:
|
470 |
+
st_idx = idx
|
471 |
+
elif type(component).__name__ == "ResBlock" and res_idx < 0:
|
472 |
+
res_idx = idx
|
473 |
+
# if SpatialTransformer exists, inject right after
|
474 |
+
if st_idx >= 0:
|
475 |
+
unet_blocks[unet_idx].insert(st_idx+1, mm_blocks[mm_blk_idx].motion_modules[mm_vtm_idx])
|
476 |
+
injection_count += 1
|
477 |
+
# otherwise, if only ResBlock exists, inject right after
|
478 |
+
elif res_idx >= 0:
|
479 |
+
unet_blocks[unet_idx].insert(res_idx+1, mm_blocks[mm_blk_idx].motion_modules[mm_vtm_idx])
|
480 |
+
injection_count += 1
|
481 |
+
# increment unet_idx
|
482 |
+
unet_idx += 1
|
483 |
+
|
484 |
+
def eject(self, unet: SparseControlNet):
|
485 |
+
# remove from input blocks (downblocks)
|
486 |
+
self._eject(unet.input_blocks)
|
487 |
+
# remove from middle block (encapsulate in list to make compatible)
|
488 |
+
self._eject([unet.middle_block])
|
489 |
+
del unet.motion_wrapper
|
490 |
+
unet.motion_wrapper = None
|
491 |
+
|
492 |
+
def _eject(self, unet_blocks: nn.ModuleList):
|
493 |
+
# eject all VanillaTemporalModule objects from all blocks
|
494 |
+
for block in unet_blocks:
|
495 |
+
idx_to_pop = []
|
496 |
+
for idx, component in enumerate(block):
|
497 |
+
if type(component) == VanillaTemporalModule:
|
498 |
+
idx_to_pop.append(idx)
|
499 |
+
# pop in backwards order, as to not disturb what the indeces refer to
|
500 |
+
for idx in sorted(idx_to_pop, reverse=True):
|
501 |
+
block.pop(idx)
|
502 |
+
|
503 |
+
def set_video_length(self, video_length: int, full_length: int):
|
504 |
+
self.AD_video_length = video_length
|
505 |
+
if self.down_blocks is not None:
|
506 |
+
for block in self.down_blocks:
|
507 |
+
block.set_video_length(video_length, full_length)
|
508 |
+
if self.up_blocks is not None:
|
509 |
+
for block in self.up_blocks:
|
510 |
+
block.set_video_length(video_length, full_length)
|
511 |
+
if self.mid_block is not None:
|
512 |
+
self.mid_block.set_video_length(video_length, full_length)
|
513 |
+
|
514 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
515 |
+
if self.down_blocks is not None:
|
516 |
+
for block in self.down_blocks:
|
517 |
+
block.set_scale_multiplier(multiplier)
|
518 |
+
if self.up_blocks is not None:
|
519 |
+
for block in self.up_blocks:
|
520 |
+
block.set_scale_multiplier(multiplier)
|
521 |
+
if self.mid_block is not None:
|
522 |
+
self.mid_block.set_scale_multiplier(multiplier)
|
523 |
+
|
524 |
+
def set_strength(self, strength: float):
|
525 |
+
if self.down_blocks is not None:
|
526 |
+
for block in self.down_blocks:
|
527 |
+
block.set_strength(strength)
|
528 |
+
if self.up_blocks is not None:
|
529 |
+
for block in self.up_blocks:
|
530 |
+
block.set_strength(strength)
|
531 |
+
if self.mid_block is not None:
|
532 |
+
self.mid_block.set_strength(strength)
|
533 |
+
|
534 |
+
def reset_temp_vars(self):
|
535 |
+
if self.down_blocks is not None:
|
536 |
+
for block in self.down_blocks:
|
537 |
+
block.reset_temp_vars()
|
538 |
+
if self.up_blocks is not None:
|
539 |
+
for block in self.up_blocks:
|
540 |
+
block.reset_temp_vars()
|
541 |
+
if self.mid_block is not None:
|
542 |
+
self.mid_block.reset_temp_vars()
|
543 |
+
|
544 |
+
def reset_scale_multiplier(self):
|
545 |
+
self.set_scale_multiplier(None)
|
546 |
+
|
547 |
+
def reset(self):
|
548 |
+
self.reset_scale_multiplier()
|
549 |
+
self.reset_temp_vars()
|
550 |
+
|
551 |
+
|
552 |
+
class MotionModule(nn.Module):
|
553 |
+
def __init__(self, in_channels, temporal_position_encoding_max_len=24, block_type: str=BlockType.DOWN, ops=disable_weight_init_clean_groupnorm):
|
554 |
+
super().__init__()
|
555 |
+
if block_type == BlockType.MID:
|
556 |
+
# mid blocks contain only a single VanillaTemporalModule
|
557 |
+
self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList([get_motion_module(in_channels, temporal_position_encoding_max_len, ops=ops)])
|
558 |
+
else:
|
559 |
+
# down blocks contain two VanillaTemporalModules
|
560 |
+
self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList(
|
561 |
+
[
|
562 |
+
get_motion_module(in_channels, temporal_position_encoding_max_len, ops=ops),
|
563 |
+
get_motion_module(in_channels, temporal_position_encoding_max_len, ops=ops)
|
564 |
+
]
|
565 |
+
)
|
566 |
+
# up blocks contain one additional VanillaTemporalModule
|
567 |
+
if block_type == BlockType.UP:
|
568 |
+
self.motion_modules.append(get_motion_module(in_channels, temporal_position_encoding_max_len, ops=ops))
|
569 |
+
|
570 |
+
def set_video_length(self, video_length: int, full_length: int):
|
571 |
+
for motion_module in self.motion_modules:
|
572 |
+
motion_module.set_video_length(video_length, full_length)
|
573 |
+
|
574 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
575 |
+
for motion_module in self.motion_modules:
|
576 |
+
motion_module.set_scale_multiplier(multiplier)
|
577 |
+
|
578 |
+
def set_masks(self, masks: Tensor, min_val: float, max_val: float):
|
579 |
+
for motion_module in self.motion_modules:
|
580 |
+
motion_module.set_masks(masks, min_val, max_val)
|
581 |
+
|
582 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
583 |
+
for motion_module in self.motion_modules:
|
584 |
+
motion_module.set_sub_idxs(sub_idxs)
|
585 |
+
|
586 |
+
def set_strength(self, strength: float):
|
587 |
+
for motion_module in self.motion_modules:
|
588 |
+
motion_module.set_strength(strength)
|
589 |
+
|
590 |
+
def reset_temp_vars(self):
|
591 |
+
for motion_module in self.motion_modules:
|
592 |
+
motion_module.reset_temp_vars()
|
593 |
+
|
594 |
+
|
595 |
+
def get_motion_module(in_channels, temporal_position_encoding_max_len, ops=disable_weight_init_clean_groupnorm):
|
596 |
+
# unlike normal AD, there is only one attention block expected in SparseCtrl models
|
597 |
+
return VanillaTemporalModule(in_channels=in_channels, attention_block_types=("Temporal_Self",), temporal_position_encoding_max_len=temporal_position_encoding_max_len, ops=ops)
|
598 |
+
|
599 |
+
|
600 |
+
class VanillaTemporalModule(nn.Module):
|
601 |
+
def __init__(
|
602 |
+
self,
|
603 |
+
in_channels,
|
604 |
+
num_attention_heads=8,
|
605 |
+
num_transformer_block=1,
|
606 |
+
attention_block_types=("Temporal_Self", "Temporal_Self"),
|
607 |
+
cross_frame_attention_mode=None,
|
608 |
+
temporal_position_encoding=True,
|
609 |
+
temporal_position_encoding_max_len=24,
|
610 |
+
temporal_attention_dim_div=1,
|
611 |
+
zero_initialize=True,
|
612 |
+
ops=disable_weight_init_clean_groupnorm,
|
613 |
+
):
|
614 |
+
super().__init__()
|
615 |
+
self.strength = 1.0
|
616 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
617 |
+
in_channels=in_channels,
|
618 |
+
num_attention_heads=num_attention_heads,
|
619 |
+
attention_head_dim=in_channels
|
620 |
+
// num_attention_heads
|
621 |
+
// temporal_attention_dim_div,
|
622 |
+
num_layers=num_transformer_block,
|
623 |
+
attention_block_types=attention_block_types,
|
624 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
625 |
+
temporal_position_encoding=temporal_position_encoding,
|
626 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
627 |
+
ops=ops,
|
628 |
+
)
|
629 |
+
|
630 |
+
if zero_initialize:
|
631 |
+
self.temporal_transformer.proj_out = zero_module(
|
632 |
+
self.temporal_transformer.proj_out
|
633 |
+
)
|
634 |
+
|
635 |
+
def set_video_length(self, video_length: int, full_length: int):
|
636 |
+
self.temporal_transformer.set_video_length(video_length, full_length)
|
637 |
+
|
638 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
639 |
+
self.temporal_transformer.set_scale_multiplier(multiplier)
|
640 |
+
|
641 |
+
def set_masks(self, masks: Tensor, min_val: float, max_val: float):
|
642 |
+
self.temporal_transformer.set_masks(masks, min_val, max_val)
|
643 |
+
|
644 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
645 |
+
self.temporal_transformer.set_sub_idxs(sub_idxs)
|
646 |
+
|
647 |
+
def set_strength(self, strength: float):
|
648 |
+
self.strength = strength
|
649 |
+
|
650 |
+
def reset_temp_vars(self):
|
651 |
+
self.set_strength(1.0)
|
652 |
+
self.temporal_transformer.reset_temp_vars()
|
653 |
+
|
654 |
+
def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None):
|
655 |
+
if math.isclose(self.strength, 1.0):
|
656 |
+
return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)
|
657 |
+
elif math.isclose(self.strength, 0.0):
|
658 |
+
return input_tensor
|
659 |
+
# elif self.strength > 1.0:
|
660 |
+
# return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)*self.strength
|
661 |
+
else:
|
662 |
+
return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)*self.strength + input_tensor*(1.0-self.strength)
|
663 |
+
|
664 |
+
|
665 |
+
class TemporalTransformer3DModel(nn.Module):
|
666 |
+
def __init__(
|
667 |
+
self,
|
668 |
+
in_channels,
|
669 |
+
num_attention_heads,
|
670 |
+
attention_head_dim,
|
671 |
+
num_layers,
|
672 |
+
attention_block_types=(
|
673 |
+
"Temporal_Self",
|
674 |
+
"Temporal_Self",
|
675 |
+
),
|
676 |
+
dropout=0.0,
|
677 |
+
norm_num_groups=32,
|
678 |
+
cross_attention_dim=768,
|
679 |
+
activation_fn="geglu",
|
680 |
+
attention_bias=False,
|
681 |
+
upcast_attention=False,
|
682 |
+
cross_frame_attention_mode=None,
|
683 |
+
temporal_position_encoding=False,
|
684 |
+
temporal_position_encoding_max_len=24,
|
685 |
+
ops=disable_weight_init_clean_groupnorm,
|
686 |
+
):
|
687 |
+
super().__init__()
|
688 |
+
self.video_length = 16
|
689 |
+
self.full_length = 16
|
690 |
+
self.scale_min = 1.0
|
691 |
+
self.scale_max = 1.0
|
692 |
+
self.raw_scale_mask: Union[Tensor, None] = None
|
693 |
+
self.temp_scale_mask: Union[Tensor, None] = None
|
694 |
+
self.sub_idxs: Union[list[int], None] = None
|
695 |
+
self.prev_hidden_states_batch = 0
|
696 |
+
|
697 |
+
|
698 |
+
inner_dim = num_attention_heads * attention_head_dim
|
699 |
+
|
700 |
+
self.norm = ops.GroupNorm(
|
701 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
702 |
+
)
|
703 |
+
self.proj_in = ops.Linear(in_channels, inner_dim)
|
704 |
+
|
705 |
+
self.transformer_blocks: Iterable[TemporalTransformerBlock] = nn.ModuleList(
|
706 |
+
[
|
707 |
+
TemporalTransformerBlock(
|
708 |
+
dim=inner_dim,
|
709 |
+
num_attention_heads=num_attention_heads,
|
710 |
+
attention_head_dim=attention_head_dim,
|
711 |
+
attention_block_types=attention_block_types,
|
712 |
+
dropout=dropout,
|
713 |
+
norm_num_groups=norm_num_groups,
|
714 |
+
cross_attention_dim=cross_attention_dim,
|
715 |
+
activation_fn=activation_fn,
|
716 |
+
attention_bias=attention_bias,
|
717 |
+
upcast_attention=upcast_attention,
|
718 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
719 |
+
temporal_position_encoding=temporal_position_encoding,
|
720 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
721 |
+
ops=ops,
|
722 |
+
)
|
723 |
+
for d in range(num_layers)
|
724 |
+
]
|
725 |
+
)
|
726 |
+
self.proj_out = ops.Linear(inner_dim, in_channels)
|
727 |
+
|
728 |
+
def set_video_length(self, video_length: int, full_length: int):
|
729 |
+
self.video_length = video_length
|
730 |
+
self.full_length = full_length
|
731 |
+
|
732 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
733 |
+
for block in self.transformer_blocks:
|
734 |
+
block.set_scale_multiplier(multiplier)
|
735 |
+
|
736 |
+
def set_masks(self, masks: Tensor, min_val: float, max_val: float):
|
737 |
+
self.scale_min = min_val
|
738 |
+
self.scale_max = max_val
|
739 |
+
self.raw_scale_mask = masks
|
740 |
+
|
741 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
742 |
+
self.sub_idxs = sub_idxs
|
743 |
+
for block in self.transformer_blocks:
|
744 |
+
block.set_sub_idxs(sub_idxs)
|
745 |
+
|
746 |
+
def reset_temp_vars(self):
|
747 |
+
del self.temp_scale_mask
|
748 |
+
self.temp_scale_mask = None
|
749 |
+
self.prev_hidden_states_batch = 0
|
750 |
+
for block in self.transformer_blocks:
|
751 |
+
block.reset_temp_vars()
|
752 |
+
|
753 |
+
def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]:
|
754 |
+
# if no raw mask, return None
|
755 |
+
if self.raw_scale_mask is None:
|
756 |
+
return None
|
757 |
+
shape = hidden_states.shape
|
758 |
+
batch, channel, height, width = shape
|
759 |
+
# if temp mask already calculated, return it
|
760 |
+
if self.temp_scale_mask != None:
|
761 |
+
# check if hidden_states batch matches
|
762 |
+
if batch == self.prev_hidden_states_batch:
|
763 |
+
if self.sub_idxs is not None:
|
764 |
+
return self.temp_scale_mask[:, self.sub_idxs, :]
|
765 |
+
return self.temp_scale_mask
|
766 |
+
# if does not match, reset cached temp_scale_mask and recalculate it
|
767 |
+
del self.temp_scale_mask
|
768 |
+
self.temp_scale_mask = None
|
769 |
+
# otherwise, calculate temp mask
|
770 |
+
self.prev_hidden_states_batch = batch
|
771 |
+
mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width))
|
772 |
+
mask = extend_to_batch_size(mask, self.full_length)
|
773 |
+
# if mask not the same amount length as full length, make it match
|
774 |
+
if self.full_length != mask.shape[0]:
|
775 |
+
mask = broadcast_image_to_extend(mask, self.full_length, 1)
|
776 |
+
# reshape mask to attention K shape (h*w, latent_count, 1)
|
777 |
+
batch, channel, height, width = mask.shape
|
778 |
+
# first, perform same operations as on hidden_states,
|
779 |
+
# turning (b, c, h, w) -> (b, h*w, c)
|
780 |
+
mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel)
|
781 |
+
# then, make it the same shape as attention's k, (h*w, b, c)
|
782 |
+
mask = mask.permute(1, 0, 2)
|
783 |
+
# make masks match the expected length of h*w
|
784 |
+
batched_number = shape[0] // self.video_length
|
785 |
+
if batched_number > 1:
|
786 |
+
mask = torch.cat([mask] * batched_number, dim=0)
|
787 |
+
# cache mask and set to proper device
|
788 |
+
self.temp_scale_mask = mask
|
789 |
+
# move temp_scale_mask to proper dtype + device
|
790 |
+
self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device)
|
791 |
+
# return subset of masks, if needed
|
792 |
+
if self.sub_idxs is not None:
|
793 |
+
return self.temp_scale_mask[:, self.sub_idxs, :]
|
794 |
+
return self.temp_scale_mask
|
795 |
+
|
796 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
797 |
+
batch, channel, height, width = hidden_states.shape
|
798 |
+
residual = hidden_states
|
799 |
+
scale_mask = self.get_scale_mask(hidden_states)
|
800 |
+
# add some casts for fp8 purposes - does not affect speed otherwise
|
801 |
+
hidden_states = self.norm(hidden_states).to(hidden_states.dtype)
|
802 |
+
inner_dim = hidden_states.shape[1]
|
803 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
804 |
+
batch, height * width, inner_dim
|
805 |
+
)
|
806 |
+
hidden_states = self.proj_in(hidden_states).to(hidden_states.dtype)
|
807 |
+
|
808 |
+
# Transformer Blocks
|
809 |
+
for block in self.transformer_blocks:
|
810 |
+
hidden_states = block(
|
811 |
+
hidden_states,
|
812 |
+
encoder_hidden_states=encoder_hidden_states,
|
813 |
+
attention_mask=attention_mask,
|
814 |
+
video_length=self.video_length,
|
815 |
+
scale_mask=scale_mask
|
816 |
+
)
|
817 |
+
|
818 |
+
# output
|
819 |
+
hidden_states = self.proj_out(hidden_states)
|
820 |
+
hidden_states = (
|
821 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
822 |
+
.permute(0, 3, 1, 2)
|
823 |
+
.contiguous()
|
824 |
+
)
|
825 |
+
|
826 |
+
output = hidden_states + residual
|
827 |
+
|
828 |
+
return output
|
829 |
+
|
830 |
+
|
831 |
+
class TemporalTransformerBlock(nn.Module):
|
832 |
+
def __init__(
|
833 |
+
self,
|
834 |
+
dim,
|
835 |
+
num_attention_heads,
|
836 |
+
attention_head_dim,
|
837 |
+
attention_block_types=(
|
838 |
+
"Temporal_Self",
|
839 |
+
"Temporal_Self",
|
840 |
+
),
|
841 |
+
dropout=0.0,
|
842 |
+
norm_num_groups=32,
|
843 |
+
cross_attention_dim=768,
|
844 |
+
activation_fn="geglu",
|
845 |
+
attention_bias=False,
|
846 |
+
upcast_attention=False,
|
847 |
+
cross_frame_attention_mode=None,
|
848 |
+
temporal_position_encoding=False,
|
849 |
+
temporal_position_encoding_max_len=24,
|
850 |
+
ops=disable_weight_init_clean_groupnorm,
|
851 |
+
):
|
852 |
+
super().__init__()
|
853 |
+
|
854 |
+
attention_blocks = []
|
855 |
+
norms = []
|
856 |
+
|
857 |
+
for block_name in attention_block_types:
|
858 |
+
attention_blocks.append(
|
859 |
+
VersatileAttention(
|
860 |
+
attention_mode=block_name.split("_")[0],
|
861 |
+
context_dim=cross_attention_dim # called context_dim for ComfyUI impl
|
862 |
+
if block_name.endswith("_Cross")
|
863 |
+
else None,
|
864 |
+
query_dim=dim,
|
865 |
+
heads=num_attention_heads,
|
866 |
+
dim_head=attention_head_dim,
|
867 |
+
dropout=dropout,
|
868 |
+
#bias=attention_bias, # remove for Comfy CrossAttention
|
869 |
+
#upcast_attention=upcast_attention, # remove for Comfy CrossAttention
|
870 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
871 |
+
temporal_position_encoding=temporal_position_encoding,
|
872 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
873 |
+
ops=ops,
|
874 |
+
)
|
875 |
+
)
|
876 |
+
norms.append(ops.LayerNorm(dim))
|
877 |
+
|
878 |
+
self.attention_blocks: Iterable[VersatileAttention] = nn.ModuleList(attention_blocks)
|
879 |
+
self.norms = nn.ModuleList(norms)
|
880 |
+
|
881 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"), operations=ops)
|
882 |
+
self.ff_norm = ops.LayerNorm(dim)
|
883 |
+
|
884 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
885 |
+
for block in self.attention_blocks:
|
886 |
+
block.set_scale_multiplier(multiplier)
|
887 |
+
|
888 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
889 |
+
for block in self.attention_blocks:
|
890 |
+
block.set_sub_idxs(sub_idxs)
|
891 |
+
|
892 |
+
def reset_temp_vars(self):
|
893 |
+
for block in self.attention_blocks:
|
894 |
+
block.reset_temp_vars()
|
895 |
+
|
896 |
+
def forward(
|
897 |
+
self,
|
898 |
+
hidden_states,
|
899 |
+
encoder_hidden_states=None,
|
900 |
+
attention_mask=None,
|
901 |
+
video_length=None,
|
902 |
+
scale_mask=None
|
903 |
+
):
|
904 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
905 |
+
norm_hidden_states = norm(hidden_states).to(hidden_states.dtype)
|
906 |
+
hidden_states = (
|
907 |
+
attention_block(
|
908 |
+
norm_hidden_states,
|
909 |
+
encoder_hidden_states=encoder_hidden_states
|
910 |
+
if attention_block.is_cross_attention
|
911 |
+
else None,
|
912 |
+
attention_mask=attention_mask,
|
913 |
+
video_length=video_length,
|
914 |
+
scale_mask=scale_mask
|
915 |
+
)
|
916 |
+
+ hidden_states
|
917 |
+
)
|
918 |
+
|
919 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
920 |
+
|
921 |
+
output = hidden_states
|
922 |
+
return output
|
923 |
+
|
924 |
+
|
925 |
+
class PositionalEncoding(nn.Module):
|
926 |
+
def __init__(self, d_model, dropout=0.0, max_len=24):
|
927 |
+
super().__init__()
|
928 |
+
self.dropout = nn.Dropout(p=dropout)
|
929 |
+
position = torch.arange(max_len).unsqueeze(1)
|
930 |
+
div_term = torch.exp(
|
931 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
932 |
+
)
|
933 |
+
pe = torch.zeros(1, max_len, d_model)
|
934 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
935 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
936 |
+
self.register_buffer("pe", pe)
|
937 |
+
self.sub_idxs = None
|
938 |
+
|
939 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
940 |
+
self.sub_idxs = sub_idxs
|
941 |
+
|
942 |
+
def forward(self, x):
|
943 |
+
#if self.sub_idxs is not None:
|
944 |
+
# x = x + self.pe[:, self.sub_idxs]
|
945 |
+
#else:
|
946 |
+
x = x + self.pe[:, : x.size(1)]
|
947 |
+
return self.dropout(x)
|
948 |
+
|
949 |
+
|
950 |
+
class CrossAttentionMMSparse(nn.Module):
|
951 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None,
|
952 |
+
operations=disable_weight_init_clean_groupnorm):
|
953 |
+
super().__init__()
|
954 |
+
inner_dim = dim_head * heads
|
955 |
+
context_dim = default(context_dim, query_dim)
|
956 |
+
|
957 |
+
self.actual_attention = optimized_attention_mm
|
958 |
+
self.heads = heads
|
959 |
+
self.dim_head = dim_head
|
960 |
+
self.scale = None
|
961 |
+
|
962 |
+
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
963 |
+
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
964 |
+
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
965 |
+
|
966 |
+
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
967 |
+
|
968 |
+
def reset_attention_type(self):
|
969 |
+
self.actual_attention = optimized_attention_mm
|
970 |
+
|
971 |
+
def forward(self, x, context=None, value=None, mask=None, scale_mask=None):
|
972 |
+
q = self.to_q(x)
|
973 |
+
context = default(context, x)
|
974 |
+
k: Tensor = self.to_k(context)
|
975 |
+
if value is not None:
|
976 |
+
v = self.to_v(value)
|
977 |
+
del value
|
978 |
+
else:
|
979 |
+
v = self.to_v(context)
|
980 |
+
|
981 |
+
# apply custom scale by multiplying k by scale factor
|
982 |
+
if self.scale is not None:
|
983 |
+
k *= self.scale
|
984 |
+
|
985 |
+
# apply scale mask, if present
|
986 |
+
if scale_mask is not None:
|
987 |
+
k *= scale_mask
|
988 |
+
|
989 |
+
try:
|
990 |
+
out = self.actual_attention(q, k, v, self.heads, mask)
|
991 |
+
except RuntimeError as e:
|
992 |
+
if str(e).startswith("CUDA error: invalid configuration argument"):
|
993 |
+
self.actual_attention = fallback_attention_mm
|
994 |
+
out = self.actual_attention(q, k, v, self.heads, mask)
|
995 |
+
else:
|
996 |
+
raise
|
997 |
+
return self.to_out(out)
|
998 |
+
|
999 |
+
|
1000 |
+
class VersatileAttention(CrossAttentionMMSparse):
|
1001 |
+
def __init__(
|
1002 |
+
self,
|
1003 |
+
attention_mode=None,
|
1004 |
+
cross_frame_attention_mode=None,
|
1005 |
+
temporal_position_encoding=False,
|
1006 |
+
temporal_position_encoding_max_len=24,
|
1007 |
+
ops=disable_weight_init_clean_groupnorm,
|
1008 |
+
*args,
|
1009 |
+
**kwargs,
|
1010 |
+
):
|
1011 |
+
super().__init__(operations=ops, *args, **kwargs)
|
1012 |
+
assert attention_mode == "Temporal"
|
1013 |
+
|
1014 |
+
self.attention_mode = attention_mode
|
1015 |
+
self.is_cross_attention = kwargs["context_dim"] is not None
|
1016 |
+
|
1017 |
+
self.pos_encoder = (
|
1018 |
+
PositionalEncoding(
|
1019 |
+
kwargs["query_dim"],
|
1020 |
+
dropout=0.0,
|
1021 |
+
max_len=temporal_position_encoding_max_len,
|
1022 |
+
)
|
1023 |
+
if (temporal_position_encoding and attention_mode == "Temporal")
|
1024 |
+
else None
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
def extra_repr(self):
|
1028 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
1029 |
+
|
1030 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
1031 |
+
if multiplier is None or math.isclose(multiplier, 1.0):
|
1032 |
+
self.scale = None
|
1033 |
+
else:
|
1034 |
+
self.scale = multiplier
|
1035 |
+
|
1036 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
1037 |
+
if self.pos_encoder != None:
|
1038 |
+
self.pos_encoder.set_sub_idxs(sub_idxs)
|
1039 |
+
|
1040 |
+
def reset_temp_vars(self):
|
1041 |
+
self.reset_attention_type()
|
1042 |
+
|
1043 |
+
def forward(
|
1044 |
+
self,
|
1045 |
+
hidden_states: Tensor,
|
1046 |
+
encoder_hidden_states=None,
|
1047 |
+
attention_mask=None,
|
1048 |
+
video_length=None,
|
1049 |
+
scale_mask=None,
|
1050 |
+
):
|
1051 |
+
if self.attention_mode != "Temporal":
|
1052 |
+
raise NotImplementedError
|
1053 |
+
|
1054 |
+
d = hidden_states.shape[1]
|
1055 |
+
hidden_states = rearrange(
|
1056 |
+
hidden_states, "(b f) d c -> (b d) f c", f=video_length
|
1057 |
+
)
|
1058 |
+
|
1059 |
+
if self.pos_encoder is not None:
|
1060 |
+
hidden_states = self.pos_encoder(hidden_states).to(hidden_states.dtype)
|
1061 |
+
|
1062 |
+
encoder_hidden_states = (
|
1063 |
+
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
|
1064 |
+
if encoder_hidden_states is not None
|
1065 |
+
else encoder_hidden_states
|
1066 |
+
)
|
1067 |
+
|
1068 |
+
hidden_states = super().forward(
|
1069 |
+
hidden_states,
|
1070 |
+
encoder_hidden_states,
|
1071 |
+
value=None,
|
1072 |
+
mask=attention_mask,
|
1073 |
+
scale_mask=scale_mask,
|
1074 |
+
)
|
1075 |
+
|
1076 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
1077 |
+
|
1078 |
+
return hidden_states
|
ComfyUI-Advanced-ControlNet/adv_control/control_svd.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
import comfy.model_detection
|
6 |
+
from comfy.utils import UNET_MAP_BASIC, UNET_MAP_RESNET, UNET_MAP_ATTENTIONS, TRANSFORMER_BLOCKS
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
from comfy.ldm.modules.diffusionmodules.util import (
|
12 |
+
zero_module,
|
13 |
+
timestep_embedding,
|
14 |
+
)
|
15 |
+
|
16 |
+
from comfy.ldm.modules.attention import SpatialVideoTransformer
|
17 |
+
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, VideoResBlock, Downsample
|
18 |
+
from comfy.ldm.util import exists
|
19 |
+
import comfy.ops
|
20 |
+
|
21 |
+
|
22 |
+
class SVDControlNet(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
image_size,
|
26 |
+
in_channels,
|
27 |
+
model_channels,
|
28 |
+
hint_channels,
|
29 |
+
num_res_blocks,
|
30 |
+
dropout=0,
|
31 |
+
channel_mult=(1, 2, 4, 8),
|
32 |
+
conv_resample=True,
|
33 |
+
dims=2,
|
34 |
+
num_classes=None,
|
35 |
+
use_checkpoint=False,
|
36 |
+
dtype=torch.float32,
|
37 |
+
num_heads=-1,
|
38 |
+
num_head_channels=-1,
|
39 |
+
num_heads_upsample=-1,
|
40 |
+
use_scale_shift_norm=False,
|
41 |
+
resblock_updown=False,
|
42 |
+
use_new_attention_order=False,
|
43 |
+
use_spatial_transformer=False, # custom transformer support
|
44 |
+
transformer_depth=1, # custom transformer support
|
45 |
+
context_dim=None, # custom transformer support
|
46 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
47 |
+
legacy=True,
|
48 |
+
disable_self_attentions=None,
|
49 |
+
num_attention_blocks=None,
|
50 |
+
disable_middle_self_attn=False,
|
51 |
+
use_linear_in_transformer=False,
|
52 |
+
adm_in_channels=None,
|
53 |
+
transformer_depth_middle=None,
|
54 |
+
transformer_depth_output=None,
|
55 |
+
use_spatial_context=False,
|
56 |
+
extra_ff_mix_layer=False,
|
57 |
+
merge_strategy="fixed",
|
58 |
+
merge_factor=0.5,
|
59 |
+
video_kernel_size=3,
|
60 |
+
device=None,
|
61 |
+
operations=comfy.ops.disable_weight_init,
|
62 |
+
**kwargs,
|
63 |
+
):
|
64 |
+
super().__init__()
|
65 |
+
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
66 |
+
if use_spatial_transformer:
|
67 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
68 |
+
|
69 |
+
if context_dim is not None:
|
70 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
71 |
+
# from omegaconf.listconfig import ListConfig
|
72 |
+
# if type(context_dim) == ListConfig:
|
73 |
+
# context_dim = list(context_dim)
|
74 |
+
|
75 |
+
if num_heads_upsample == -1:
|
76 |
+
num_heads_upsample = num_heads
|
77 |
+
|
78 |
+
if num_heads == -1:
|
79 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
80 |
+
|
81 |
+
if num_head_channels == -1:
|
82 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
83 |
+
|
84 |
+
self.dims = dims
|
85 |
+
self.image_size = image_size
|
86 |
+
self.in_channels = in_channels
|
87 |
+
self.model_channels = model_channels
|
88 |
+
|
89 |
+
if isinstance(num_res_blocks, int):
|
90 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
91 |
+
else:
|
92 |
+
if len(num_res_blocks) != len(channel_mult):
|
93 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
94 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
95 |
+
self.num_res_blocks = num_res_blocks
|
96 |
+
|
97 |
+
if disable_self_attentions is not None:
|
98 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
99 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
100 |
+
if num_attention_blocks is not None:
|
101 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
102 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
103 |
+
|
104 |
+
transformer_depth = transformer_depth[:]
|
105 |
+
|
106 |
+
self.dropout = dropout
|
107 |
+
self.channel_mult = channel_mult
|
108 |
+
self.conv_resample = conv_resample
|
109 |
+
self.num_classes = num_classes
|
110 |
+
self.use_checkpoint = use_checkpoint
|
111 |
+
self.dtype = dtype
|
112 |
+
self.num_heads = num_heads
|
113 |
+
self.num_head_channels = num_head_channels
|
114 |
+
self.num_heads_upsample = num_heads_upsample
|
115 |
+
self.predict_codebook_ids = n_embed is not None
|
116 |
+
|
117 |
+
time_embed_dim = model_channels * 4
|
118 |
+
self.time_embed = nn.Sequential(
|
119 |
+
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
120 |
+
nn.SiLU(),
|
121 |
+
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
122 |
+
)
|
123 |
+
|
124 |
+
if self.num_classes is not None:
|
125 |
+
if isinstance(self.num_classes, int):
|
126 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
127 |
+
elif self.num_classes == "continuous":
|
128 |
+
print("setting up linear c_adm embedding layer")
|
129 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
130 |
+
elif self.num_classes == "sequential":
|
131 |
+
assert adm_in_channels is not None
|
132 |
+
self.label_emb = nn.Sequential(
|
133 |
+
nn.Sequential(
|
134 |
+
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
135 |
+
nn.SiLU(),
|
136 |
+
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
137 |
+
)
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
raise ValueError()
|
141 |
+
|
142 |
+
self.input_blocks = nn.ModuleList(
|
143 |
+
[
|
144 |
+
TimestepEmbedSequential(
|
145 |
+
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
146 |
+
)
|
147 |
+
]
|
148 |
+
)
|
149 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
150 |
+
|
151 |
+
self.input_hint_block = TimestepEmbedSequential(
|
152 |
+
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
153 |
+
nn.SiLU(),
|
154 |
+
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
155 |
+
nn.SiLU(),
|
156 |
+
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
157 |
+
nn.SiLU(),
|
158 |
+
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
159 |
+
nn.SiLU(),
|
160 |
+
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
161 |
+
nn.SiLU(),
|
162 |
+
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
163 |
+
nn.SiLU(),
|
164 |
+
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
165 |
+
nn.SiLU(),
|
166 |
+
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
167 |
+
)
|
168 |
+
|
169 |
+
self._feature_size = model_channels
|
170 |
+
input_block_chans = [model_channels]
|
171 |
+
ch = model_channels
|
172 |
+
ds = 1
|
173 |
+
for level, mult in enumerate(channel_mult):
|
174 |
+
for nr in range(self.num_res_blocks[level]):
|
175 |
+
layers = [
|
176 |
+
VideoResBlock(
|
177 |
+
ch,
|
178 |
+
time_embed_dim,
|
179 |
+
dropout,
|
180 |
+
out_channels=mult * model_channels,
|
181 |
+
dims=dims,
|
182 |
+
use_checkpoint=use_checkpoint,
|
183 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
184 |
+
dtype=self.dtype,
|
185 |
+
device=device,
|
186 |
+
operations=operations,
|
187 |
+
video_kernel_size=video_kernel_size,
|
188 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
189 |
+
)
|
190 |
+
]
|
191 |
+
ch = mult * model_channels
|
192 |
+
num_transformers = transformer_depth.pop(0)
|
193 |
+
if num_transformers > 0:
|
194 |
+
if num_head_channels == -1:
|
195 |
+
dim_head = ch // num_heads
|
196 |
+
else:
|
197 |
+
num_heads = ch // num_head_channels
|
198 |
+
dim_head = num_head_channels
|
199 |
+
if legacy:
|
200 |
+
#num_heads = 1
|
201 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
202 |
+
if exists(disable_self_attentions):
|
203 |
+
disabled_sa = disable_self_attentions[level]
|
204 |
+
else:
|
205 |
+
disabled_sa = False
|
206 |
+
|
207 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
208 |
+
layers.append(
|
209 |
+
SpatialVideoTransformer(
|
210 |
+
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
211 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
212 |
+
checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations,
|
213 |
+
use_spatial_context=use_spatial_context, ff_in=extra_ff_mix_layer,
|
214 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
215 |
+
)
|
216 |
+
)
|
217 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
218 |
+
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
219 |
+
self._feature_size += ch
|
220 |
+
input_block_chans.append(ch)
|
221 |
+
if level != len(channel_mult) - 1:
|
222 |
+
out_ch = ch
|
223 |
+
self.input_blocks.append(
|
224 |
+
TimestepEmbedSequential(
|
225 |
+
VideoResBlock(
|
226 |
+
ch,
|
227 |
+
time_embed_dim,
|
228 |
+
dropout,
|
229 |
+
out_channels=out_ch,
|
230 |
+
dims=dims,
|
231 |
+
use_checkpoint=use_checkpoint,
|
232 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
233 |
+
down=True,
|
234 |
+
dtype=self.dtype,
|
235 |
+
device=device,
|
236 |
+
operations=operations,
|
237 |
+
video_kernel_size=video_kernel_size,
|
238 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
239 |
+
)
|
240 |
+
if resblock_updown
|
241 |
+
else Downsample(
|
242 |
+
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
243 |
+
)
|
244 |
+
)
|
245 |
+
)
|
246 |
+
ch = out_ch
|
247 |
+
input_block_chans.append(ch)
|
248 |
+
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
249 |
+
ds *= 2
|
250 |
+
self._feature_size += ch
|
251 |
+
|
252 |
+
if num_head_channels == -1:
|
253 |
+
dim_head = ch // num_heads
|
254 |
+
else:
|
255 |
+
num_heads = ch // num_head_channels
|
256 |
+
dim_head = num_head_channels
|
257 |
+
if legacy:
|
258 |
+
#num_heads = 1
|
259 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
260 |
+
mid_block = [
|
261 |
+
VideoResBlock(
|
262 |
+
ch,
|
263 |
+
time_embed_dim,
|
264 |
+
dropout,
|
265 |
+
dims=dims,
|
266 |
+
use_checkpoint=use_checkpoint,
|
267 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
268 |
+
dtype=self.dtype,
|
269 |
+
device=device,
|
270 |
+
operations=operations,
|
271 |
+
video_kernel_size=video_kernel_size,
|
272 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
273 |
+
)]
|
274 |
+
if transformer_depth_middle >= 0:
|
275 |
+
mid_block += [SpatialVideoTransformer( # always uses a self-attn
|
276 |
+
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
277 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
278 |
+
checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations,
|
279 |
+
use_spatial_context=use_spatial_context, ff_in=extra_ff_mix_layer,
|
280 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
281 |
+
),
|
282 |
+
VideoResBlock(
|
283 |
+
ch,
|
284 |
+
time_embed_dim,
|
285 |
+
dropout,
|
286 |
+
dims=dims,
|
287 |
+
use_checkpoint=use_checkpoint,
|
288 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
289 |
+
dtype=self.dtype,
|
290 |
+
device=device,
|
291 |
+
operations=operations,
|
292 |
+
video_kernel_size=video_kernel_size,
|
293 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
294 |
+
)]
|
295 |
+
self.middle_block = TimestepEmbedSequential(*mid_block)
|
296 |
+
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
297 |
+
self._feature_size += ch
|
298 |
+
|
299 |
+
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
300 |
+
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
301 |
+
|
302 |
+
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
303 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
304 |
+
emb = self.time_embed(t_emb)
|
305 |
+
|
306 |
+
cond = kwargs["cond"]
|
307 |
+
num_video_frames = cond["num_video_frames"]
|
308 |
+
image_only_indicator = cond.get("image_only_indicator", None)
|
309 |
+
time_context = cond.get("time_context", None)
|
310 |
+
del cond
|
311 |
+
|
312 |
+
guided_hint = self.input_hint_block(hint, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
313 |
+
|
314 |
+
out_output = []
|
315 |
+
out_middle = []
|
316 |
+
|
317 |
+
hs = []
|
318 |
+
if self.num_classes is not None:
|
319 |
+
assert y.shape[0] == x.shape[0]
|
320 |
+
emb = emb + self.label_emb(y)
|
321 |
+
|
322 |
+
h = x
|
323 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
324 |
+
if guided_hint is not None:
|
325 |
+
h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
326 |
+
h += guided_hint
|
327 |
+
guided_hint = None
|
328 |
+
else:
|
329 |
+
h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
330 |
+
out_output.append(zero_conv(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
|
331 |
+
|
332 |
+
h = self.middle_block(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
333 |
+
out_middle.append(self.middle_block_out(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
|
334 |
+
|
335 |
+
return {"middle": out_middle, "output": out_output}
|
336 |
+
|
337 |
+
|
338 |
+
TEMPORAL_TRANSFORMER_BLOCKS = {
|
339 |
+
"norm_in.weight",
|
340 |
+
"norm_in.bias",
|
341 |
+
"ff_in.net.0.proj.weight",
|
342 |
+
"ff_in.net.0.proj.bias",
|
343 |
+
"ff_in.net.2.weight",
|
344 |
+
"ff_in.net.2.bias",
|
345 |
+
}
|
346 |
+
TEMPORAL_TRANSFORMER_BLOCKS.update(TRANSFORMER_BLOCKS)
|
347 |
+
|
348 |
+
|
349 |
+
TEMPORAL_UNET_MAP_ATTENTIONS = {
|
350 |
+
"time_mixer.mix_factor",
|
351 |
+
}
|
352 |
+
TEMPORAL_UNET_MAP_ATTENTIONS.update(UNET_MAP_ATTENTIONS)
|
353 |
+
|
354 |
+
|
355 |
+
TEMPORAL_TRANSFORMER_MAP = {
|
356 |
+
"time_pos_embed.0.weight": "time_pos_embed.linear_1.weight",
|
357 |
+
"time_pos_embed.0.bias": "time_pos_embed.linear_1.bias",
|
358 |
+
"time_pos_embed.2.weight": "time_pos_embed.linear_2.weight",
|
359 |
+
"time_pos_embed.2.bias": "time_pos_embed.linear_2.bias",
|
360 |
+
}
|
361 |
+
|
362 |
+
|
363 |
+
TEMPORAL_RESNET = {
|
364 |
+
"time_mixer.mix_factor",
|
365 |
+
}
|
366 |
+
|
367 |
+
|
368 |
+
def svd_unet_config_from_diffusers_unet(state_dict: dict[str, Tensor], dtype):
|
369 |
+
match = {}
|
370 |
+
transformer_depth = []
|
371 |
+
|
372 |
+
attn_res = 1
|
373 |
+
down_blocks = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}")
|
374 |
+
for i in range(down_blocks):
|
375 |
+
attn_blocks = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
|
376 |
+
for ab in range(attn_blocks):
|
377 |
+
transformer_count = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
378 |
+
transformer_depth.append(transformer_count)
|
379 |
+
if transformer_count > 0:
|
380 |
+
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
|
381 |
+
|
382 |
+
attn_res *= 2
|
383 |
+
if attn_blocks == 0:
|
384 |
+
transformer_depth.append(0)
|
385 |
+
transformer_depth.append(0)
|
386 |
+
|
387 |
+
match["transformer_depth"] = transformer_depth
|
388 |
+
|
389 |
+
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
390 |
+
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
391 |
+
match["adm_in_channels"] = None
|
392 |
+
if "class_embedding.linear_1.weight" in state_dict:
|
393 |
+
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
394 |
+
elif "add_embedding.linear_1.weight" in state_dict:
|
395 |
+
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
396 |
+
|
397 |
+
# based on unet_config of SVD
|
398 |
+
SVD = {
|
399 |
+
'use_checkpoint': False,
|
400 |
+
'image_size': 32,
|
401 |
+
'use_spatial_transformer': True,
|
402 |
+
'legacy': False,
|
403 |
+
'num_classes': 'sequential',
|
404 |
+
'adm_in_channels': 768,
|
405 |
+
'dtype': dtype,
|
406 |
+
'in_channels': 8,
|
407 |
+
'out_channels': 4,
|
408 |
+
'model_channels': 320,
|
409 |
+
'num_res_blocks': [2, 2, 2, 2],
|
410 |
+
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
411 |
+
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
412 |
+
'channel_mult': [1, 2, 4, 4],
|
413 |
+
'transformer_depth_middle': 1,
|
414 |
+
'use_linear_in_transformer': True,
|
415 |
+
'context_dim': 1024,
|
416 |
+
'extra_ff_mix_layer': True,
|
417 |
+
'use_spatial_context': True,
|
418 |
+
'merge_strategy': 'learned_with_images',
|
419 |
+
'merge_factor': 0.0,
|
420 |
+
'video_kernel_size': [3, 1, 1],
|
421 |
+
'use_temporal_attention': True,
|
422 |
+
'use_temporal_resblock': True,
|
423 |
+
'num_heads': -1,
|
424 |
+
'num_head_channels': 64,
|
425 |
+
}
|
426 |
+
|
427 |
+
supported_models = [SVD]
|
428 |
+
|
429 |
+
for unet_config in supported_models:
|
430 |
+
matches = True
|
431 |
+
for k in match:
|
432 |
+
if match[k] != unet_config[k]:
|
433 |
+
matches = False
|
434 |
+
break
|
435 |
+
if matches:
|
436 |
+
return comfy.model_detection.convert_config(unet_config)
|
437 |
+
return None
|
438 |
+
|
439 |
+
|
440 |
+
def svd_unet_to_diffusers(unet_config):
|
441 |
+
num_res_blocks = unet_config["num_res_blocks"]
|
442 |
+
channel_mult = unet_config["channel_mult"]
|
443 |
+
transformer_depth = unet_config["transformer_depth"][:]
|
444 |
+
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
445 |
+
num_blocks = len(channel_mult)
|
446 |
+
|
447 |
+
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
448 |
+
|
449 |
+
diffusers_unet_map = {}
|
450 |
+
for x in range(num_blocks):
|
451 |
+
n = 1 + (num_res_blocks[x] + 1) * x
|
452 |
+
for i in range(num_res_blocks[x]):
|
453 |
+
for b in TEMPORAL_RESNET:
|
454 |
+
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, b)] = "input_blocks.{}.0.{}".format(n, b)
|
455 |
+
for b in UNET_MAP_RESNET:
|
456 |
+
diffusers_unet_map["down_blocks.{}.resnets.{}.spatial_res_block.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
457 |
+
diffusers_unet_map["down_blocks.{}.resnets.{}.temporal_res_block.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.time_stack.{}".format(n, b)
|
458 |
+
#diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
459 |
+
num_transformers = transformer_depth.pop(0)
|
460 |
+
if num_transformers > 0:
|
461 |
+
for b in TEMPORAL_UNET_MAP_ATTENTIONS:
|
462 |
+
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
463 |
+
for b in TEMPORAL_TRANSFORMER_MAP:
|
464 |
+
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, TEMPORAL_TRANSFORMER_MAP[b])] = "input_blocks.{}.1.{}".format(n, b)
|
465 |
+
for t in range(num_transformers):
|
466 |
+
for b in TRANSFORMER_BLOCKS:
|
467 |
+
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
468 |
+
for b in TEMPORAL_TRANSFORMER_BLOCKS:
|
469 |
+
diffusers_unet_map["down_blocks.{}.attentions.{}.temporal_transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.time_stack.{}.{}".format(n, t, b)
|
470 |
+
n += 1
|
471 |
+
for k in ["weight", "bias"]:
|
472 |
+
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
473 |
+
|
474 |
+
i = 0
|
475 |
+
for b in TEMPORAL_UNET_MAP_ATTENTIONS:
|
476 |
+
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
477 |
+
for b in TEMPORAL_TRANSFORMER_MAP:
|
478 |
+
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, TEMPORAL_TRANSFORMER_MAP[b])] = "middle_block.1.{}".format(b)
|
479 |
+
for t in range(transformers_mid):
|
480 |
+
for b in TRANSFORMER_BLOCKS:
|
481 |
+
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
482 |
+
for b in TEMPORAL_TRANSFORMER_BLOCKS:
|
483 |
+
diffusers_unet_map["mid_block.attentions.{}.temporal_transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.time_stack.{}.{}".format(t, b)
|
484 |
+
|
485 |
+
for i, n in enumerate([0, 2]):
|
486 |
+
for b in TEMPORAL_RESNET:
|
487 |
+
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, b)] = "middle_block.{}.{}".format(n, b)
|
488 |
+
for b in UNET_MAP_RESNET:
|
489 |
+
diffusers_unet_map["mid_block.resnets.{}.spatial_res_block.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
490 |
+
diffusers_unet_map["mid_block.resnets.{}.temporal_res_block.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.time_stack.{}".format(n, b)
|
491 |
+
#diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
492 |
+
|
493 |
+
num_res_blocks = list(reversed(num_res_blocks))
|
494 |
+
for x in range(num_blocks):
|
495 |
+
n = (num_res_blocks[x] + 1) * x
|
496 |
+
l = num_res_blocks[x] + 1
|
497 |
+
for i in range(l):
|
498 |
+
c = 0
|
499 |
+
for b in UNET_MAP_RESNET:
|
500 |
+
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
501 |
+
c += 1
|
502 |
+
num_transformers = transformer_depth_output.pop()
|
503 |
+
if num_transformers > 0:
|
504 |
+
c += 1
|
505 |
+
for b in UNET_MAP_ATTENTIONS:
|
506 |
+
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
507 |
+
for t in range(num_transformers):
|
508 |
+
for b in TRANSFORMER_BLOCKS:
|
509 |
+
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
510 |
+
if i == l - 1:
|
511 |
+
for k in ["weight", "bias"]:
|
512 |
+
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
513 |
+
n += 1
|
514 |
+
|
515 |
+
for k in UNET_MAP_BASIC:
|
516 |
+
diffusers_unet_map[k[1]] = k[0]
|
517 |
+
|
518 |
+
return diffusers_unet_map
|
ComfyUI-Advanced-ControlNet/adv_control/documentation.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .logger import logger
|
2 |
+
|
3 |
+
def image(src):
|
4 |
+
return f'<img src={src} style="width: 0px; min-width: 100%">'
|
5 |
+
def video(src):
|
6 |
+
return f'<video src={src} autoplay muted loop controls controlslist="nodownload noremoteplayback noplaybackrate" style="width: 0px; min-width: 100%" class="VHS_loopedvideo">'
|
7 |
+
def short_desc(desc):
|
8 |
+
return f'<div id=VHS_shortdesc style="font-size: .8em">{desc}</div>'
|
9 |
+
|
10 |
+
descriptions = {
|
11 |
+
}
|
12 |
+
|
13 |
+
sizes = ['1.4','1.2','1']
|
14 |
+
def as_html(entry, depth=0):
|
15 |
+
if isinstance(entry, dict):
|
16 |
+
size = 0.8 if depth < 2 else 1
|
17 |
+
html = ''
|
18 |
+
for k in entry:
|
19 |
+
if k == "collapsed":
|
20 |
+
continue
|
21 |
+
collapse_single = k.endswith("_collapsed")
|
22 |
+
if collapse_single:
|
23 |
+
name = k[:-len("_collapsed")]
|
24 |
+
else:
|
25 |
+
name = k
|
26 |
+
collapse_flag = ' VHS_precollapse' if entry.get("collapsed", False) or collapse_single else ''
|
27 |
+
html += f'<div vhs_title=\"{name}\" style=\"display: flex; font-size: {size}em\" class=\"VHS_collapse{collapse_flag}\"><div style=\"color: #AAA; height: 1.5em;\">[<span style=\"font-family: monospace\">-</span>]</div><div style=\"width: 100%\">{name}: {as_html(entry[k], depth=depth+1)}</div></div>'
|
28 |
+
return html
|
29 |
+
if isinstance(entry, list):
|
30 |
+
html = ''
|
31 |
+
for i in entry:
|
32 |
+
html += f'<div>{as_html(i, depth=depth)}</div>'
|
33 |
+
return html
|
34 |
+
return str(entry)
|
35 |
+
|
36 |
+
def format_descriptions(nodes):
|
37 |
+
for k in descriptions:
|
38 |
+
if k.endswith("_collapsed"):
|
39 |
+
k = k[:-len("_collapsed")]
|
40 |
+
nodes[k].DESCRIPTION = as_html(descriptions[k])
|
41 |
+
# undocumented_nodes = []
|
42 |
+
# for k in nodes:
|
43 |
+
# if not hasattr(nodes[k], "DESCRIPTION"):
|
44 |
+
# undocumented_nodes.append(k)
|
45 |
+
# if len(undocumented_nodes) > 0:
|
46 |
+
# logger.info(f"Undocumented nodes: {undocumented_nodes}")
|
47 |
+
|
ComfyUI-Advanced-ControlNet/adv_control/logger.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
|
5 |
+
|
6 |
+
class ColoredFormatter(logging.Formatter):
|
7 |
+
COLORS = {
|
8 |
+
"DEBUG": "\033[0;36m", # CYAN
|
9 |
+
"INFO": "\033[0;32m", # GREEN
|
10 |
+
"WARNING": "\033[0;33m", # YELLOW
|
11 |
+
"ERROR": "\033[0;31m", # RED
|
12 |
+
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
|
13 |
+
"RESET": "\033[0m", # RESET COLOR
|
14 |
+
}
|
15 |
+
|
16 |
+
def format(self, record):
|
17 |
+
colored_record = copy.copy(record)
|
18 |
+
levelname = colored_record.levelname
|
19 |
+
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
|
20 |
+
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
|
21 |
+
return super().format(colored_record)
|
22 |
+
|
23 |
+
|
24 |
+
# Create a new logger
|
25 |
+
logger = logging.getLogger("Advanced-ControlNet")
|
26 |
+
logger.propagate = False
|
27 |
+
|
28 |
+
# Add handler if we don't have one.
|
29 |
+
if not logger.handlers:
|
30 |
+
handler = logging.StreamHandler(sys.stdout)
|
31 |
+
handler.setFormatter(ColoredFormatter("[%(name)s] - %(levelname)s - %(message)s"))
|
32 |
+
logger.addHandler(handler)
|
33 |
+
|
34 |
+
# Configure logger
|
35 |
+
loglevel = logging.INFO
|
36 |
+
logger.setLevel(loglevel)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch import Tensor
|
3 |
+
|
4 |
+
import folder_paths
|
5 |
+
import comfy.sample
|
6 |
+
from comfy.model_patcher import ModelPatcher
|
7 |
+
|
8 |
+
from .control import load_controlnet, convert_to_advanced, is_advanced_controlnet, is_sd3_advanced_controlnet
|
9 |
+
from .utils import ControlWeights, LatentKeyframeGroup, TimestepKeyframeGroup, AbstractPreprocWrapper, BIGMAX
|
10 |
+
from .nodes_weight import (DefaultWeights, ScaledSoftMaskedUniversalWeights, ScaledSoftUniversalWeights,
|
11 |
+
SoftControlNetWeightsSD15, CustomControlNetWeightsSD15, CustomControlNetWeightsFlux,
|
12 |
+
SoftT2IAdapterWeights, CustomT2IAdapterWeights)
|
13 |
+
from .nodes_keyframes import (LatentKeyframeGroupNode, LatentKeyframeInterpolationNode, LatentKeyframeBatchedGroupNode, LatentKeyframeNode,
|
14 |
+
TimestepKeyframeNode, TimestepKeyframeInterpolationNode, TimestepKeyframeFromStrengthListNode)
|
15 |
+
from .nodes_sparsectrl import SparseCtrlMergedLoaderAdvanced, SparseCtrlLoaderAdvanced, SparseIndexMethodNode, SparseSpreadMethodNode, RgbSparseCtrlPreprocessor, SparseWeightExtras
|
16 |
+
from .nodes_reference import ReferenceControlNetNode, ReferenceControlFinetune, ReferencePreprocessorNode
|
17 |
+
from .nodes_plusplus import PlusPlusLoaderAdvanced, PlusPlusLoaderSingle, PlusPlusInputNode
|
18 |
+
from .nodes_loosecontrol import ControlNetLoaderWithLoraAdvanced
|
19 |
+
from .nodes_deprecated import (LoadImagesFromDirectory, ScaledSoftUniversalWeightsDeprecated,
|
20 |
+
SoftControlNetWeightsDeprecated, CustomControlNetWeightsDeprecated,
|
21 |
+
SoftT2IAdapterWeightsDeprecated, CustomT2IAdapterWeightsDeprecated)
|
22 |
+
from .logger import logger
|
23 |
+
|
24 |
+
from .sampling import acn_sample_factory
|
25 |
+
# inject sample functions
|
26 |
+
comfy.sample.sample = acn_sample_factory(comfy.sample.sample)
|
27 |
+
comfy.sample.sample_custom = acn_sample_factory(comfy.sample.sample_custom, is_custom=True)
|
28 |
+
|
29 |
+
|
30 |
+
class ControlNetLoaderAdvanced:
|
31 |
+
@classmethod
|
32 |
+
def INPUT_TYPES(s):
|
33 |
+
return {
|
34 |
+
"required": {
|
35 |
+
"control_net_name": (folder_paths.get_filename_list("controlnet"), ),
|
36 |
+
},
|
37 |
+
"optional": {
|
38 |
+
"tk_optional": ("TIMESTEP_KEYFRAME", ),
|
39 |
+
}
|
40 |
+
}
|
41 |
+
|
42 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
43 |
+
FUNCTION = "load_controlnet"
|
44 |
+
|
45 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
|
46 |
+
|
47 |
+
def load_controlnet(self, control_net_name,
|
48 |
+
tk_optional: TimestepKeyframeGroup=None,
|
49 |
+
timestep_keyframe: TimestepKeyframeGroup=None,
|
50 |
+
):
|
51 |
+
if timestep_keyframe is not None: # backwards compatibility
|
52 |
+
tk_optional = timestep_keyframe
|
53 |
+
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
54 |
+
controlnet = load_controlnet(controlnet_path, tk_optional)
|
55 |
+
return (controlnet,)
|
56 |
+
|
57 |
+
|
58 |
+
class DiffControlNetLoaderAdvanced:
|
59 |
+
@classmethod
|
60 |
+
def INPUT_TYPES(s):
|
61 |
+
return {
|
62 |
+
"required": {
|
63 |
+
"model": ("MODEL",),
|
64 |
+
"control_net_name": (folder_paths.get_filename_list("controlnet"), )
|
65 |
+
},
|
66 |
+
"optional": {
|
67 |
+
"tk_optional": ("TIMESTEP_KEYFRAME", ),
|
68 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 160}),
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
73 |
+
FUNCTION = "load_controlnet"
|
74 |
+
|
75 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
|
76 |
+
|
77 |
+
def load_controlnet(self, control_net_name, model,
|
78 |
+
tk_optional: TimestepKeyframeGroup=None,
|
79 |
+
timestep_keyframe: TimestepKeyframeGroup=None
|
80 |
+
):
|
81 |
+
if timestep_keyframe is not None: # backwards compatibility
|
82 |
+
tk_optional = timestep_keyframe
|
83 |
+
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
84 |
+
controlnet = load_controlnet(controlnet_path, tk_optional, model)
|
85 |
+
if is_advanced_controlnet(controlnet):
|
86 |
+
controlnet.verify_all_weights()
|
87 |
+
return (controlnet,)
|
88 |
+
|
89 |
+
|
90 |
+
class AdvancedControlNetApply:
|
91 |
+
@classmethod
|
92 |
+
def INPUT_TYPES(s):
|
93 |
+
return {
|
94 |
+
"required": {
|
95 |
+
"positive": ("CONDITIONING", ),
|
96 |
+
"negative": ("CONDITIONING", ),
|
97 |
+
"control_net": ("CONTROL_NET", ),
|
98 |
+
"image": ("IMAGE", ),
|
99 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
100 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
101 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
102 |
+
},
|
103 |
+
"optional": {
|
104 |
+
"mask_optional": ("MASK", ),
|
105 |
+
"timestep_kf": ("TIMESTEP_KEYFRAME", ),
|
106 |
+
"latent_kf_override": ("LATENT_KEYFRAME", ),
|
107 |
+
"weights_override": ("CONTROL_NET_WEIGHTS", ),
|
108 |
+
"model_optional": ("MODEL",),
|
109 |
+
"vae_optional": ("VAE",),
|
110 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
111 |
+
}
|
112 |
+
}
|
113 |
+
|
114 |
+
RETURN_TYPES = ("CONDITIONING","CONDITIONING","MODEL",)
|
115 |
+
RETURN_NAMES = ("positive", "negative", "model_opt")
|
116 |
+
FUNCTION = "apply_controlnet"
|
117 |
+
|
118 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
|
119 |
+
|
120 |
+
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent,
|
121 |
+
mask_optional: Tensor=None, model_optional: ModelPatcher=None, vae_optional=None,
|
122 |
+
timestep_kf: TimestepKeyframeGroup=None, latent_kf_override: LatentKeyframeGroup=None,
|
123 |
+
weights_override: ControlWeights=None, control_apply_to_uncond=False):
|
124 |
+
if strength == 0:
|
125 |
+
return (positive, negative, model_optional)
|
126 |
+
if model_optional:
|
127 |
+
model_optional = model_optional.clone()
|
128 |
+
|
129 |
+
control_hint = image.movedim(-1,1)
|
130 |
+
cnets = {}
|
131 |
+
|
132 |
+
out = []
|
133 |
+
for conditioning in [positive, negative]:
|
134 |
+
c = []
|
135 |
+
if conditioning is not None:
|
136 |
+
for t in conditioning:
|
137 |
+
d = t[1].copy()
|
138 |
+
|
139 |
+
prev_cnet = d.get('control', None)
|
140 |
+
if prev_cnet in cnets:
|
141 |
+
c_net = cnets[prev_cnet]
|
142 |
+
else:
|
143 |
+
# copy, convert to advanced if needed, and set cond
|
144 |
+
c_net = convert_to_advanced(control_net.copy()).set_cond_hint(control_hint, strength, (start_percent, end_percent), vae_optional)
|
145 |
+
if is_advanced_controlnet(c_net):
|
146 |
+
# disarm node check
|
147 |
+
c_net.disarm()
|
148 |
+
# if model required, verify model is passed in, and if so patch it
|
149 |
+
if c_net.require_model:
|
150 |
+
if not model_optional:
|
151 |
+
raise Exception(f"Type '{type(c_net).__name__}' requires model_optional input, but got None.")
|
152 |
+
c_net.patch_model(model=model_optional)
|
153 |
+
# if vae required, verify vae is passed in
|
154 |
+
if c_net.require_vae:
|
155 |
+
# if controlnet can accept preprocced condhint latents and is the case, ignore vae requirement
|
156 |
+
if c_net.allow_condhint_latents and isinstance(control_hint, AbstractPreprocWrapper):
|
157 |
+
pass
|
158 |
+
elif not vae_optional:
|
159 |
+
# make sure SD3 ControlNet will get a special message instead of generic type mention
|
160 |
+
if is_sd3_advanced_controlnet:
|
161 |
+
raise Exception(f"SD3 ControlNet requires vae_optional input, but got None.")
|
162 |
+
else:
|
163 |
+
raise Exception(f"Type '{type(c_net).__name__}' requires vae_optional input, but got None.")
|
164 |
+
# apply optional parameters and overrides, if provided
|
165 |
+
if timestep_kf is not None:
|
166 |
+
c_net.set_timestep_keyframes(timestep_kf)
|
167 |
+
if latent_kf_override is not None:
|
168 |
+
c_net.latent_keyframe_override = latent_kf_override
|
169 |
+
if weights_override is not None:
|
170 |
+
c_net.weights_override = weights_override
|
171 |
+
# verify weights are compatible
|
172 |
+
c_net.verify_all_weights()
|
173 |
+
# set cond hint mask
|
174 |
+
if mask_optional is not None:
|
175 |
+
mask_optional = mask_optional.clone()
|
176 |
+
# if not in the form of a batch, make it so
|
177 |
+
if len(mask_optional.shape) < 3:
|
178 |
+
mask_optional = mask_optional.unsqueeze(0)
|
179 |
+
c_net.set_cond_hint_mask(mask_optional)
|
180 |
+
c_net.set_previous_controlnet(prev_cnet)
|
181 |
+
cnets[prev_cnet] = c_net
|
182 |
+
|
183 |
+
d['control'] = c_net
|
184 |
+
d['control_apply_to_uncond'] = control_apply_to_uncond
|
185 |
+
n = [t[0], d]
|
186 |
+
c.append(n)
|
187 |
+
out.append(c)
|
188 |
+
return (out[0], out[1], model_optional)
|
189 |
+
|
190 |
+
|
191 |
+
class AdvancedControlNetApplySingle:
|
192 |
+
@classmethod
|
193 |
+
def INPUT_TYPES(s):
|
194 |
+
return {
|
195 |
+
"required": {
|
196 |
+
"conditioning": ("CONDITIONING", ),
|
197 |
+
"control_net": ("CONTROL_NET", ),
|
198 |
+
"image": ("IMAGE", ),
|
199 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
200 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
201 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
202 |
+
},
|
203 |
+
"optional": {
|
204 |
+
"mask_optional": ("MASK", ),
|
205 |
+
"timestep_kf": ("TIMESTEP_KEYFRAME", ),
|
206 |
+
"latent_kf_override": ("LATENT_KEYFRAME", ),
|
207 |
+
"weights_override": ("CONTROL_NET_WEIGHTS", ),
|
208 |
+
"model_optional": ("MODEL",),
|
209 |
+
"vae_optional": ("VAE",),
|
210 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
211 |
+
}
|
212 |
+
}
|
213 |
+
|
214 |
+
RETURN_TYPES = ("CONDITIONING","MODEL",)
|
215 |
+
RETURN_NAMES = ("CONDITIONING", "model_opt")
|
216 |
+
FUNCTION = "apply_controlnet"
|
217 |
+
|
218 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
|
219 |
+
|
220 |
+
def apply_controlnet(self, conditioning, control_net, image, strength, start_percent, end_percent,
|
221 |
+
mask_optional: Tensor=None, model_optional: ModelPatcher=None, vae_optional=None,
|
222 |
+
timestep_kf: TimestepKeyframeGroup=None, latent_kf_override: LatentKeyframeGroup=None,
|
223 |
+
weights_override: ControlWeights=None):
|
224 |
+
values = AdvancedControlNetApply.apply_controlnet(self, positive=conditioning, negative=None, control_net=control_net, image=image,
|
225 |
+
strength=strength, start_percent=start_percent, end_percent=end_percent,
|
226 |
+
mask_optional=mask_optional, model_optional=model_optional, vae_optional=vae_optional,
|
227 |
+
timestep_kf=timestep_kf, latent_kf_override=latent_kf_override, weights_override=weights_override,
|
228 |
+
control_apply_to_uncond=True)
|
229 |
+
return (values[0], values[2])
|
230 |
+
|
231 |
+
|
232 |
+
# NODE MAPPING
|
233 |
+
NODE_CLASS_MAPPINGS = {
|
234 |
+
# Keyframes
|
235 |
+
"TimestepKeyframe": TimestepKeyframeNode,
|
236 |
+
"ACN_TimestepKeyframeInterpolation": TimestepKeyframeInterpolationNode,
|
237 |
+
"ACN_TimestepKeyframeFromStrengthList": TimestepKeyframeFromStrengthListNode,
|
238 |
+
"LatentKeyframe": LatentKeyframeNode,
|
239 |
+
"LatentKeyframeTiming": LatentKeyframeInterpolationNode,
|
240 |
+
"LatentKeyframeBatchedGroup": LatentKeyframeBatchedGroupNode,
|
241 |
+
"LatentKeyframeGroup": LatentKeyframeGroupNode,
|
242 |
+
# Conditioning
|
243 |
+
"ACN_AdvancedControlNetApply": AdvancedControlNetApply,
|
244 |
+
"ACN_AdvancedControlNetApplySingle": AdvancedControlNetApplySingle,
|
245 |
+
# Loaders
|
246 |
+
"ControlNetLoaderAdvanced": ControlNetLoaderAdvanced,
|
247 |
+
"DiffControlNetLoaderAdvanced": DiffControlNetLoaderAdvanced,
|
248 |
+
# Weights
|
249 |
+
"ACN_ScaledSoftControlNetWeights": ScaledSoftUniversalWeights,
|
250 |
+
"ScaledSoftMaskedUniversalWeights": ScaledSoftMaskedUniversalWeights,
|
251 |
+
"ACN_SoftControlNetWeightsSD15": SoftControlNetWeightsSD15,
|
252 |
+
"ACN_CustomControlNetWeightsSD15": CustomControlNetWeightsSD15,
|
253 |
+
"ACN_CustomControlNetWeightsFlux": CustomControlNetWeightsFlux,
|
254 |
+
"ACN_SoftT2IAdapterWeights": SoftT2IAdapterWeights,
|
255 |
+
"ACN_CustomT2IAdapterWeights": CustomT2IAdapterWeights,
|
256 |
+
"ACN_DefaultUniversalWeights": DefaultWeights,
|
257 |
+
# SparseCtrl
|
258 |
+
"ACN_SparseCtrlRGBPreprocessor": RgbSparseCtrlPreprocessor,
|
259 |
+
"ACN_SparseCtrlLoaderAdvanced": SparseCtrlLoaderAdvanced,
|
260 |
+
"ACN_SparseCtrlMergedLoaderAdvanced": SparseCtrlMergedLoaderAdvanced,
|
261 |
+
"ACN_SparseCtrlIndexMethodNode": SparseIndexMethodNode,
|
262 |
+
"ACN_SparseCtrlSpreadMethodNode": SparseSpreadMethodNode,
|
263 |
+
"ACN_SparseCtrlWeightExtras": SparseWeightExtras,
|
264 |
+
# ControlNet++
|
265 |
+
"ACN_ControlNet++LoaderSingle": PlusPlusLoaderSingle,
|
266 |
+
"ACN_ControlNet++LoaderAdvanced": PlusPlusLoaderAdvanced,
|
267 |
+
"ACN_ControlNet++InputNode": PlusPlusInputNode,
|
268 |
+
# Reference
|
269 |
+
"ACN_ReferencePreprocessor": ReferencePreprocessorNode,
|
270 |
+
"ACN_ReferenceControlNet": ReferenceControlNetNode,
|
271 |
+
"ACN_ReferenceControlNetFinetune": ReferenceControlFinetune,
|
272 |
+
# LOOSEControl
|
273 |
+
#"ACN_ControlNetLoaderWithLoraAdvanced": ControlNetLoaderWithLoraAdvanced,
|
274 |
+
# Deprecated
|
275 |
+
"LoadImagesFromDirectory": LoadImagesFromDirectory,
|
276 |
+
"ScaledSoftControlNetWeights": ScaledSoftUniversalWeightsDeprecated,
|
277 |
+
"SoftControlNetWeights": SoftControlNetWeightsDeprecated,
|
278 |
+
"CustomControlNetWeights": CustomControlNetWeightsDeprecated,
|
279 |
+
"SoftT2IAdapterWeights": SoftT2IAdapterWeightsDeprecated,
|
280 |
+
"CustomT2IAdapterWeights": CustomT2IAdapterWeightsDeprecated,
|
281 |
+
}
|
282 |
+
|
283 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
284 |
+
# Keyframes
|
285 |
+
"TimestepKeyframe": "Timestep Keyframe 🛂🅐🅒🅝",
|
286 |
+
"ACN_TimestepKeyframeInterpolation": "Timestep Keyframe Interp. 🛂🅐🅒🅝",
|
287 |
+
"ACN_TimestepKeyframeFromStrengthList": "Timestep Keyframe From List 🛂🅐🅒🅝",
|
288 |
+
"LatentKeyframe": "Latent Keyframe 🛂🅐🅒🅝",
|
289 |
+
"LatentKeyframeTiming": "Latent Keyframe Interp. 🛂🅐🅒🅝",
|
290 |
+
"LatentKeyframeBatchedGroup": "Latent Keyframe From List 🛂🅐🅒🅝",
|
291 |
+
"LatentKeyframeGroup": "Latent Keyframe Group 🛂🅐🅒🅝",
|
292 |
+
# Conditioning
|
293 |
+
"ACN_AdvancedControlNetApply": "Apply Advanced ControlNet 🛂🅐🅒🅝",
|
294 |
+
"ACN_AdvancedControlNetApplySingle": "Apply Advanced ControlNet(1) 🛂🅐🅒🅝",
|
295 |
+
# Loaders
|
296 |
+
"ControlNetLoaderAdvanced": "Load Advanced ControlNet Model 🛂🅐🅒🅝",
|
297 |
+
"DiffControlNetLoaderAdvanced": "Load Advanced ControlNet Model (diff) 🛂🅐🅒🅝",
|
298 |
+
# Weights
|
299 |
+
"ACN_ScaledSoftControlNetWeights": "Scaled Soft Weights 🛂🅐🅒🅝",
|
300 |
+
"ScaledSoftMaskedUniversalWeights": "Scaled Soft Masked Weights 🛂🅐🅒🅝",
|
301 |
+
"ACN_SoftControlNetWeightsSD15": "ControlNet Soft Weights [SD1.5] 🛂🅐🅒🅝",
|
302 |
+
"ACN_CustomControlNetWeightsSD15": "ControlNet Custom Weights [SD1.5] 🛂🅐🅒🅝",
|
303 |
+
"ACN_CustomControlNetWeightsFlux": "ControlNet Custom Weights [Flux] 🛂🅐🅒🅝",
|
304 |
+
"ACN_SoftT2IAdapterWeights": "T2IAdapter Soft Weights 🛂🅐🅒🅝",
|
305 |
+
"ACN_CustomT2IAdapterWeights": "T2IAdapter Custom Weights 🛂🅐🅒🅝",
|
306 |
+
"ACN_DefaultUniversalWeights": "Default Weights 🛂🅐🅒🅝",
|
307 |
+
# SparseCtrl
|
308 |
+
"ACN_SparseCtrlRGBPreprocessor": "RGB SparseCtrl 🛂🅐🅒🅝",
|
309 |
+
"ACN_SparseCtrlLoaderAdvanced": "Load SparseCtrl Model 🛂🅐🅒🅝",
|
310 |
+
"ACN_SparseCtrlMergedLoaderAdvanced": "🧪Load Merged SparseCtrl Model 🛂🅐🅒🅝",
|
311 |
+
"ACN_SparseCtrlIndexMethodNode": "SparseCtrl Index Method 🛂🅐🅒🅝",
|
312 |
+
"ACN_SparseCtrlSpreadMethodNode": "SparseCtrl Spread Method 🛂🅐🅒🅝",
|
313 |
+
"ACN_SparseCtrlWeightExtras": "SparseCtrl Weight Extras 🛂🅐🅒🅝",
|
314 |
+
# ControlNet++
|
315 |
+
"ACN_ControlNet++LoaderSingle": "Load ControlNet++ Model (Single) 🛂🅐🅒🅝",
|
316 |
+
"ACN_ControlNet++LoaderAdvanced": "Load ControlNet++ Model (Multi) 🛂🅐🅒🅝",
|
317 |
+
"ACN_ControlNet++InputNode": "ControlNet++ Input 🛂🅐🅒🅝",
|
318 |
+
# Reference
|
319 |
+
"ACN_ReferencePreprocessor": "Reference Preproccessor 🛂🅐🅒🅝",
|
320 |
+
"ACN_ReferenceControlNet": "Reference ControlNet 🛂🅐🅒🅝",
|
321 |
+
"ACN_ReferenceControlNetFinetune": "Reference ControlNet (Finetune) 🛂🅐🅒🅝",
|
322 |
+
# LOOSEControl
|
323 |
+
#"ACN_ControlNetLoaderWithLoraAdvanced": "Load Adv. ControlNet Model w/ LoRA 🛂🅐🅒🅝",
|
324 |
+
# Deprecated
|
325 |
+
"LoadImagesFromDirectory": "🚫Load Images [DEPRECATED] 🛂🅐🅒🅝",
|
326 |
+
"ScaledSoftControlNetWeights": "Scaled Soft Weights 🛂🅐🅒🅝",
|
327 |
+
"SoftControlNetWeights": "ControlNet Soft Weights 🛂🅐🅒🅝",
|
328 |
+
"CustomControlNetWeights": "ControlNet Custom Weights 🛂🅐🅒🅝",
|
329 |
+
"SoftT2IAdapterWeights": "T2IAdapter Soft Weights 🛂🅐🅒🅝",
|
330 |
+
"CustomT2IAdapterWeights": "T2IAdapter Custom Weights 🛂🅐🅒🅝",
|
331 |
+
}
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_deprecated.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageOps
|
7 |
+
from .utils import BIGMAX, ControlWeights, TimestepKeyframeGroup, TimestepKeyframe, get_properly_arranged_t2i_weights
|
8 |
+
from .logger import logger
|
9 |
+
|
10 |
+
|
11 |
+
class LoadImagesFromDirectory:
|
12 |
+
@classmethod
|
13 |
+
def INPUT_TYPES(s):
|
14 |
+
return {
|
15 |
+
"required": {
|
16 |
+
"directory": ("STRING", {"default": ""}),
|
17 |
+
},
|
18 |
+
"optional": {
|
19 |
+
"image_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
|
20 |
+
"start_index": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
|
21 |
+
}
|
22 |
+
}
|
23 |
+
|
24 |
+
RETURN_TYPES = ("IMAGE", "MASK", "INT")
|
25 |
+
FUNCTION = "load_images"
|
26 |
+
|
27 |
+
CATEGORY = ""
|
28 |
+
|
29 |
+
def load_images(self, directory: str, image_load_cap: int = 0, start_index: int = 0):
|
30 |
+
if not os.path.isdir(directory):
|
31 |
+
raise FileNotFoundError(f"Directory '{directory} cannot be found.'")
|
32 |
+
dir_files = os.listdir(directory)
|
33 |
+
if len(dir_files) == 0:
|
34 |
+
raise FileNotFoundError(f"No files in directory '{directory}'.")
|
35 |
+
|
36 |
+
dir_files = sorted(dir_files)
|
37 |
+
dir_files = [os.path.join(directory, x) for x in dir_files]
|
38 |
+
# start at start_index
|
39 |
+
dir_files = dir_files[start_index:]
|
40 |
+
|
41 |
+
images = []
|
42 |
+
masks = []
|
43 |
+
|
44 |
+
limit_images = False
|
45 |
+
if image_load_cap > 0:
|
46 |
+
limit_images = True
|
47 |
+
image_count = 0
|
48 |
+
|
49 |
+
for image_path in dir_files:
|
50 |
+
if os.path.isdir(image_path):
|
51 |
+
continue
|
52 |
+
if limit_images and image_count >= image_load_cap:
|
53 |
+
break
|
54 |
+
i = Image.open(image_path)
|
55 |
+
i = ImageOps.exif_transpose(i)
|
56 |
+
image = i.convert("RGB")
|
57 |
+
image = np.array(image).astype(np.float32) / 255.0
|
58 |
+
image = torch.from_numpy(image)[None,]
|
59 |
+
if 'A' in i.getbands():
|
60 |
+
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
61 |
+
mask = 1. - torch.from_numpy(mask)
|
62 |
+
else:
|
63 |
+
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
64 |
+
images.append(image)
|
65 |
+
masks.append(mask)
|
66 |
+
image_count += 1
|
67 |
+
|
68 |
+
if len(images) == 0:
|
69 |
+
raise FileNotFoundError(f"No images could be loaded from directory '{directory}'.")
|
70 |
+
|
71 |
+
return (torch.cat(images, dim=0), torch.stack(masks, dim=0), image_count)
|
72 |
+
|
73 |
+
|
74 |
+
class ScaledSoftUniversalWeightsDeprecated:
|
75 |
+
@classmethod
|
76 |
+
def INPUT_TYPES(s):
|
77 |
+
return {
|
78 |
+
"required": {
|
79 |
+
"base_multiplier": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 1.0, "step": 0.001}, ),
|
80 |
+
"flip_weights": ("BOOLEAN", {"default": False}),
|
81 |
+
},
|
82 |
+
"optional": {
|
83 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
84 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
85 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
90 |
+
RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT")
|
91 |
+
FUNCTION = "load_weights"
|
92 |
+
|
93 |
+
CATEGORY = ""
|
94 |
+
|
95 |
+
def load_weights(self, base_multiplier, flip_weights, uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
96 |
+
weights = ControlWeights.universal(base_multiplier=base_multiplier, uncond_multiplier=uncond_multiplier, extras=cn_extras)
|
97 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
98 |
+
|
99 |
+
|
100 |
+
class SoftControlNetWeightsDeprecated:
|
101 |
+
@classmethod
|
102 |
+
def INPUT_TYPES(s):
|
103 |
+
return {
|
104 |
+
"required": {
|
105 |
+
"weight_00": ("FLOAT", {"default": 0.09941396206337118, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
106 |
+
"weight_01": ("FLOAT", {"default": 0.12050177219802567, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
107 |
+
"weight_02": ("FLOAT", {"default": 0.14606275417942507, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
108 |
+
"weight_03": ("FLOAT", {"default": 0.17704576264172736, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
109 |
+
"weight_04": ("FLOAT", {"default": 0.214600924414215, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
110 |
+
"weight_05": ("FLOAT", {"default": 0.26012233262329093, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
111 |
+
"weight_06": ("FLOAT", {"default": 0.3152997971191405, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
112 |
+
"weight_07": ("FLOAT", {"default": 0.3821815722656249, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
113 |
+
"weight_08": ("FLOAT", {"default": 0.4632503906249999, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
114 |
+
"weight_09": ("FLOAT", {"default": 0.561515625, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
115 |
+
"weight_10": ("FLOAT", {"default": 0.6806249999999999, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
116 |
+
"weight_11": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
117 |
+
"weight_12": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
118 |
+
"flip_weights": ("BOOLEAN", {"default": False}),
|
119 |
+
},
|
120 |
+
"optional": {
|
121 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
122 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
123 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
DEPRECATED = True
|
128 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
129 |
+
RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT")
|
130 |
+
FUNCTION = "load_weights"
|
131 |
+
|
132 |
+
CATEGORY = ""
|
133 |
+
|
134 |
+
def load_weights(self, weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
|
135 |
+
weight_07, weight_08, weight_09, weight_10, weight_11, weight_12, flip_weights,
|
136 |
+
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
137 |
+
weights_output = [weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
|
138 |
+
weight_07, weight_08, weight_09, weight_10, weight_11]
|
139 |
+
weights_middle = [weight_12]
|
140 |
+
weights = ControlWeights.controlnet(weights_output=weights_output, weights_middle=weights_middle, uncond_multiplier=uncond_multiplier, extras=cn_extras)
|
141 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
142 |
+
|
143 |
+
|
144 |
+
class CustomControlNetWeightsDeprecated:
|
145 |
+
@classmethod
|
146 |
+
def INPUT_TYPES(s):
|
147 |
+
return {
|
148 |
+
"required": {
|
149 |
+
"weight_00": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
150 |
+
"weight_01": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
151 |
+
"weight_02": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
152 |
+
"weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
153 |
+
"weight_04": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
154 |
+
"weight_05": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
155 |
+
"weight_06": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
156 |
+
"weight_07": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
157 |
+
"weight_08": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
158 |
+
"weight_09": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
159 |
+
"weight_10": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
160 |
+
"weight_11": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
161 |
+
"weight_12": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
162 |
+
"flip_weights": ("BOOLEAN", {"default": False}),
|
163 |
+
},
|
164 |
+
"optional": {
|
165 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
166 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
167 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
DEPRECATED = True
|
172 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
173 |
+
RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT")
|
174 |
+
FUNCTION = "load_weights"
|
175 |
+
|
176 |
+
CATEGORY = ""
|
177 |
+
|
178 |
+
def load_weights(self, weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
|
179 |
+
weight_07, weight_08, weight_09, weight_10, weight_11, weight_12, flip_weights,
|
180 |
+
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
181 |
+
weights_output = [weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
|
182 |
+
weight_07, weight_08, weight_09, weight_10, weight_11]
|
183 |
+
weights_middle = [weight_12]
|
184 |
+
weights = ControlWeights.controlnet(weights_output=weights_output, weights_middle=weights_middle, uncond_multiplier=uncond_multiplier, extras=cn_extras)
|
185 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
186 |
+
|
187 |
+
|
188 |
+
class SoftT2IAdapterWeightsDeprecated:
|
189 |
+
@classmethod
|
190 |
+
def INPUT_TYPES(s):
|
191 |
+
return {
|
192 |
+
"required": {
|
193 |
+
"weight_00": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
194 |
+
"weight_01": ("FLOAT", {"default": 0.62, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
195 |
+
"weight_02": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
196 |
+
"weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
197 |
+
"flip_weights": ("BOOLEAN", {"default": False}),
|
198 |
+
},
|
199 |
+
"optional": {
|
200 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
201 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
202 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
203 |
+
}
|
204 |
+
}
|
205 |
+
|
206 |
+
DEPRECATED = True
|
207 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
208 |
+
RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT")
|
209 |
+
FUNCTION = "load_weights"
|
210 |
+
|
211 |
+
CATEGORY = ""
|
212 |
+
|
213 |
+
def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
|
214 |
+
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
215 |
+
weights = [weight_00, weight_01, weight_02, weight_03]
|
216 |
+
weights = get_properly_arranged_t2i_weights(weights)
|
217 |
+
weights = ControlWeights.t2iadapter(weights_input=weights, uncond_multiplier=uncond_multiplier, extras=cn_extras)
|
218 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
219 |
+
|
220 |
+
|
221 |
+
class CustomT2IAdapterWeightsDeprecated:
|
222 |
+
@classmethod
|
223 |
+
def INPUT_TYPES(s):
|
224 |
+
return {
|
225 |
+
"required": {
|
226 |
+
"weight_00": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
227 |
+
"weight_01": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
228 |
+
"weight_02": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
229 |
+
"weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
230 |
+
"flip_weights": ("BOOLEAN", {"default": False}),
|
231 |
+
},
|
232 |
+
"optional": {
|
233 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
234 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
235 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
236 |
+
}
|
237 |
+
}
|
238 |
+
|
239 |
+
DEPRECATED = True
|
240 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
241 |
+
RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT")
|
242 |
+
FUNCTION = "load_weights"
|
243 |
+
|
244 |
+
CATEGORY = ""
|
245 |
+
|
246 |
+
def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
|
247 |
+
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
248 |
+
weights = [weight_00, weight_01, weight_02, weight_03]
|
249 |
+
weights = get_properly_arranged_t2i_weights(weights)
|
250 |
+
weights = ControlWeights.t2iadapter(weights_input=weights, uncond_multiplier=uncond_multiplier, extras=cn_extras)
|
251 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_keyframes.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
import numpy as np
|
3 |
+
from collections.abc import Iterable
|
4 |
+
|
5 |
+
from .utils import ControlWeights, TimestepKeyframe, TimestepKeyframeGroup, LatentKeyframe, LatentKeyframeGroup, BIGMIN, BIGMAX
|
6 |
+
from .utils import StrengthInterpolation as SI
|
7 |
+
from .logger import logger
|
8 |
+
|
9 |
+
|
10 |
+
class TimestepKeyframeNode:
|
11 |
+
OUTDATED_DUMMY = -39
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def INPUT_TYPES(s):
|
15 |
+
return {
|
16 |
+
"required": {
|
17 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ),
|
18 |
+
},
|
19 |
+
"optional": {
|
20 |
+
"prev_timestep_kf": ("TIMESTEP_KEYFRAME", ),
|
21 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
22 |
+
"cn_weights": ("CONTROL_NET_WEIGHTS", ),
|
23 |
+
"latent_keyframe": ("LATENT_KEYFRAME", ),
|
24 |
+
"null_latent_kf_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
25 |
+
"inherit_missing": ("BOOLEAN", {"default": True}, ),
|
26 |
+
"guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}),
|
27 |
+
"mask_optional": ("MASK", ),
|
28 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
29 |
+
}
|
30 |
+
}
|
31 |
+
|
32 |
+
RETURN_NAMES = ("TIMESTEP_KF", )
|
33 |
+
RETURN_TYPES = ("TIMESTEP_KEYFRAME", )
|
34 |
+
FUNCTION = "load_keyframe"
|
35 |
+
|
36 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
37 |
+
|
38 |
+
def load_keyframe(self,
|
39 |
+
start_percent: float,
|
40 |
+
strength: float=1.0,
|
41 |
+
cn_weights: ControlWeights=None, control_net_weights: ControlWeights=None, # old name
|
42 |
+
latent_keyframe: LatentKeyframeGroup=None,
|
43 |
+
prev_timestep_kf: TimestepKeyframeGroup=None, prev_timestep_keyframe: TimestepKeyframeGroup=None, # old name
|
44 |
+
null_latent_kf_strength: float=0.0,
|
45 |
+
inherit_missing=True,
|
46 |
+
guarantee_steps=OUTDATED_DUMMY,
|
47 |
+
guarantee_usage=True, # old input
|
48 |
+
mask_optional=None,):
|
49 |
+
# if using outdated dummy value, means node on workflow is outdated and should appropriately convert behavior
|
50 |
+
if guarantee_steps == self.OUTDATED_DUMMY:
|
51 |
+
guarantee_steps = int(guarantee_usage)
|
52 |
+
control_net_weights = control_net_weights if control_net_weights else cn_weights
|
53 |
+
prev_timestep_keyframe = prev_timestep_keyframe if prev_timestep_keyframe else prev_timestep_kf
|
54 |
+
if not prev_timestep_keyframe:
|
55 |
+
prev_timestep_keyframe = TimestepKeyframeGroup()
|
56 |
+
else:
|
57 |
+
prev_timestep_keyframe = prev_timestep_keyframe.clone()
|
58 |
+
keyframe = TimestepKeyframe(start_percent=start_percent, strength=strength, null_latent_kf_strength=null_latent_kf_strength,
|
59 |
+
control_weights=control_net_weights, latent_keyframes=latent_keyframe, inherit_missing=inherit_missing,
|
60 |
+
guarantee_steps=guarantee_steps, mask_hint_orig=mask_optional)
|
61 |
+
prev_timestep_keyframe.add(keyframe)
|
62 |
+
return (prev_timestep_keyframe,)
|
63 |
+
|
64 |
+
|
65 |
+
class TimestepKeyframeInterpolationNode:
|
66 |
+
@classmethod
|
67 |
+
def INPUT_TYPES(s):
|
68 |
+
return {
|
69 |
+
"required": {
|
70 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001},),
|
71 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
72 |
+
"strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},),
|
73 |
+
"strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},),
|
74 |
+
"interpolation": (SI._LIST, ),
|
75 |
+
"intervals": ("INT", {"default": 50, "min": 2, "max": 100, "step": 1}),
|
76 |
+
},
|
77 |
+
"optional": {
|
78 |
+
"prev_timestep_kf": ("TIMESTEP_KEYFRAME", ),
|
79 |
+
"cn_weights": ("CONTROL_NET_WEIGHTS", ),
|
80 |
+
"latent_keyframe": ("LATENT_KEYFRAME", ),
|
81 |
+
"null_latent_kf_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001},),
|
82 |
+
"inherit_missing": ("BOOLEAN", {"default": True},),
|
83 |
+
"mask_optional": ("MASK", ),
|
84 |
+
"print_keyframes": ("BOOLEAN", {"default": False}),
|
85 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 50}),
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
RETURN_NAMES = ("TIMESTEP_KF", )
|
90 |
+
RETURN_TYPES = ("TIMESTEP_KEYFRAME", )
|
91 |
+
FUNCTION = "load_keyframe"
|
92 |
+
|
93 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
94 |
+
|
95 |
+
def load_keyframe(self,
|
96 |
+
start_percent: float, end_percent: float,
|
97 |
+
strength_start: float, strength_end: float, interpolation: str, intervals: int,
|
98 |
+
cn_weights: ControlWeights=None,
|
99 |
+
latent_keyframe: LatentKeyframeGroup=None,
|
100 |
+
prev_timestep_kf: TimestepKeyframeGroup=None,
|
101 |
+
null_latent_kf_strength: float=0.0,
|
102 |
+
inherit_missing=True,
|
103 |
+
guarantee_steps=1,
|
104 |
+
mask_optional=None, print_keyframes=False):
|
105 |
+
if not prev_timestep_kf:
|
106 |
+
prev_timestep_kf = TimestepKeyframeGroup()
|
107 |
+
else:
|
108 |
+
prev_timestep_kf = prev_timestep_kf.clone()
|
109 |
+
|
110 |
+
percents = SI.get_weights(num_from=start_percent, num_to=end_percent, length=intervals, method=SI.LINEAR)
|
111 |
+
strengths = SI.get_weights(num_from=strength_start, num_to=strength_end, length=intervals, method=interpolation)
|
112 |
+
|
113 |
+
is_first = True
|
114 |
+
for percent, strength in zip(percents, strengths):
|
115 |
+
guarantee_steps = 0
|
116 |
+
if is_first:
|
117 |
+
guarantee_steps = 1
|
118 |
+
is_first = False
|
119 |
+
prev_timestep_kf.add(TimestepKeyframe(start_percent=percent, strength=strength, null_latent_kf_strength=null_latent_kf_strength,
|
120 |
+
control_weights=cn_weights, latent_keyframes=latent_keyframe, inherit_missing=inherit_missing,
|
121 |
+
guarantee_steps=guarantee_steps, mask_hint_orig=mask_optional))
|
122 |
+
if print_keyframes:
|
123 |
+
logger.info(f"TimestepKeyframe - start_percent:{percent} = {strength}")
|
124 |
+
return (prev_timestep_kf,)
|
125 |
+
|
126 |
+
|
127 |
+
class TimestepKeyframeFromStrengthListNode:
|
128 |
+
@classmethod
|
129 |
+
def INPUT_TYPES(s):
|
130 |
+
return {
|
131 |
+
"required": {
|
132 |
+
"float_strengths": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
|
133 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001},),
|
134 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
135 |
+
},
|
136 |
+
"optional": {
|
137 |
+
"prev_timestep_kf": ("TIMESTEP_KEYFRAME", ),
|
138 |
+
"cn_weights": ("CONTROL_NET_WEIGHTS", ),
|
139 |
+
"latent_keyframe": ("LATENT_KEYFRAME", ),
|
140 |
+
"null_latent_kf_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001},),
|
141 |
+
"inherit_missing": ("BOOLEAN", {"default": True},),
|
142 |
+
"mask_optional": ("MASK", ),
|
143 |
+
"print_keyframes": ("BOOLEAN", {"default": False}),
|
144 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
145 |
+
}
|
146 |
+
}
|
147 |
+
|
148 |
+
RETURN_NAMES = ("TIMESTEP_KF", )
|
149 |
+
RETURN_TYPES = ("TIMESTEP_KEYFRAME", )
|
150 |
+
FUNCTION = "load_keyframe"
|
151 |
+
|
152 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
153 |
+
|
154 |
+
def load_keyframe(self,
|
155 |
+
start_percent: float, end_percent: float,
|
156 |
+
float_strengths: float,
|
157 |
+
cn_weights: ControlWeights=None,
|
158 |
+
latent_keyframe: LatentKeyframeGroup=None,
|
159 |
+
prev_timestep_kf: TimestepKeyframeGroup=None,
|
160 |
+
null_latent_kf_strength: float=0.0,
|
161 |
+
inherit_missing=True,
|
162 |
+
guarantee_steps=1,
|
163 |
+
mask_optional=None, print_keyframes=False):
|
164 |
+
if not prev_timestep_kf:
|
165 |
+
prev_timestep_kf = TimestepKeyframeGroup()
|
166 |
+
else:
|
167 |
+
prev_timestep_kf = prev_timestep_kf.clone()
|
168 |
+
|
169 |
+
if type(float_strengths) in (float, int):
|
170 |
+
float_strengths = [float(float_strengths)]
|
171 |
+
elif isinstance(float_strengths, Iterable):
|
172 |
+
pass
|
173 |
+
else:
|
174 |
+
raise Exception(f"strengths_float must be either an iterable input or a float, but was {type(float_strengths).__repr__}.")
|
175 |
+
percents = SI.get_weights(num_from=start_percent, num_to=end_percent, length=len(float_strengths), method=SI.LINEAR)
|
176 |
+
|
177 |
+
is_first = True
|
178 |
+
for percent, strength in zip(percents, float_strengths):
|
179 |
+
guarantee_steps = 0
|
180 |
+
if is_first:
|
181 |
+
guarantee_steps = 1
|
182 |
+
is_first = False
|
183 |
+
prev_timestep_kf.add(TimestepKeyframe(start_percent=percent, strength=strength, null_latent_kf_strength=null_latent_kf_strength,
|
184 |
+
control_weights=cn_weights, latent_keyframes=latent_keyframe, inherit_missing=inherit_missing,
|
185 |
+
guarantee_steps=guarantee_steps, mask_hint_orig=mask_optional))
|
186 |
+
if print_keyframes:
|
187 |
+
logger.info(f"TimestepKeyframe - start_percent:{percent} = {strength}")
|
188 |
+
return (prev_timestep_kf,)
|
189 |
+
|
190 |
+
|
191 |
+
class LatentKeyframeNode:
|
192 |
+
@classmethod
|
193 |
+
def INPUT_TYPES(s):
|
194 |
+
return {
|
195 |
+
"required": {
|
196 |
+
"batch_index": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}),
|
197 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
198 |
+
},
|
199 |
+
"optional": {
|
200 |
+
"prev_latent_kf": ("LATENT_KEYFRAME", ),
|
201 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
202 |
+
}
|
203 |
+
}
|
204 |
+
|
205 |
+
RETURN_NAMES = ("LATENT_KF", )
|
206 |
+
RETURN_TYPES = ("LATENT_KEYFRAME", )
|
207 |
+
FUNCTION = "load_keyframe"
|
208 |
+
|
209 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
210 |
+
|
211 |
+
def load_keyframe(self,
|
212 |
+
batch_index: int,
|
213 |
+
strength: float,
|
214 |
+
prev_latent_kf: LatentKeyframeGroup=None,
|
215 |
+
prev_latent_keyframe: LatentKeyframeGroup=None, # old name
|
216 |
+
):
|
217 |
+
prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
|
218 |
+
if not prev_latent_keyframe:
|
219 |
+
prev_latent_keyframe = LatentKeyframeGroup()
|
220 |
+
else:
|
221 |
+
prev_latent_keyframe = prev_latent_keyframe.clone()
|
222 |
+
keyframe = LatentKeyframe(batch_index, strength)
|
223 |
+
prev_latent_keyframe.add(keyframe)
|
224 |
+
return (prev_latent_keyframe,)
|
225 |
+
|
226 |
+
|
227 |
+
class LatentKeyframeGroupNode:
|
228 |
+
@classmethod
|
229 |
+
def INPUT_TYPES(s):
|
230 |
+
return {
|
231 |
+
"required": {
|
232 |
+
"index_strengths": ("STRING", {"multiline": True, "default": ""}),
|
233 |
+
},
|
234 |
+
"optional": {
|
235 |
+
"prev_latent_kf": ("LATENT_KEYFRAME", ),
|
236 |
+
"latent_optional": ("LATENT", ),
|
237 |
+
"print_keyframes": ("BOOLEAN", {"default": False}),
|
238 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 35}),
|
239 |
+
}
|
240 |
+
}
|
241 |
+
|
242 |
+
RETURN_NAMES = ("LATENT_KF", )
|
243 |
+
RETURN_TYPES = ("LATENT_KEYFRAME", )
|
244 |
+
FUNCTION = "load_keyframes"
|
245 |
+
|
246 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
247 |
+
|
248 |
+
def validate_index(self, index: int, latent_count: int = 0, is_range: bool = False, allow_negative = False) -> int:
|
249 |
+
# if part of range, do nothing
|
250 |
+
if is_range:
|
251 |
+
return index
|
252 |
+
# otherwise, validate index
|
253 |
+
# validate not out of range - only when latent_count is passed in
|
254 |
+
if latent_count > 0 and index > latent_count-1:
|
255 |
+
raise IndexError(f"Index '{index}' out of range for the total {latent_count} latents.")
|
256 |
+
# if negative, validate not out of range
|
257 |
+
if index < 0:
|
258 |
+
if not allow_negative:
|
259 |
+
raise IndexError(f"Negative indeces not allowed, but was {index}.")
|
260 |
+
conv_index = latent_count+index
|
261 |
+
if conv_index < 0:
|
262 |
+
raise IndexError(f"Index '{index}', converted to '{conv_index}' out of range for the total {latent_count} latents.")
|
263 |
+
index = conv_index
|
264 |
+
return index
|
265 |
+
|
266 |
+
def convert_to_index_int(self, raw_index: str, latent_count: int = 0, is_range: bool = False, allow_negative = False) -> int:
|
267 |
+
try:
|
268 |
+
return self.validate_index(int(raw_index), latent_count=latent_count, is_range=is_range, allow_negative=allow_negative)
|
269 |
+
except ValueError as e:
|
270 |
+
raise ValueError(f"index '{raw_index}' must be an integer.", e)
|
271 |
+
|
272 |
+
def convert_to_latent_keyframes(self, latent_indeces: str, latent_count: int) -> set[LatentKeyframe]:
|
273 |
+
if not latent_indeces:
|
274 |
+
return set()
|
275 |
+
int_latent_indeces = [i for i in range(0, latent_count)]
|
276 |
+
allow_negative = latent_count > 0
|
277 |
+
chosen_indeces = set()
|
278 |
+
# parse string - allow positive ints, negative ints, and ranges separated by ':'
|
279 |
+
groups = latent_indeces.split(",")
|
280 |
+
groups = [g.strip() for g in groups]
|
281 |
+
for g in groups:
|
282 |
+
# parse strengths - default to 1.0 if no strength given
|
283 |
+
strength = 1.0
|
284 |
+
if '=' in g:
|
285 |
+
g, strength_str = g.split("=", 1)
|
286 |
+
g = g.strip()
|
287 |
+
try:
|
288 |
+
strength = float(strength_str.strip())
|
289 |
+
except ValueError as e:
|
290 |
+
raise ValueError(f"strength '{strength_str}' must be a float.", e)
|
291 |
+
if strength < 0:
|
292 |
+
raise ValueError(f"Strength '{strength}' cannot be negative.")
|
293 |
+
# parse range of indeces (e.g. 2:16)
|
294 |
+
if ':' in g:
|
295 |
+
index_range = g.split(":", 1)
|
296 |
+
index_range = [r.strip() for r in index_range]
|
297 |
+
start_index = self.convert_to_index_int(index_range[0], latent_count=latent_count, is_range=True, allow_negative=allow_negative)
|
298 |
+
end_index = self.convert_to_index_int(index_range[1], latent_count=latent_count, is_range=True, allow_negative=allow_negative)
|
299 |
+
# if latents were passed in, base indeces on known latent count
|
300 |
+
if len(int_latent_indeces) > 0:
|
301 |
+
for i in int_latent_indeces[start_index:end_index]:
|
302 |
+
chosen_indeces.add(LatentKeyframe(i, strength))
|
303 |
+
# otherwise, assume indeces are valid
|
304 |
+
else:
|
305 |
+
for i in range(start_index, end_index):
|
306 |
+
chosen_indeces.add(LatentKeyframe(i, strength))
|
307 |
+
# parse individual indeces
|
308 |
+
else:
|
309 |
+
chosen_indeces.add(LatentKeyframe(self.convert_to_index_int(g, latent_count=latent_count, allow_negative=allow_negative), strength))
|
310 |
+
return chosen_indeces
|
311 |
+
|
312 |
+
def load_keyframes(self,
|
313 |
+
index_strengths: str,
|
314 |
+
prev_latent_kf: LatentKeyframeGroup=None,
|
315 |
+
prev_latent_keyframe: LatentKeyframeGroup=None, # old name
|
316 |
+
latent_image_opt=None,
|
317 |
+
print_keyframes=False):
|
318 |
+
prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
|
319 |
+
if not prev_latent_keyframe:
|
320 |
+
prev_latent_keyframe = LatentKeyframeGroup()
|
321 |
+
else:
|
322 |
+
prev_latent_keyframe = prev_latent_keyframe.clone()
|
323 |
+
curr_latent_keyframe = LatentKeyframeGroup()
|
324 |
+
|
325 |
+
latent_count = -1
|
326 |
+
if latent_image_opt:
|
327 |
+
latent_count = latent_image_opt['samples'].size()[0]
|
328 |
+
latent_keyframes = self.convert_to_latent_keyframes(index_strengths, latent_count=latent_count)
|
329 |
+
|
330 |
+
for latent_keyframe in latent_keyframes:
|
331 |
+
curr_latent_keyframe.add(latent_keyframe)
|
332 |
+
|
333 |
+
if print_keyframes:
|
334 |
+
for keyframe in curr_latent_keyframe.keyframes:
|
335 |
+
logger.info(f"LatentKeyframe {keyframe.batch_index}={keyframe.strength}")
|
336 |
+
|
337 |
+
# replace values with prev_latent_keyframes
|
338 |
+
for latent_keyframe in prev_latent_keyframe.keyframes:
|
339 |
+
curr_latent_keyframe.add(latent_keyframe)
|
340 |
+
|
341 |
+
return (curr_latent_keyframe,)
|
342 |
+
|
343 |
+
|
344 |
+
class LatentKeyframeInterpolationNode:
|
345 |
+
@classmethod
|
346 |
+
def INPUT_TYPES(s):
|
347 |
+
return {
|
348 |
+
"required": {
|
349 |
+
"batch_index_from": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}),
|
350 |
+
"batch_index_to_excl": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}),
|
351 |
+
"strength_from": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
352 |
+
"strength_to": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
353 |
+
"interpolation": (SI._LIST, ),
|
354 |
+
},
|
355 |
+
"optional": {
|
356 |
+
"prev_latent_kf": ("LATENT_KEYFRAME", ),
|
357 |
+
"print_keyframes": ("BOOLEAN", {"default": False}),
|
358 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 50}),
|
359 |
+
}
|
360 |
+
}
|
361 |
+
|
362 |
+
RETURN_NAMES = ("LATENT_KF", )
|
363 |
+
RETURN_TYPES = ("LATENT_KEYFRAME", )
|
364 |
+
FUNCTION = "load_keyframe"
|
365 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
366 |
+
|
367 |
+
def load_keyframe(self,
|
368 |
+
batch_index_from: int,
|
369 |
+
strength_from: float,
|
370 |
+
batch_index_to_excl: int,
|
371 |
+
strength_to: float,
|
372 |
+
interpolation: str,
|
373 |
+
prev_latent_kf: LatentKeyframeGroup=None,
|
374 |
+
prev_latent_keyframe: LatentKeyframeGroup=None, # old name
|
375 |
+
print_keyframes=False):
|
376 |
+
|
377 |
+
if (batch_index_from > batch_index_to_excl):
|
378 |
+
raise ValueError("batch_index_from must be less than or equal to batch_index_to.")
|
379 |
+
|
380 |
+
if (batch_index_from < 0 and batch_index_to_excl >= 0):
|
381 |
+
raise ValueError("batch_index_from and batch_index_to must be either both positive or both negative.")
|
382 |
+
|
383 |
+
prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
|
384 |
+
if not prev_latent_keyframe:
|
385 |
+
prev_latent_keyframe = LatentKeyframeGroup()
|
386 |
+
else:
|
387 |
+
prev_latent_keyframe = prev_latent_keyframe.clone()
|
388 |
+
curr_latent_keyframe = LatentKeyframeGroup()
|
389 |
+
|
390 |
+
steps = batch_index_to_excl - batch_index_from
|
391 |
+
diff = strength_to - strength_from
|
392 |
+
if interpolation == SI.LINEAR:
|
393 |
+
weights = np.linspace(strength_from, strength_to, steps)
|
394 |
+
elif interpolation == SI.EASE_IN:
|
395 |
+
index = np.linspace(0, 1, steps)
|
396 |
+
weights = diff * np.power(index, 2) + strength_from
|
397 |
+
elif interpolation == SI.EASE_OUT:
|
398 |
+
index = np.linspace(0, 1, steps)
|
399 |
+
weights = diff * (1 - np.power(1 - index, 2)) + strength_from
|
400 |
+
elif interpolation == SI.EASE_IN_OUT:
|
401 |
+
index = np.linspace(0, 1, steps)
|
402 |
+
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + strength_from
|
403 |
+
|
404 |
+
for i in range(steps):
|
405 |
+
keyframe = LatentKeyframe(batch_index_from + i, float(weights[i]))
|
406 |
+
curr_latent_keyframe.add(keyframe)
|
407 |
+
|
408 |
+
if print_keyframes:
|
409 |
+
for keyframe in curr_latent_keyframe.keyframes:
|
410 |
+
logger.info(f"LatentKeyframe {keyframe.batch_index}={keyframe.strength}")
|
411 |
+
|
412 |
+
# replace values with prev_latent_keyframes
|
413 |
+
for latent_keyframe in prev_latent_keyframe.keyframes:
|
414 |
+
curr_latent_keyframe.add(latent_keyframe)
|
415 |
+
|
416 |
+
return (curr_latent_keyframe,)
|
417 |
+
|
418 |
+
|
419 |
+
class LatentKeyframeBatchedGroupNode:
|
420 |
+
@classmethod
|
421 |
+
def INPUT_TYPES(s):
|
422 |
+
return {
|
423 |
+
"required": {
|
424 |
+
"float_strengths": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
|
425 |
+
},
|
426 |
+
"optional": {
|
427 |
+
"prev_latent_kf": ("LATENT_KEYFRAME", ),
|
428 |
+
"print_keyframes": ("BOOLEAN", {"default": False}),
|
429 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
430 |
+
}
|
431 |
+
}
|
432 |
+
|
433 |
+
RETURN_NAMES = ("LATENT_KF", )
|
434 |
+
RETURN_TYPES = ("LATENT_KEYFRAME", )
|
435 |
+
FUNCTION = "load_keyframe"
|
436 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
437 |
+
|
438 |
+
def load_keyframe(self, float_strengths: Union[float, list[float]],
|
439 |
+
prev_latent_kf: LatentKeyframeGroup=None,
|
440 |
+
prev_latent_keyframe: LatentKeyframeGroup=None, # old name
|
441 |
+
print_keyframes=False):
|
442 |
+
prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
|
443 |
+
if not prev_latent_keyframe:
|
444 |
+
prev_latent_keyframe = LatentKeyframeGroup()
|
445 |
+
else:
|
446 |
+
prev_latent_keyframe = prev_latent_keyframe.clone()
|
447 |
+
curr_latent_keyframe = LatentKeyframeGroup()
|
448 |
+
|
449 |
+
# if received a normal float input, do nothing
|
450 |
+
if type(float_strengths) in (float, int):
|
451 |
+
logger.info("No batched float_strengths passed into Latent Keyframe Batch Group node; will not create any new keyframes.")
|
452 |
+
# if iterable, attempt to create LatentKeyframes with chosen strengths
|
453 |
+
elif isinstance(float_strengths, Iterable):
|
454 |
+
for idx, strength in enumerate(float_strengths):
|
455 |
+
keyframe = LatentKeyframe(idx, strength)
|
456 |
+
curr_latent_keyframe.add(keyframe)
|
457 |
+
else:
|
458 |
+
raise ValueError(f"Expected strengths to be an iterable input, but was {type(float_strengths).__repr__}.")
|
459 |
+
|
460 |
+
if print_keyframes:
|
461 |
+
for keyframe in curr_latent_keyframe.keyframes:
|
462 |
+
logger.info(f"LatentKeyframe {keyframe.batch_index}={keyframe.strength}")
|
463 |
+
|
464 |
+
# replace values with prev_latent_keyframes
|
465 |
+
for latent_keyframe in prev_latent_keyframe.keyframes:
|
466 |
+
curr_latent_keyframe.add(latent_keyframe)
|
467 |
+
|
468 |
+
return (curr_latent_keyframe,)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_loosecontrol.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import folder_paths
|
2 |
+
import comfy.utils
|
3 |
+
import comfy.model_detection
|
4 |
+
import comfy.model_management
|
5 |
+
import comfy.lora
|
6 |
+
from comfy.model_patcher import ModelPatcher
|
7 |
+
|
8 |
+
from .utils import TimestepKeyframeGroup
|
9 |
+
from .control import ControlNetAdvanced, load_controlnet
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def convert_cn_lora_from_diffusers(cn_model: ModelPatcher, lora_path: str):
|
15 |
+
lora_data = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
16 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
17 |
+
for key, value in lora_data.items():
|
18 |
+
lora_data[key] = value.to(unet_dtype)
|
19 |
+
diffusers_keys = comfy.utils.unet_to_diffusers(cn_model.model.state_dict())
|
20 |
+
|
21 |
+
#lora_data = comfy.model_detection.unet_config_from_diffusers_unet(lora_data, dtype=unet_dtype)
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
#key_map = comfy.lora.model_lora_keys_unet(cn_model.model, key_map)
|
26 |
+
lora_data = comfy.lora.load_lora(lora_data, to_load=diffusers_keys)
|
27 |
+
|
28 |
+
# TODO: detect if diffusers for sure? not sure if needed at this time, since cn loras are
|
29 |
+
# only used currently for LOOSEControl, and those are all in diffusers format
|
30 |
+
#unet_dtype = comfy.model_management.unet_dtype()
|
31 |
+
#lora_data = comfy.model_detection.unet_config_from_diffusers_unet(lora_data, unet_dtype)
|
32 |
+
return lora_data
|
33 |
+
|
34 |
+
|
35 |
+
class ControlNetLoaderWithLoraAdvanced:
|
36 |
+
@classmethod
|
37 |
+
def INPUT_TYPES(s):
|
38 |
+
return {
|
39 |
+
"required": {
|
40 |
+
"control_net_name": (folder_paths.get_filename_list("controlnet"), ),
|
41 |
+
"cn_lora_name": (folder_paths.get_filename_list("controlnet"), ),
|
42 |
+
"cn_lora_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
43 |
+
},
|
44 |
+
"optional": {
|
45 |
+
"timestep_keyframe": ("TIMESTEP_KEYFRAME", ),
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
50 |
+
FUNCTION = "load_controlnet"
|
51 |
+
|
52 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/LOOSEControl"
|
53 |
+
|
54 |
+
def load_controlnet(self, control_net_name, cn_lora_name, cn_lora_strength: float,
|
55 |
+
timestep_keyframe: TimestepKeyframeGroup=None
|
56 |
+
):
|
57 |
+
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
58 |
+
controlnet: ControlNetAdvanced = load_controlnet(controlnet_path, timestep_keyframe)
|
59 |
+
if not isinstance(controlnet, ControlNetAdvanced):
|
60 |
+
raise ValueError("Type {} is not compatible with CN LoRA features at this time.")
|
61 |
+
# now, try to load CN LoRA
|
62 |
+
lora_path = folder_paths.get_full_path("controlnet", cn_lora_name)
|
63 |
+
lora_data = convert_cn_lora_from_diffusers(cn_model=controlnet.control_model_wrapped, lora_path=lora_path)
|
64 |
+
# apply patches to wrapped control_model
|
65 |
+
controlnet.control_model_wrapped.add_patches(lora_data, strength_patch=cn_lora_strength)
|
66 |
+
# all done
|
67 |
+
return (controlnet,)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_plusplus.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
import math
|
3 |
+
|
4 |
+
import folder_paths
|
5 |
+
|
6 |
+
from .control_plusplus import load_controlnetplusplus, PlusPlusType, PlusPlusInput, PlusPlusInputGroup, PlusPlusImageWrapper
|
7 |
+
from .utils import BIGMAX
|
8 |
+
|
9 |
+
|
10 |
+
class PlusPlusLoaderAdvanced:
|
11 |
+
@classmethod
|
12 |
+
def INPUT_TYPES(s):
|
13 |
+
return {
|
14 |
+
"required": {
|
15 |
+
"plus_input": ("PLUS_INPUT", ),
|
16 |
+
"name": (folder_paths.get_filename_list("controlnet"), ),
|
17 |
+
}
|
18 |
+
}
|
19 |
+
|
20 |
+
RETURN_TYPES = ("CONTROL_NET", "IMAGE",)
|
21 |
+
FUNCTION = "load_controlnet_plusplus"
|
22 |
+
|
23 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/ControlNet++"
|
24 |
+
|
25 |
+
def load_controlnet_plusplus(self, plus_input: PlusPlusInputGroup, name: str):
|
26 |
+
controlnet_path = folder_paths.get_full_path("controlnet", name)
|
27 |
+
controlnet = load_controlnetplusplus(controlnet_path)
|
28 |
+
controlnet.verify_control_type(name, plus_input)
|
29 |
+
return (controlnet, PlusPlusImageWrapper(plus_input),)
|
30 |
+
|
31 |
+
|
32 |
+
class PlusPlusLoaderSingle:
|
33 |
+
@classmethod
|
34 |
+
def INPUT_TYPES(s):
|
35 |
+
return {
|
36 |
+
"required": {
|
37 |
+
"name": (folder_paths.get_filename_list("controlnet"), ),
|
38 |
+
"control_type": (PlusPlusType._LIST_WITH_NONE, {"default": PlusPlusType.NONE}, ),
|
39 |
+
}
|
40 |
+
}
|
41 |
+
|
42 |
+
RETURN_TYPES = ("CONTROL_NET",)
|
43 |
+
FUNCTION = "load_controlnet_plusplus"
|
44 |
+
|
45 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/ControlNet++"
|
46 |
+
|
47 |
+
def load_controlnet_plusplus(self, name: str, control_type: str):
|
48 |
+
controlnet_path = folder_paths.get_full_path("controlnet", name)
|
49 |
+
controlnet = load_controlnetplusplus(controlnet_path)
|
50 |
+
controlnet.single_control_type = control_type
|
51 |
+
controlnet.verify_control_type(name)
|
52 |
+
return (controlnet,)
|
53 |
+
|
54 |
+
|
55 |
+
class PlusPlusInputNode:
|
56 |
+
@classmethod
|
57 |
+
def INPUT_TYPES(s):
|
58 |
+
return {
|
59 |
+
"required": {
|
60 |
+
"image": ("IMAGE",),
|
61 |
+
"control_type": (PlusPlusType._LIST,),
|
62 |
+
},
|
63 |
+
"optional": {
|
64 |
+
"prev_plus_input": ("PLUS_INPUT",),
|
65 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
66 |
+
#"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": BIGMAX, "step": 0.01}),
|
67 |
+
}
|
68 |
+
}
|
69 |
+
|
70 |
+
RETURN_TYPES = ("PLUS_INPUT", )
|
71 |
+
FUNCTION = "wrap_images"
|
72 |
+
|
73 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/ControlNet++"
|
74 |
+
|
75 |
+
def wrap_images(self, image: Tensor, control_type: str, strength=1.0, prev_plus_input: PlusPlusInputGroup=None):
|
76 |
+
if prev_plus_input is None:
|
77 |
+
prev_plus_input = PlusPlusInputGroup()
|
78 |
+
prev_plus_input = prev_plus_input.clone()
|
79 |
+
|
80 |
+
if math.isclose(strength, 0.0):
|
81 |
+
strength = 0.0000001
|
82 |
+
pp_input = PlusPlusInput(image, control_type, strength)
|
83 |
+
prev_plus_input.add(pp_input)
|
84 |
+
|
85 |
+
return (prev_plus_input,)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_reference.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
|
3 |
+
from nodes import VAEEncode
|
4 |
+
import comfy.utils
|
5 |
+
from comfy.sd import VAE
|
6 |
+
|
7 |
+
from .control_reference import ReferenceAdvanced, ReferenceOptions, ReferenceType, ReferencePreprocWrapper
|
8 |
+
|
9 |
+
|
10 |
+
# node for ReferenceCN
|
11 |
+
class ReferenceControlNetNode:
|
12 |
+
@classmethod
|
13 |
+
def INPUT_TYPES(s):
|
14 |
+
return {
|
15 |
+
"required": {
|
16 |
+
"reference_type": (ReferenceType._LIST,),
|
17 |
+
"style_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
18 |
+
"ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
19 |
+
},
|
20 |
+
}
|
21 |
+
|
22 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
23 |
+
FUNCTION = "load_controlnet"
|
24 |
+
|
25 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/Reference"
|
26 |
+
|
27 |
+
def load_controlnet(self, reference_type: str, style_fidelity: float, ref_weight: float):
|
28 |
+
ref_opts = ReferenceOptions.create_combo(reference_type=reference_type, style_fidelity=style_fidelity, ref_weight=ref_weight)
|
29 |
+
controlnet = ReferenceAdvanced(ref_opts=ref_opts, timestep_keyframes=None)
|
30 |
+
return (controlnet,)
|
31 |
+
|
32 |
+
|
33 |
+
class ReferenceControlFinetune:
|
34 |
+
@classmethod
|
35 |
+
def INPUT_TYPES(s):
|
36 |
+
return {
|
37 |
+
"required": {
|
38 |
+
"attn_style_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
39 |
+
"attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
40 |
+
"attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
41 |
+
"adain_style_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
42 |
+
"adain_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
43 |
+
"adain_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
44 |
+
},
|
45 |
+
}
|
46 |
+
|
47 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
48 |
+
FUNCTION = "load_controlnet"
|
49 |
+
|
50 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/Reference"
|
51 |
+
|
52 |
+
def load_controlnet(self,
|
53 |
+
attn_style_fidelity: float, attn_ref_weight: float, attn_strength: float,
|
54 |
+
adain_style_fidelity: float, adain_ref_weight: float, adain_strength: float):
|
55 |
+
ref_opts = ReferenceOptions(reference_type=ReferenceType.ATTN_ADAIN,
|
56 |
+
attn_style_fidelity=attn_style_fidelity, attn_ref_weight=attn_ref_weight, attn_strength=attn_strength,
|
57 |
+
adain_style_fidelity=adain_style_fidelity, adain_ref_weight=adain_ref_weight, adain_strength=adain_strength)
|
58 |
+
controlnet = ReferenceAdvanced(ref_opts=ref_opts, timestep_keyframes=None)
|
59 |
+
return (controlnet,)
|
60 |
+
|
61 |
+
|
62 |
+
class ReferencePreprocessorNode:
|
63 |
+
@classmethod
|
64 |
+
def INPUT_TYPES(s):
|
65 |
+
return {
|
66 |
+
"required": {
|
67 |
+
"image": ("IMAGE", ),
|
68 |
+
"vae": ("VAE", ),
|
69 |
+
"latent_size": ("LATENT", ),
|
70 |
+
}
|
71 |
+
}
|
72 |
+
|
73 |
+
RETURN_TYPES = ("IMAGE",)
|
74 |
+
RETURN_NAMES = ("proc_IMAGE",)
|
75 |
+
FUNCTION = "preprocess_images"
|
76 |
+
|
77 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/Reference/preprocess"
|
78 |
+
|
79 |
+
def preprocess_images(self, vae: VAE, image: Tensor, latent_size: Tensor):
|
80 |
+
# first, resize image to match latents
|
81 |
+
image = image.movedim(-1,1)
|
82 |
+
image = comfy.utils.common_upscale(image, latent_size["samples"].shape[3] * 8, latent_size["samples"].shape[2] * 8, 'nearest-exact', "center")
|
83 |
+
image = image.movedim(1,-1)
|
84 |
+
# then, vae encode
|
85 |
+
try:
|
86 |
+
image = vae.vae_encode_crop_pixels(image)
|
87 |
+
except Exception:
|
88 |
+
image = VAEEncode.vae_encode_crop_pixels(image)
|
89 |
+
encoded = vae.encode(image[:,:,:,:3])
|
90 |
+
return (ReferencePreprocWrapper(condhint=encoded),)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_sparsectrl.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
|
3 |
+
import folder_paths
|
4 |
+
from nodes import VAEEncode
|
5 |
+
import comfy.utils
|
6 |
+
from comfy.sd import VAE
|
7 |
+
|
8 |
+
from .utils import TimestepKeyframeGroup
|
9 |
+
from .control_sparsectrl import SparseMethod, SparseIndexMethod, SparseSettings, SparseSpreadMethod, PreprocSparseRGBWrapper, SparseConst, SparseContextAware, get_idx_list_from_str
|
10 |
+
from .control import load_sparsectrl, load_controlnet, ControlNetAdvanced, SparseCtrlAdvanced
|
11 |
+
|
12 |
+
|
13 |
+
# node for SparseCtrl loading
|
14 |
+
class SparseCtrlLoaderAdvanced:
|
15 |
+
@classmethod
|
16 |
+
def INPUT_TYPES(s):
|
17 |
+
return {
|
18 |
+
"required": {
|
19 |
+
"sparsectrl_name": (folder_paths.get_filename_list("controlnet"), ),
|
20 |
+
"use_motion": ("BOOLEAN", {"default": True}, ),
|
21 |
+
"motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
22 |
+
"motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
23 |
+
},
|
24 |
+
"optional": {
|
25 |
+
"sparse_method": ("SPARSE_METHOD", ),
|
26 |
+
"tk_optional": ("TIMESTEP_KEYFRAME", ),
|
27 |
+
"context_aware": (SparseContextAware.LIST, ),
|
28 |
+
"sparse_hint_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
29 |
+
"sparse_nonhint_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
30 |
+
"sparse_mask_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
31 |
+
}
|
32 |
+
}
|
33 |
+
|
34 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
35 |
+
FUNCTION = "load_controlnet"
|
36 |
+
|
37 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl"
|
38 |
+
|
39 |
+
def load_controlnet(self, sparsectrl_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None,
|
40 |
+
context_aware=SparseContextAware.NEAREST_HINT, sparse_hint_mult=1.0, sparse_nonhint_mult=1.0, sparse_mask_mult=1.0):
|
41 |
+
sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name)
|
42 |
+
sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale,
|
43 |
+
context_aware=context_aware,
|
44 |
+
sparse_mask_mult=sparse_mask_mult, sparse_hint_mult=sparse_hint_mult, sparse_nonhint_mult=sparse_nonhint_mult)
|
45 |
+
sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=sparse_settings)
|
46 |
+
return (sparsectrl,)
|
47 |
+
|
48 |
+
|
49 |
+
class SparseCtrlMergedLoaderAdvanced:
|
50 |
+
@classmethod
|
51 |
+
def INPUT_TYPES(s):
|
52 |
+
return {
|
53 |
+
"required": {
|
54 |
+
"sparsectrl_name": (folder_paths.get_filename_list("controlnet"), ),
|
55 |
+
"control_net_name": (folder_paths.get_filename_list("controlnet"), ),
|
56 |
+
"use_motion": ("BOOLEAN", {"default": True}, ),
|
57 |
+
"motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
58 |
+
"motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
59 |
+
},
|
60 |
+
"optional": {
|
61 |
+
"sparse_method": ("SPARSE_METHOD", ),
|
62 |
+
"tk_optional": ("TIMESTEP_KEYFRAME", ),
|
63 |
+
}
|
64 |
+
}
|
65 |
+
|
66 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
67 |
+
FUNCTION = "load_controlnet"
|
68 |
+
|
69 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl/experimental"
|
70 |
+
|
71 |
+
def load_controlnet(self, sparsectrl_name: str, control_net_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None):
|
72 |
+
sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name)
|
73 |
+
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
74 |
+
sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale, merged=True)
|
75 |
+
# first, load normal controlnet
|
76 |
+
controlnet = load_controlnet(controlnet_path, timestep_keyframe=tk_optional)
|
77 |
+
# confirm that controlnet is ControlNetAdvanced
|
78 |
+
if controlnet is None or type(controlnet) != ControlNetAdvanced:
|
79 |
+
raise ValueError(f"controlnet_path must point to a normal ControlNet, but instead: {type(controlnet).__name__}")
|
80 |
+
# next, load sparsectrl, making sure to load motion portion
|
81 |
+
sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=SparseSettings.default())
|
82 |
+
# now, combine state dicts
|
83 |
+
new_state_dict = controlnet.control_model.state_dict()
|
84 |
+
for key, value in sparsectrl.control_model.motion_holder.motion_wrapper.state_dict().items():
|
85 |
+
new_state_dict[key] = value
|
86 |
+
# now, reload sparsectrl with real settings
|
87 |
+
sparsectrl = load_sparsectrl(sparsectrl_path, controlnet_data=new_state_dict, timestep_keyframe=tk_optional, sparse_settings=sparse_settings)
|
88 |
+
return (sparsectrl,)
|
89 |
+
|
90 |
+
|
91 |
+
class SparseIndexMethodNode:
|
92 |
+
@classmethod
|
93 |
+
def INPUT_TYPES(s):
|
94 |
+
return {
|
95 |
+
"required": {
|
96 |
+
"indexes": ("STRING", {"default": "0"}),
|
97 |
+
}
|
98 |
+
}
|
99 |
+
|
100 |
+
RETURN_TYPES = ("SPARSE_METHOD",)
|
101 |
+
FUNCTION = "get_method"
|
102 |
+
|
103 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl"
|
104 |
+
|
105 |
+
def get_method(self, indexes: str):
|
106 |
+
idxs = get_idx_list_from_str(indexes)
|
107 |
+
return (SparseIndexMethod(idxs),)
|
108 |
+
|
109 |
+
|
110 |
+
class SparseSpreadMethodNode:
|
111 |
+
@classmethod
|
112 |
+
def INPUT_TYPES(s):
|
113 |
+
return {
|
114 |
+
"required": {
|
115 |
+
"spread": (SparseSpreadMethod.LIST,),
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
RETURN_TYPES = ("SPARSE_METHOD",)
|
120 |
+
FUNCTION = "get_method"
|
121 |
+
|
122 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl"
|
123 |
+
|
124 |
+
def get_method(self, spread: str):
|
125 |
+
return (SparseSpreadMethod(spread=spread),)
|
126 |
+
|
127 |
+
|
128 |
+
class RgbSparseCtrlPreprocessor:
|
129 |
+
@classmethod
|
130 |
+
def INPUT_TYPES(s):
|
131 |
+
return {
|
132 |
+
"required": {
|
133 |
+
"image": ("IMAGE", ),
|
134 |
+
"vae": ("VAE", ),
|
135 |
+
"latent_size": ("LATENT", ),
|
136 |
+
},
|
137 |
+
"optional": {
|
138 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
139 |
+
}
|
140 |
+
}
|
141 |
+
|
142 |
+
RETURN_TYPES = ("IMAGE",)
|
143 |
+
RETURN_NAMES = ("proc_IMAGE",)
|
144 |
+
FUNCTION = "preprocess_images"
|
145 |
+
|
146 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl/preprocess"
|
147 |
+
|
148 |
+
def preprocess_images(self, vae: VAE, image: Tensor, latent_size: Tensor):
|
149 |
+
# first, resize image to match latents
|
150 |
+
image = image.movedim(-1,1)
|
151 |
+
image = comfy.utils.common_upscale(image, latent_size["samples"].shape[3] * 8, latent_size["samples"].shape[2] * 8, 'nearest-exact', "center")
|
152 |
+
image = image.movedim(1,-1)
|
153 |
+
# then, vae encode
|
154 |
+
try:
|
155 |
+
image = vae.vae_encode_crop_pixels(image)
|
156 |
+
except Exception:
|
157 |
+
image = VAEEncode.vae_encode_crop_pixels(image)
|
158 |
+
encoded = vae.encode(image[:,:,:,:3])
|
159 |
+
return (PreprocSparseRGBWrapper(condhint=encoded),)
|
160 |
+
|
161 |
+
|
162 |
+
class SparseWeightExtras:
|
163 |
+
@classmethod
|
164 |
+
def INPUT_TYPES(s):
|
165 |
+
return {
|
166 |
+
"optional": {
|
167 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
168 |
+
"sparse_hint_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
169 |
+
"sparse_nonhint_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
170 |
+
"sparse_mask_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
171 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 50}),
|
172 |
+
}
|
173 |
+
}
|
174 |
+
|
175 |
+
RETURN_TYPES = ("CN_WEIGHTS_EXTRAS", )
|
176 |
+
RETURN_NAMES = ("cn_extras", )
|
177 |
+
FUNCTION = "create_weight_extras"
|
178 |
+
|
179 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl/extras"
|
180 |
+
|
181 |
+
def create_weight_extras(self, cn_extras: dict[str]={}, sparse_hint_mult=1.0, sparse_nonhint_mult=1.0, sparse_mask_mult=1.0):
|
182 |
+
cn_extras = cn_extras.copy()
|
183 |
+
cn_extras[SparseConst.HINT_MULT] = sparse_hint_mult
|
184 |
+
cn_extras[SparseConst.NONHINT_MULT] = sparse_nonhint_mult
|
185 |
+
cn_extras[SparseConst.MASK_MULT] = sparse_mask_mult
|
186 |
+
return (cn_extras, )
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_weight.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor
|
2 |
+
import torch
|
3 |
+
from .utils import TimestepKeyframe, TimestepKeyframeGroup, ControlWeights, get_properly_arranged_t2i_weights, linear_conversion
|
4 |
+
from .logger import logger
|
5 |
+
|
6 |
+
|
7 |
+
WEIGHTS_RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT")
|
8 |
+
|
9 |
+
|
10 |
+
class DefaultWeights:
|
11 |
+
@classmethod
|
12 |
+
def INPUT_TYPES(s):
|
13 |
+
return {
|
14 |
+
"optional": {
|
15 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
16 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
17 |
+
}
|
18 |
+
}
|
19 |
+
|
20 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
21 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
22 |
+
FUNCTION = "load_weights"
|
23 |
+
|
24 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights"
|
25 |
+
|
26 |
+
def load_weights(self, cn_extras: dict[str]={}):
|
27 |
+
weights = ControlWeights.default(extras=cn_extras)
|
28 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
29 |
+
|
30 |
+
|
31 |
+
class ScaledSoftMaskedUniversalWeights:
|
32 |
+
@classmethod
|
33 |
+
def INPUT_TYPES(s):
|
34 |
+
return {
|
35 |
+
"required": {
|
36 |
+
"mask": ("MASK", ),
|
37 |
+
"min_base_multiplier": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ),
|
38 |
+
"max_base_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}, ),
|
39 |
+
#"lock_min": ("BOOLEAN", {"default": False}, ),
|
40 |
+
#"lock_max": ("BOOLEAN", {"default": False}, ),
|
41 |
+
},
|
42 |
+
"optional": {
|
43 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
44 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
45 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
50 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
51 |
+
FUNCTION = "load_weights"
|
52 |
+
|
53 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights"
|
54 |
+
|
55 |
+
def load_weights(self, mask: Tensor, min_base_multiplier: float, max_base_multiplier: float, lock_min=False, lock_max=False,
|
56 |
+
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
57 |
+
# normalize mask
|
58 |
+
mask = mask.clone()
|
59 |
+
x_min = 0.0 if lock_min else mask.min()
|
60 |
+
x_max = 1.0 if lock_max else mask.max()
|
61 |
+
if x_min == x_max:
|
62 |
+
mask = torch.ones_like(mask) * max_base_multiplier
|
63 |
+
else:
|
64 |
+
mask = linear_conversion(mask, x_min, x_max, min_base_multiplier, max_base_multiplier)
|
65 |
+
weights = ControlWeights.universal_mask(weight_mask=mask, uncond_multiplier=uncond_multiplier, extras=cn_extras)
|
66 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
67 |
+
|
68 |
+
|
69 |
+
class ScaledSoftUniversalWeights:
|
70 |
+
@classmethod
|
71 |
+
def INPUT_TYPES(s):
|
72 |
+
return {
|
73 |
+
"required": {
|
74 |
+
"base_multiplier": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 1.0, "step": 0.001}, ),
|
75 |
+
},
|
76 |
+
"optional": {
|
77 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
78 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
79 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
80 |
+
}
|
81 |
+
}
|
82 |
+
|
83 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
84 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
85 |
+
FUNCTION = "load_weights"
|
86 |
+
|
87 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights"
|
88 |
+
|
89 |
+
def load_weights(self, base_multiplier, uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
90 |
+
weights = ControlWeights.universal(base_multiplier=base_multiplier, uncond_multiplier=uncond_multiplier, extras=cn_extras)
|
91 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
92 |
+
|
93 |
+
|
94 |
+
class SoftControlNetWeightsSD15:
|
95 |
+
@classmethod
|
96 |
+
def INPUT_TYPES(s):
|
97 |
+
return {
|
98 |
+
"required": {
|
99 |
+
"output_0": ("FLOAT", {"default": 0.09941396206337118, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
100 |
+
"output_1": ("FLOAT", {"default": 0.12050177219802567, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
101 |
+
"output_2": ("FLOAT", {"default": 0.14606275417942507, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
102 |
+
"output_3": ("FLOAT", {"default": 0.17704576264172736, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
103 |
+
"output_4": ("FLOAT", {"default": 0.214600924414215, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
104 |
+
"output_5": ("FLOAT", {"default": 0.26012233262329093, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
105 |
+
"output_6": ("FLOAT", {"default": 0.3152997971191405, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
106 |
+
"output_7": ("FLOAT", {"default": 0.3821815722656249, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
107 |
+
"output_8": ("FLOAT", {"default": 0.4632503906249999, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
108 |
+
"output_9": ("FLOAT", {"default": 0.561515625, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
109 |
+
"output_10": ("FLOAT", {"default": 0.6806249999999999, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
110 |
+
"output_11": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
111 |
+
"middle_0": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
112 |
+
},
|
113 |
+
"optional": {
|
114 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
115 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
116 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
117 |
+
}
|
118 |
+
}
|
119 |
+
|
120 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
121 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
122 |
+
FUNCTION = "load_weights"
|
123 |
+
|
124 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/ControlNet"
|
125 |
+
|
126 |
+
def load_weights(self, output_0, output_1, output_2, output_3, output_4, output_5, output_6,
|
127 |
+
output_7, output_8, output_9, output_10, output_11, middle_0,
|
128 |
+
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
129 |
+
return CustomControlNetWeightsSD15.load_weights(self,
|
130 |
+
output_0=output_0, output_1=output_1, output_2=output_2, output_3=output_3,
|
131 |
+
output_4=output_4, output_5=output_5, output_6=output_6, output_7=output_7,
|
132 |
+
output_8=output_8, output_9=output_9, output_10=output_10, output_11=output_11,
|
133 |
+
middle_0=middle_0,
|
134 |
+
uncond_multiplier=uncond_multiplier, cn_extras=cn_extras)
|
135 |
+
|
136 |
+
|
137 |
+
class CustomControlNetWeightsSD15:
|
138 |
+
@classmethod
|
139 |
+
def INPUT_TYPES(s):
|
140 |
+
return {
|
141 |
+
"required": {
|
142 |
+
"output_0": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
143 |
+
"output_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
144 |
+
"output_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
145 |
+
"output_3": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
146 |
+
"output_4": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
147 |
+
"output_5": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
148 |
+
"output_6": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
149 |
+
"output_7": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
150 |
+
"output_8": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
151 |
+
"output_9": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
152 |
+
"output_10": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
153 |
+
"output_11": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
154 |
+
"middle_0": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
155 |
+
},
|
156 |
+
"optional": {
|
157 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
158 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
159 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
164 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
165 |
+
FUNCTION = "load_weights"
|
166 |
+
|
167 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/ControlNet"
|
168 |
+
|
169 |
+
def load_weights(self, output_0, output_1, output_2, output_3, output_4, output_5, output_6,
|
170 |
+
output_7, output_8, output_9, output_10, output_11, middle_0,
|
171 |
+
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
172 |
+
weights_output = [output_0, output_1, output_2, output_3, output_4, output_5, output_6,
|
173 |
+
output_7, output_8, output_9, output_10, output_11]
|
174 |
+
weights_middle = [middle_0]
|
175 |
+
weights = ControlWeights.controlnet(weights_output=weights_output, weights_middle=weights_middle, uncond_multiplier=uncond_multiplier, extras=cn_extras)
|
176 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
177 |
+
|
178 |
+
|
179 |
+
class CustomControlNetWeightsFlux:
|
180 |
+
@classmethod
|
181 |
+
def INPUT_TYPES(s):
|
182 |
+
return {
|
183 |
+
"required": {
|
184 |
+
"input_0": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
185 |
+
"input_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
186 |
+
"input_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
187 |
+
"input_3": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
188 |
+
"input_4": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
189 |
+
"input_5": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
190 |
+
"input_6": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
191 |
+
"input_7": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
192 |
+
"input_8": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
193 |
+
"input_9": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
194 |
+
"input_10": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
195 |
+
"input_11": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
196 |
+
"input_12": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
197 |
+
"input_13": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
198 |
+
"input_14": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
199 |
+
"input_15": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
200 |
+
"input_16": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
201 |
+
"input_17": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
202 |
+
"input_18": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
203 |
+
},
|
204 |
+
"optional": {
|
205 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
206 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
207 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
208 |
+
}
|
209 |
+
}
|
210 |
+
|
211 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
212 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
213 |
+
FUNCTION = "load_weights"
|
214 |
+
|
215 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/ControlNet"
|
216 |
+
|
217 |
+
def load_weights(self, input_0, input_1, input_2, input_3, input_4, input_5, input_6,
|
218 |
+
input_7, input_8, input_9, input_10, input_11, input_12, input_13,
|
219 |
+
input_14, input_15, input_16, input_17, input_18,
|
220 |
+
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
221 |
+
weights_input = [input_0, input_1, input_2, input_3, input_4, input_5,
|
222 |
+
input_6, input_7, input_8, input_9, input_10, input_11,
|
223 |
+
input_12, input_13, input_14, input_15, input_16, input_17, input_18]
|
224 |
+
weights = ControlWeights.controlnet(weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=cn_extras)
|
225 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
226 |
+
|
227 |
+
|
228 |
+
class SoftT2IAdapterWeights:
|
229 |
+
@classmethod
|
230 |
+
def INPUT_TYPES(s):
|
231 |
+
return {
|
232 |
+
"required": {
|
233 |
+
"input_0": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
234 |
+
"input_1": ("FLOAT", {"default": 0.62, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
235 |
+
"input_2": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
236 |
+
"input_3": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
237 |
+
},
|
238 |
+
"optional": {
|
239 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
240 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
241 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
242 |
+
}
|
243 |
+
}
|
244 |
+
|
245 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
246 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
247 |
+
FUNCTION = "load_weights"
|
248 |
+
|
249 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/T2IAdapter"
|
250 |
+
|
251 |
+
def load_weights(self, input_0, input_1, input_2, input_3,
|
252 |
+
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
253 |
+
return CustomT2IAdapterWeights.load_weights(self, input_0=input_0, input_1=input_1, input_2=input_2, input_3=input_3,
|
254 |
+
uncond_multiplier=uncond_multiplier, cn_extras=cn_extras)
|
255 |
+
|
256 |
+
|
257 |
+
class CustomT2IAdapterWeights:
|
258 |
+
@classmethod
|
259 |
+
def INPUT_TYPES(s):
|
260 |
+
return {
|
261 |
+
"required": {
|
262 |
+
"input_0": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
263 |
+
"input_1": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
264 |
+
"input_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
265 |
+
"input_3": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
266 |
+
},
|
267 |
+
"optional": {
|
268 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
269 |
+
"cn_extras": ("CN_WEIGHTS_EXTRAS",),
|
270 |
+
"autosize": ("ACNAUTOSIZE", {"padding": 0}),
|
271 |
+
}
|
272 |
+
}
|
273 |
+
|
274 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
275 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
276 |
+
FUNCTION = "load_weights"
|
277 |
+
|
278 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/T2IAdapter"
|
279 |
+
|
280 |
+
def load_weights(self, input_0, input_1, input_2, input_3,
|
281 |
+
uncond_multiplier: float=1.0, cn_extras: dict[str]={}):
|
282 |
+
weights = [input_0, input_1, input_2, input_3]
|
283 |
+
weights = get_properly_arranged_t2i_weights(weights)
|
284 |
+
weights = ControlWeights.t2iadapter(weights_input=weights, uncond_multiplier=uncond_multiplier, extras=cn_extras)
|
285 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
ComfyUI-Advanced-ControlNet/adv_control/sampling.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Union
|
2 |
+
|
3 |
+
import comfy.sample
|
4 |
+
from comfy.model_patcher import ModelPatcher
|
5 |
+
from comfy.controlnet import ControlBase
|
6 |
+
from comfy.ldm.modules.attention import BasicTransformerBlock
|
7 |
+
|
8 |
+
|
9 |
+
from .control import convert_all_to_advanced, restore_all_controlnet_conns
|
10 |
+
from .control_reference import (ReferenceAdvanced, ReferenceInjections,
|
11 |
+
RefBasicTransformerBlock, RefTimestepEmbedSequential,
|
12 |
+
InjectionBasicTransformerBlockHolder, InjectionTimestepEmbedSequentialHolder,
|
13 |
+
_forward_inject_BasicTransformerBlock, factory_forward_inject_UNetModel,
|
14 |
+
handle_context_ref_setup,
|
15 |
+
REF_CONTROL_LIST_ALL, CONTEXTREF_CLEAN_FUNC)
|
16 |
+
from .control_lllite import (ControlLLLiteAdvanced)
|
17 |
+
from .utils import torch_dfs
|
18 |
+
|
19 |
+
|
20 |
+
def support_sliding_context_windows(model, positive, negative) -> tuple[bool, dict, dict]:
|
21 |
+
# convert to advanced, with report if anything was actually modified
|
22 |
+
modified, new_conds = convert_all_to_advanced([positive, negative])
|
23 |
+
positive, negative = new_conds
|
24 |
+
return modified, positive, negative
|
25 |
+
|
26 |
+
|
27 |
+
def has_sliding_context_windows(model):
|
28 |
+
motion_injection_params = getattr(model, "motion_injection_params", None)
|
29 |
+
if motion_injection_params is None:
|
30 |
+
return False
|
31 |
+
context_options = getattr(motion_injection_params, "context_options")
|
32 |
+
return context_options.context_length is not None
|
33 |
+
|
34 |
+
|
35 |
+
def get_contextref_obj(model):
|
36 |
+
motion_injection_params = getattr(model, "motion_injection_params", None)
|
37 |
+
if motion_injection_params is None:
|
38 |
+
return None
|
39 |
+
context_options = getattr(motion_injection_params, "context_options")
|
40 |
+
extras = getattr(context_options, "extras", None)
|
41 |
+
if extras is None:
|
42 |
+
return None
|
43 |
+
return getattr(extras, "context_ref", None)
|
44 |
+
|
45 |
+
|
46 |
+
def acn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable:
|
47 |
+
def get_refcn(control: ControlBase, order: int=-1):
|
48 |
+
ref_set: set[ReferenceAdvanced] = set()
|
49 |
+
if control is None:
|
50 |
+
return ref_set
|
51 |
+
if type(control) == ReferenceAdvanced and not control.is_context_ref:
|
52 |
+
control.order = order
|
53 |
+
order -= 1
|
54 |
+
ref_set.add(control)
|
55 |
+
ref_set.update(get_refcn(control.previous_controlnet, order=order))
|
56 |
+
return ref_set
|
57 |
+
|
58 |
+
def get_lllitecn(control: ControlBase):
|
59 |
+
cn_dict: dict[ControlLLLiteAdvanced,None] = {}
|
60 |
+
if control is None:
|
61 |
+
return cn_dict
|
62 |
+
if type(control) == ControlLLLiteAdvanced:
|
63 |
+
cn_dict[control] = None
|
64 |
+
cn_dict.update(get_lllitecn(control.previous_controlnet))
|
65 |
+
return cn_dict
|
66 |
+
|
67 |
+
def acn_sample(model: ModelPatcher, *args, **kwargs):
|
68 |
+
controlnets_modified = False
|
69 |
+
orig_positive = args[-3]
|
70 |
+
orig_negative = args[-2]
|
71 |
+
try:
|
72 |
+
orig_model_options = model.model_options
|
73 |
+
# check if positive or negative conds contain ref cn
|
74 |
+
positive = args[-3]
|
75 |
+
negative = args[-2]
|
76 |
+
# if context options present, perform some special actions that may be required
|
77 |
+
context_refs = []
|
78 |
+
if has_sliding_context_windows(model):
|
79 |
+
model.model_options = model.model_options.copy()
|
80 |
+
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
81 |
+
# convert all CNs to Advanced if needed
|
82 |
+
controlnets_modified, positive, negative = support_sliding_context_windows(model, positive, negative)
|
83 |
+
if controlnets_modified:
|
84 |
+
args = list(args)
|
85 |
+
args[-3] = positive
|
86 |
+
args[-2] = negative
|
87 |
+
args = tuple(args)
|
88 |
+
# enable ContextRef, if requested
|
89 |
+
existing_contextref_obj = get_contextref_obj(model)
|
90 |
+
if existing_contextref_obj is not None:
|
91 |
+
context_refs = handle_context_ref_setup(existing_contextref_obj, model.model_options["transformer_options"], positive, negative)
|
92 |
+
controlnets_modified = True
|
93 |
+
# look for Advanced ControlNets that will require intervention to work
|
94 |
+
ref_set = set()
|
95 |
+
lllite_dict: dict[ControlLLLiteAdvanced, None] = {} # dicts preserve insertion order since py3.7
|
96 |
+
if positive is not None:
|
97 |
+
for cond in positive:
|
98 |
+
if "control" in cond[1]:
|
99 |
+
ref_set.update(get_refcn(cond[1]["control"]))
|
100 |
+
lllite_dict.update(get_lllitecn(cond[1]["control"]))
|
101 |
+
if negative is not None:
|
102 |
+
for cond in negative:
|
103 |
+
if "control" in cond[1]:
|
104 |
+
ref_set.update(get_refcn(cond[1]["control"]))
|
105 |
+
lllite_dict.update(get_lllitecn(cond[1]["control"]))
|
106 |
+
# if lllite found, apply patches to a cloned model_options, and continue
|
107 |
+
if len(lllite_dict) > 0:
|
108 |
+
lllite_list = list(lllite_dict.keys())
|
109 |
+
model.model_options = model.model_options.copy()
|
110 |
+
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
111 |
+
lllite_list.reverse() # reverse so that patches will be applied in expected order
|
112 |
+
for lll in lllite_list:
|
113 |
+
lll.live_model_patches(model.model_options)
|
114 |
+
# if no ref cn found, do original function immediately
|
115 |
+
if len(ref_set) == 0 and len(context_refs) == 0:
|
116 |
+
return orig_comfy_sample(model, *args, **kwargs)
|
117 |
+
# otherwise, injection time
|
118 |
+
try:
|
119 |
+
# inject
|
120 |
+
# storage for all Reference-related injections
|
121 |
+
reference_injections = ReferenceInjections()
|
122 |
+
|
123 |
+
# first, handle attn module injection
|
124 |
+
all_modules = torch_dfs(model.model)
|
125 |
+
attn_modules: list[RefBasicTransformerBlock] = []
|
126 |
+
for module in all_modules:
|
127 |
+
if isinstance(module, BasicTransformerBlock):
|
128 |
+
attn_modules.append(module)
|
129 |
+
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
|
130 |
+
attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
|
131 |
+
for i, module in enumerate(attn_modules):
|
132 |
+
injection_holder = InjectionBasicTransformerBlockHolder(block=module, idx=i)
|
133 |
+
injection_holder.attn_weight = float(i) / float(len(attn_modules))
|
134 |
+
if hasattr(module, "_forward"): # backward compatibility
|
135 |
+
module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
|
136 |
+
else:
|
137 |
+
module.forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
|
138 |
+
module.injection_holder = injection_holder
|
139 |
+
reference_injections.attn_modules.append(module)
|
140 |
+
# figure out which module is middle block
|
141 |
+
if hasattr(model.model.diffusion_model, "middle_block"):
|
142 |
+
mid_modules = torch_dfs(model.model.diffusion_model.middle_block)
|
143 |
+
mid_attn_modules: list[RefBasicTransformerBlock] = [module for module in mid_modules if isinstance(module, BasicTransformerBlock)]
|
144 |
+
for module in mid_attn_modules:
|
145 |
+
module.injection_holder.is_middle = True
|
146 |
+
|
147 |
+
# next, handle gn module injection (TimestepEmbedSequential)
|
148 |
+
# TODO: figure out the logic behind these hardcoded indexes
|
149 |
+
if type(model.model).__name__ == "SDXL":
|
150 |
+
input_block_indices = [4, 5, 7, 8]
|
151 |
+
output_block_indices = [0, 1, 2, 3, 4, 5]
|
152 |
+
else:
|
153 |
+
input_block_indices = [4, 5, 7, 8, 10, 11]
|
154 |
+
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
|
155 |
+
if hasattr(model.model.diffusion_model, "middle_block"):
|
156 |
+
module = model.model.diffusion_model.middle_block
|
157 |
+
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=0, is_middle=True)
|
158 |
+
injection_holder.gn_weight = 0.0
|
159 |
+
module.injection_holder = injection_holder
|
160 |
+
reference_injections.gn_modules.append(module)
|
161 |
+
for w, i in enumerate(input_block_indices):
|
162 |
+
module = model.model.diffusion_model.input_blocks[i]
|
163 |
+
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_input=True)
|
164 |
+
injection_holder.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
|
165 |
+
module.injection_holder = injection_holder
|
166 |
+
reference_injections.gn_modules.append(module)
|
167 |
+
for w, i in enumerate(output_block_indices):
|
168 |
+
module = model.model.diffusion_model.output_blocks[i]
|
169 |
+
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_output=True)
|
170 |
+
injection_holder.gn_weight = float(w) / float(len(output_block_indices))
|
171 |
+
module.injection_holder = injection_holder
|
172 |
+
reference_injections.gn_modules.append(module)
|
173 |
+
# hack gn_module forwards and update weights
|
174 |
+
for i, module in enumerate(reference_injections.gn_modules):
|
175 |
+
module.injection_holder.gn_weight *= 2
|
176 |
+
|
177 |
+
# handle diffusion_model forward injection
|
178 |
+
reference_injections.diffusion_model_orig_forward = model.model.diffusion_model.forward
|
179 |
+
model.model.diffusion_model.forward = factory_forward_inject_UNetModel(reference_injections).__get__(model.model.diffusion_model, type(model.model.diffusion_model))
|
180 |
+
# store ordered ref cns in model's transformer options
|
181 |
+
new_model_options = model.model_options.copy()
|
182 |
+
new_model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
183 |
+
ref_list: list[ReferenceAdvanced] = list(ref_set)
|
184 |
+
new_model_options["transformer_options"][REF_CONTROL_LIST_ALL] = sorted(ref_list, key=lambda x: x.order)
|
185 |
+
new_model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC] = reference_injections.clean_contextref_module_mem
|
186 |
+
model.model_options = new_model_options
|
187 |
+
# continue with original function
|
188 |
+
return orig_comfy_sample(model, *args, **kwargs)
|
189 |
+
finally:
|
190 |
+
# cleanup injections
|
191 |
+
# restore attn modules
|
192 |
+
attn_modules: list[RefBasicTransformerBlock] = reference_injections.attn_modules
|
193 |
+
for module in attn_modules:
|
194 |
+
module.injection_holder.restore(module)
|
195 |
+
module.injection_holder.clean_all()
|
196 |
+
del module.injection_holder
|
197 |
+
del attn_modules
|
198 |
+
# restore gn modules
|
199 |
+
gn_modules: list[RefTimestepEmbedSequential] = reference_injections.gn_modules
|
200 |
+
for module in gn_modules:
|
201 |
+
module.injection_holder.restore(module)
|
202 |
+
module.injection_holder.clean_all()
|
203 |
+
del module.injection_holder
|
204 |
+
del gn_modules
|
205 |
+
# restore diffusion_model forward function
|
206 |
+
model.model.diffusion_model.forward = reference_injections.diffusion_model_orig_forward.__get__(model.model.diffusion_model, type(model.model.diffusion_model))
|
207 |
+
# cleanup
|
208 |
+
reference_injections.cleanup()
|
209 |
+
finally:
|
210 |
+
# restore model_options
|
211 |
+
model.model_options = orig_model_options
|
212 |
+
# restore controlnets in conds, if needed
|
213 |
+
if controlnets_modified:
|
214 |
+
restore_all_controlnet_conns([orig_positive, orig_negative])
|
215 |
+
|
216 |
+
return acn_sample
|
ComfyUI-Advanced-ControlNet/adv_control/utils.py
ADDED
@@ -0,0 +1,981 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import Callable, Union
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
import torch.nn.functional
|
6 |
+
from einops import rearrange
|
7 |
+
import numpy as np
|
8 |
+
import math
|
9 |
+
|
10 |
+
import comfy.ops
|
11 |
+
import comfy.utils
|
12 |
+
import comfy.sample
|
13 |
+
import comfy.samplers
|
14 |
+
import comfy.model_base
|
15 |
+
|
16 |
+
from comfy.controlnet import ControlBase
|
17 |
+
from comfy.model_patcher import ModelPatcher
|
18 |
+
from comfy.sd import VAE
|
19 |
+
|
20 |
+
from .logger import logger
|
21 |
+
|
22 |
+
BIGMIN = -(2**53-1)
|
23 |
+
BIGMAX = (2**53-1)
|
24 |
+
|
25 |
+
ORIG_PREVIOUS_CONTROLNET = "_orig_previous_controlnet"
|
26 |
+
CONTROL_INIT_BY_ACN = "_control_init_by_ACN"
|
27 |
+
|
28 |
+
|
29 |
+
def load_torch_file_with_dict_factory(controlnet_data: dict[str, Tensor], orig_load_torch_file: Callable):
|
30 |
+
def load_torch_file_with_dict(*args, **kwargs):
|
31 |
+
# immediately restore load_torch_file to original version
|
32 |
+
comfy.utils.load_torch_file = orig_load_torch_file
|
33 |
+
return controlnet_data
|
34 |
+
return load_torch_file_with_dict
|
35 |
+
|
36 |
+
# wrapping len function so that it will save the thing len is trying to get the length of;
|
37 |
+
# this will be assumed to be the cond_or_uncond variable;
|
38 |
+
# automatically restores len to original function after running
|
39 |
+
def wrapper_len_factory(orig_len: Callable) -> Callable:
|
40 |
+
def wrapper_len(*args, **kwargs):
|
41 |
+
cond_or_uncond = args[0]
|
42 |
+
real_length = orig_len(*args, **kwargs)
|
43 |
+
if real_length > 0 and type(cond_or_uncond) == list and isinstance(cond_or_uncond[0], int) and (cond_or_uncond[0] in [0, 1]):
|
44 |
+
try:
|
45 |
+
to_return = IntWithCondOrUncond(real_length)
|
46 |
+
setattr(to_return, "cond_or_uncond", cond_or_uncond)
|
47 |
+
return to_return
|
48 |
+
finally:
|
49 |
+
__builtins__["len"] = orig_len
|
50 |
+
else:
|
51 |
+
return real_length
|
52 |
+
return wrapper_len
|
53 |
+
|
54 |
+
# wrapping cond_cat function so that it will wrap around len function to get cond_or_uncond variable value
|
55 |
+
# from comfy.samplers.calc_conds_batch
|
56 |
+
def wrapper_cond_cat_factory(orig_cond_cat: Callable):
|
57 |
+
def wrapper_cond_cat(*args, **kwargs):
|
58 |
+
__builtins__["len"] = wrapper_len_factory(__builtins__["len"])
|
59 |
+
return orig_cond_cat(*args, **kwargs)
|
60 |
+
return wrapper_cond_cat
|
61 |
+
orig_cond_cat = comfy.samplers.cond_cat
|
62 |
+
comfy.samplers.cond_cat = wrapper_cond_cat_factory(orig_cond_cat)
|
63 |
+
|
64 |
+
|
65 |
+
# wrapping apply_model so that len function will be cleaned up fairly soon after being injected
|
66 |
+
def apply_model_uncond_cleanup_factory(orig_apply_model, orig_len):
|
67 |
+
def apply_model_uncond_cleanup_wrapper(self, *args, **kwargs):
|
68 |
+
__builtins__["len"] = orig_len
|
69 |
+
return orig_apply_model(self, *args, **kwargs)
|
70 |
+
return apply_model_uncond_cleanup_wrapper
|
71 |
+
global_orig_len = __builtins__["len"]
|
72 |
+
orig_apply_model = comfy.model_base.BaseModel.apply_model
|
73 |
+
comfy.model_base.BaseModel.apply_model = apply_model_uncond_cleanup_factory(orig_apply_model, global_orig_len)
|
74 |
+
|
75 |
+
|
76 |
+
def uncond_multiplier_check_cn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable:
|
77 |
+
def contains_uncond_multiplier(control: Union[ControlBase, 'AdvancedControlBase']):
|
78 |
+
if control is None:
|
79 |
+
return False
|
80 |
+
if not isinstance(control, AdvancedControlBase):
|
81 |
+
return contains_uncond_multiplier(control.previous_controlnet)
|
82 |
+
# check if weights_override has an uncond_multiplier
|
83 |
+
if control.weights_override is not None and control.weights_override.has_uncond_multiplier:
|
84 |
+
return True
|
85 |
+
# check if any timestep_keyframes have an uncond_multiplier on their weights
|
86 |
+
if control.timestep_keyframes is not None:
|
87 |
+
for tk in control.timestep_keyframes.keyframes:
|
88 |
+
if tk.has_control_weights() and tk.control_weights.has_uncond_multiplier:
|
89 |
+
return True
|
90 |
+
return contains_uncond_multiplier(control.previous_controlnet)
|
91 |
+
|
92 |
+
# check if positive or negative conds contain Adv. Cns that use multiply_negative on weights
|
93 |
+
def uncond_multiplier_check_cn_sample(model: ModelPatcher, *args, **kwargs):
|
94 |
+
positive = args[-3]
|
95 |
+
negative = args[-2]
|
96 |
+
has_uncond_multiplier = False
|
97 |
+
if positive is not None:
|
98 |
+
for cond in positive:
|
99 |
+
if "control" in cond[1]:
|
100 |
+
has_uncond_multiplier = contains_uncond_multiplier(cond[1]["control"])
|
101 |
+
if has_uncond_multiplier:
|
102 |
+
break
|
103 |
+
if negative is not None and not has_uncond_multiplier:
|
104 |
+
for cond in negative:
|
105 |
+
if "control" in cond[1]:
|
106 |
+
has_uncond_multiplier = contains_uncond_multiplier(cond[1]["control"])
|
107 |
+
if has_uncond_multiplier:
|
108 |
+
break
|
109 |
+
try:
|
110 |
+
# if uncond_multiplier found, continue to use wrapped version of function
|
111 |
+
if has_uncond_multiplier:
|
112 |
+
return orig_comfy_sample(model, *args, **kwargs)
|
113 |
+
# otherwise, use original version of function to prevent even the smallest of slowdowns (0.XX%)
|
114 |
+
try:
|
115 |
+
wrapped_cond_cat = comfy.samplers.cond_cat
|
116 |
+
comfy.samplers.cond_cat = orig_cond_cat
|
117 |
+
return orig_comfy_sample(model, *args, **kwargs)
|
118 |
+
finally:
|
119 |
+
comfy.samplers.cond_cat = wrapped_cond_cat
|
120 |
+
finally:
|
121 |
+
# make sure len function is unwrapped by the time sampling is done, just in case
|
122 |
+
__builtins__["len"] = global_orig_len
|
123 |
+
return uncond_multiplier_check_cn_sample
|
124 |
+
# inject sample functions
|
125 |
+
comfy.sample.sample = uncond_multiplier_check_cn_sample_factory(comfy.sample.sample)
|
126 |
+
comfy.sample.sample_custom = uncond_multiplier_check_cn_sample_factory(comfy.sample.sample_custom, is_custom=True)
|
127 |
+
|
128 |
+
|
129 |
+
class IntWithCondOrUncond(int):
|
130 |
+
def __new__(cls, *args, **kwargs):
|
131 |
+
return super(IntWithCondOrUncond, cls).__new__(cls, *args, **kwargs)
|
132 |
+
|
133 |
+
def __init__(self, *args, **kwargs):
|
134 |
+
super().__init__()
|
135 |
+
self.cond_or_uncond = None
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
def get_properly_arranged_t2i_weights(initial_weights: list[float]):
|
140 |
+
new_weights = []
|
141 |
+
new_weights.extend([initial_weights[0]]*3)
|
142 |
+
new_weights.extend([initial_weights[1]]*3)
|
143 |
+
new_weights.extend([initial_weights[2]]*3)
|
144 |
+
new_weights.extend([initial_weights[3]]*3)
|
145 |
+
return new_weights
|
146 |
+
|
147 |
+
|
148 |
+
class ControlWeightType:
|
149 |
+
DEFAULT = "default"
|
150 |
+
UNIVERSAL = "universal"
|
151 |
+
T2IADAPTER = "t2iadapter"
|
152 |
+
CONTROLNET = "controlnet"
|
153 |
+
CONTROLNETPLUSPLUS = "controlnet++"
|
154 |
+
CONTROLLORA = "controllora"
|
155 |
+
CONTROLLLLITE = "controllllite"
|
156 |
+
SVD_CONTROLNET = "svd_controlnet"
|
157 |
+
SPARSECTRL = "sparsectrl"
|
158 |
+
|
159 |
+
|
160 |
+
class ControlWeights:
|
161 |
+
def __init__(self, weight_type: str, base_multiplier: float=1.0,
|
162 |
+
weights_input: list[float]=None, weights_middle: list[float]=None, weights_output: list[float]=None,
|
163 |
+
weight_func: Callable=None, weight_mask: Tensor=None,
|
164 |
+
uncond_multiplier=1.0, uncond_mask: Tensor=None, extras: dict[str]={},):
|
165 |
+
self.weight_type = weight_type
|
166 |
+
self.base_multiplier = base_multiplier
|
167 |
+
self.weights_input = weights_input
|
168 |
+
self.weights_middle = weights_middle
|
169 |
+
self.weights_output = weights_output
|
170 |
+
self.weight_func = weight_func
|
171 |
+
self.weight_mask = weight_mask
|
172 |
+
self.uncond_multiplier = float(uncond_multiplier)
|
173 |
+
self.has_uncond_multiplier = not math.isclose(self.uncond_multiplier, 1.0)
|
174 |
+
self.uncond_mask = uncond_mask if uncond_mask is not None else 1.0
|
175 |
+
self.has_uncond_mask = uncond_mask is not None
|
176 |
+
self.extras = extras
|
177 |
+
|
178 |
+
def get(self, idx: int, control: dict[str, list[Tensor]], key: str, default=1.0) -> Union[float, Tensor]:
|
179 |
+
# if weight_func present, use it
|
180 |
+
if self.weight_func is not None:
|
181 |
+
return self.weight_func(idx=idx, control=control, key=key)
|
182 |
+
# if weights is not none, return index
|
183 |
+
relevant_weights = None
|
184 |
+
if key == "middle":
|
185 |
+
relevant_weights = self.weights_middle
|
186 |
+
elif key == "input":
|
187 |
+
relevant_weights = self.weights_input
|
188 |
+
if relevant_weights is not None:
|
189 |
+
relevant_weights = list(reversed(relevant_weights))
|
190 |
+
else:
|
191 |
+
relevant_weights = self.weights_output
|
192 |
+
if relevant_weights is None:
|
193 |
+
return default
|
194 |
+
elif idx >= len(relevant_weights):
|
195 |
+
return default
|
196 |
+
return relevant_weights[idx]
|
197 |
+
|
198 |
+
def copy_with_new_weights(self, new_weights_input: list[float]=None, new_weights_middle: list[float]=None, new_weights_output: list[float]=None,
|
199 |
+
new_weight_func: Callable=None):
|
200 |
+
return ControlWeights(weight_type=self.weight_type, base_multiplier=self.base_multiplier,
|
201 |
+
weights_input=new_weights_input, weights_middle=new_weights_middle, weights_output=new_weights_output,
|
202 |
+
weight_func=new_weight_func, weight_mask=self.weight_mask,
|
203 |
+
uncond_multiplier=self.uncond_multiplier, extras=self.extras)
|
204 |
+
|
205 |
+
@classmethod
|
206 |
+
def default(cls, extras: dict[str]={}):
|
207 |
+
return cls(ControlWeightType.DEFAULT, extras=extras)
|
208 |
+
|
209 |
+
@classmethod
|
210 |
+
def universal(cls, base_multiplier: float, uncond_multiplier: float=1.0, extras: dict[str]={}):
|
211 |
+
return cls(ControlWeightType.UNIVERSAL, base_multiplier=base_multiplier, uncond_multiplier=uncond_multiplier, extras=extras)
|
212 |
+
|
213 |
+
@classmethod
|
214 |
+
def universal_mask(cls, weight_mask: Tensor, uncond_multiplier: float=1.0, extras: dict[str]={}):
|
215 |
+
return cls(ControlWeightType.UNIVERSAL, weight_mask=weight_mask, uncond_multiplier=uncond_multiplier, extras=extras)
|
216 |
+
|
217 |
+
@classmethod
|
218 |
+
def t2iadapter(cls, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}):
|
219 |
+
return cls(ControlWeightType.T2IADAPTER, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras)
|
220 |
+
|
221 |
+
@classmethod
|
222 |
+
def controlnet(cls, weights_output: list[float]=None, weights_middle: list[float]=None, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}):
|
223 |
+
return cls(ControlWeightType.CONTROLNET, weights_output=weights_output, weights_middle=weights_middle, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras)
|
224 |
+
|
225 |
+
@classmethod
|
226 |
+
def controllora(cls, weights_output: list[float]=None, weights_middle: list[float]=None, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}):
|
227 |
+
return cls(ControlWeightType.CONTROLLORA, weights_output=weights_output, weights_middle=weights_middle, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras)
|
228 |
+
|
229 |
+
@classmethod
|
230 |
+
def controllllite(cls, weights_output: list[float]=None, weights_middle: list[float]=None, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}):
|
231 |
+
return cls(ControlWeightType.CONTROLLLLITE, weights_output=weights_output, weights_middle=weights_middle, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras)
|
232 |
+
|
233 |
+
|
234 |
+
class StrengthInterpolation:
|
235 |
+
LINEAR = "linear"
|
236 |
+
EASE_IN = "ease-in"
|
237 |
+
EASE_OUT = "ease-out"
|
238 |
+
EASE_IN_OUT = "ease-in-out"
|
239 |
+
NONE = "none"
|
240 |
+
|
241 |
+
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
|
242 |
+
_LIST_WITH_NONE = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT, NONE]
|
243 |
+
|
244 |
+
@classmethod
|
245 |
+
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
|
246 |
+
diff = num_to - num_from
|
247 |
+
if method == cls.LINEAR:
|
248 |
+
weights = torch.linspace(num_from, num_to, length)
|
249 |
+
elif method == cls.EASE_IN:
|
250 |
+
index = torch.linspace(0, 1, length)
|
251 |
+
weights = diff * np.power(index, 2) + num_from
|
252 |
+
elif method == cls.EASE_OUT:
|
253 |
+
index = torch.linspace(0, 1, length)
|
254 |
+
weights = diff * (1 - np.power(1 - index, 2)) + num_from
|
255 |
+
elif method == cls.EASE_IN_OUT:
|
256 |
+
index = torch.linspace(0, 1, length)
|
257 |
+
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
|
258 |
+
else:
|
259 |
+
raise ValueError(f"Unrecognized interpolation method '{method}'.")
|
260 |
+
if reverse:
|
261 |
+
weights = weights.flip(dims=(0,))
|
262 |
+
return weights
|
263 |
+
|
264 |
+
|
265 |
+
class LatentKeyframe:
|
266 |
+
def __init__(self, batch_index: int, strength: float) -> None:
|
267 |
+
self.batch_index = batch_index
|
268 |
+
self.strength = strength
|
269 |
+
|
270 |
+
|
271 |
+
# always maintain sorted state (by batch_index of LatentKeyframe)
|
272 |
+
class LatentKeyframeGroup:
|
273 |
+
def __init__(self) -> None:
|
274 |
+
self.keyframes: list[LatentKeyframe] = []
|
275 |
+
|
276 |
+
def add(self, keyframe: LatentKeyframe) -> None:
|
277 |
+
added = False
|
278 |
+
# replace existing keyframe if same batch_index
|
279 |
+
for i in range(len(self.keyframes)):
|
280 |
+
if self.keyframes[i].batch_index == keyframe.batch_index:
|
281 |
+
self.keyframes[i] = keyframe
|
282 |
+
added = True
|
283 |
+
break
|
284 |
+
if not added:
|
285 |
+
self.keyframes.append(keyframe)
|
286 |
+
self.keyframes.sort(key=lambda k: k.batch_index)
|
287 |
+
|
288 |
+
def get_index(self, index: int) -> Union[LatentKeyframe, None]:
|
289 |
+
try:
|
290 |
+
return self.keyframes[index]
|
291 |
+
except IndexError:
|
292 |
+
return None
|
293 |
+
|
294 |
+
def __getitem__(self, index) -> LatentKeyframe:
|
295 |
+
return self.keyframes[index]
|
296 |
+
|
297 |
+
def is_empty(self) -> bool:
|
298 |
+
return len(self.keyframes) == 0
|
299 |
+
|
300 |
+
def clone(self) -> 'LatentKeyframeGroup':
|
301 |
+
cloned = LatentKeyframeGroup()
|
302 |
+
for tk in self.keyframes:
|
303 |
+
cloned.add(tk)
|
304 |
+
return cloned
|
305 |
+
|
306 |
+
|
307 |
+
class TimestepKeyframe:
|
308 |
+
def __init__(self,
|
309 |
+
start_percent: float = 0.0,
|
310 |
+
strength: float = 1.0,
|
311 |
+
control_weights: ControlWeights = None,
|
312 |
+
latent_keyframes: LatentKeyframeGroup = None,
|
313 |
+
null_latent_kf_strength: float = 0.0,
|
314 |
+
inherit_missing: bool = True,
|
315 |
+
guarantee_steps: int = 1,
|
316 |
+
mask_hint_orig: Tensor = None) -> None:
|
317 |
+
self.start_percent = float(start_percent)
|
318 |
+
self.start_t = 999999999.9
|
319 |
+
self.strength = strength
|
320 |
+
self.control_weights = control_weights
|
321 |
+
self.latent_keyframes = latent_keyframes
|
322 |
+
self.null_latent_kf_strength = null_latent_kf_strength
|
323 |
+
self.inherit_missing = inherit_missing
|
324 |
+
self.guarantee_steps = guarantee_steps
|
325 |
+
self.mask_hint_orig = mask_hint_orig
|
326 |
+
|
327 |
+
def has_control_weights(self):
|
328 |
+
return self.control_weights is not None
|
329 |
+
|
330 |
+
def has_latent_keyframes(self):
|
331 |
+
return self.latent_keyframes is not None
|
332 |
+
|
333 |
+
def has_mask_hint(self):
|
334 |
+
return self.mask_hint_orig is not None
|
335 |
+
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def default() -> 'TimestepKeyframe':
|
339 |
+
return TimestepKeyframe(start_percent=0.0, guarantee_steps=0)
|
340 |
+
|
341 |
+
|
342 |
+
# always maintain sorted state (by start_percent of TimestepKeyFrame)
|
343 |
+
class TimestepKeyframeGroup:
|
344 |
+
def __init__(self) -> None:
|
345 |
+
self.keyframes: list[TimestepKeyframe] = []
|
346 |
+
self.keyframes.append(TimestepKeyframe.default())
|
347 |
+
|
348 |
+
def add(self, keyframe: TimestepKeyframe) -> None:
|
349 |
+
# add to end of list, then sort
|
350 |
+
self.keyframes.append(keyframe)
|
351 |
+
self.keyframes = get_sorted_list_via_attr(self.keyframes, attr="start_percent")
|
352 |
+
|
353 |
+
def get_index(self, index: int) -> Union[TimestepKeyframe, None]:
|
354 |
+
try:
|
355 |
+
return self.keyframes[index]
|
356 |
+
except IndexError:
|
357 |
+
return None
|
358 |
+
|
359 |
+
def has_index(self, index: int) -> int:
|
360 |
+
return index >=0 and index < len(self.keyframes)
|
361 |
+
|
362 |
+
def __getitem__(self, index) -> TimestepKeyframe:
|
363 |
+
return self.keyframes[index]
|
364 |
+
|
365 |
+
def __len__(self) -> int:
|
366 |
+
return len(self.keyframes)
|
367 |
+
|
368 |
+
def is_empty(self) -> bool:
|
369 |
+
return len(self.keyframes) == 0
|
370 |
+
|
371 |
+
def clone(self) -> 'TimestepKeyframeGroup':
|
372 |
+
cloned = TimestepKeyframeGroup()
|
373 |
+
# already sorted, so don't use add function to make cloning quicker
|
374 |
+
for tk in self.keyframes:
|
375 |
+
cloned.keyframes.append(tk)
|
376 |
+
return cloned
|
377 |
+
|
378 |
+
@classmethod
|
379 |
+
def default(cls, keyframe: TimestepKeyframe) -> 'TimestepKeyframeGroup':
|
380 |
+
group = cls()
|
381 |
+
group.keyframes[0] = keyframe
|
382 |
+
return group
|
383 |
+
|
384 |
+
|
385 |
+
class AbstractPreprocWrapper:
|
386 |
+
error_msg = "Invalid use of [InsertHere] output. The output of [InsertHere] preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
|
387 |
+
def __init__(self, condhint):
|
388 |
+
self.condhint = condhint
|
389 |
+
|
390 |
+
def movedim(self, *args, **kwargs):
|
391 |
+
return self
|
392 |
+
|
393 |
+
def __getattr__(self, *args, **kwargs):
|
394 |
+
raise AttributeError(self.error_msg)
|
395 |
+
|
396 |
+
def __setattr__(self, name, value):
|
397 |
+
if name != "condhint":
|
398 |
+
raise AttributeError(self.error_msg)
|
399 |
+
super().__setattr__(name, value)
|
400 |
+
|
401 |
+
def __iter__(self, *args, **kwargs):
|
402 |
+
raise AttributeError(self.error_msg)
|
403 |
+
|
404 |
+
def __next__(self, *args, **kwargs):
|
405 |
+
raise AttributeError(self.error_msg)
|
406 |
+
|
407 |
+
def __len__(self, *args, **kwargs):
|
408 |
+
raise AttributeError(self.error_msg)
|
409 |
+
|
410 |
+
def __getitem__(self, *args, **kwargs):
|
411 |
+
raise AttributeError(self.error_msg)
|
412 |
+
|
413 |
+
def __setitem__(self, *args, **kwargs):
|
414 |
+
raise AttributeError(self.error_msg)
|
415 |
+
|
416 |
+
|
417 |
+
# depending on model, AnimateDiff may inject into GroupNorm, so make sure GroupNorm will be clean
|
418 |
+
class disable_weight_init_clean_groupnorm(comfy.ops.disable_weight_init):
|
419 |
+
class GroupNorm(comfy.ops.disable_weight_init.GroupNorm):
|
420 |
+
def forward_comfy_cast_weights(self, input):
|
421 |
+
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
422 |
+
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
423 |
+
|
424 |
+
def forward(self, input):
|
425 |
+
if self.comfy_cast_weights:
|
426 |
+
return self.forward_comfy_cast_weights(input)
|
427 |
+
else:
|
428 |
+
return torch.nn.functional.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
|
429 |
+
|
430 |
+
class manual_cast_clean_groupnorm(comfy.ops.manual_cast):
|
431 |
+
class GroupNorm(disable_weight_init_clean_groupnorm.GroupNorm):
|
432 |
+
comfy_cast_weights = True
|
433 |
+
|
434 |
+
|
435 |
+
# adapted from comfy/sample.py
|
436 |
+
def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False, match_shape=False, flux_shape=None):
|
437 |
+
mask = mask.clone()
|
438 |
+
if flux_shape is not None:
|
439 |
+
multiplier = multiplier * 0.5
|
440 |
+
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(round(flux_shape[-2]*multiplier), round(flux_shape[-1]*multiplier)), mode="bilinear")
|
441 |
+
mask = rearrange(mask, "b c h w -> b (h w) c")
|
442 |
+
else:
|
443 |
+
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(round(shape[-2]*multiplier), round(shape[-1]*multiplier)), mode="bilinear")
|
444 |
+
if match_dim1:
|
445 |
+
if match_shape and len(shape) < 4:
|
446 |
+
raise Exception(f"match_dim1 cannot be True if shape is under 4 dims; was {len(shape)}.")
|
447 |
+
mask = torch.cat([mask] * shape[1], dim=1)
|
448 |
+
if match_shape and len(shape) == 3 and len(mask.shape) != 3:
|
449 |
+
mask = mask.squeeze(1)
|
450 |
+
return mask
|
451 |
+
|
452 |
+
|
453 |
+
# applies min-max normalization, from:
|
454 |
+
# https://stackoverflow.com/questions/68791508/min-max-normalization-of-a-tensor-in-pytorch
|
455 |
+
def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0):
|
456 |
+
x_min, x_max = x.min(), x.max()
|
457 |
+
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min
|
458 |
+
|
459 |
+
def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0):
|
460 |
+
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min
|
461 |
+
|
462 |
+
def extend_to_batch_size(tensor: Tensor, batch_size: int):
|
463 |
+
if tensor.shape[0] > batch_size:
|
464 |
+
return tensor[:batch_size]
|
465 |
+
elif tensor.shape[0] < batch_size:
|
466 |
+
remainder = batch_size-tensor.shape[0]
|
467 |
+
return torch.cat([tensor] + [tensor[-1:]]*remainder, dim=0)
|
468 |
+
return tensor
|
469 |
+
|
470 |
+
def broadcast_image_to_extend(tensor, target_batch_size, batched_number, except_one=True):
|
471 |
+
current_batch_size = tensor.shape[0]
|
472 |
+
#print(current_batch_size, target_batch_size)
|
473 |
+
if except_one and current_batch_size == 1:
|
474 |
+
return tensor
|
475 |
+
|
476 |
+
per_batch = target_batch_size // batched_number
|
477 |
+
tensor = tensor[:per_batch]
|
478 |
+
|
479 |
+
if per_batch > tensor.shape[0]:
|
480 |
+
tensor = extend_to_batch_size(tensor=tensor, batch_size=per_batch)
|
481 |
+
|
482 |
+
current_batch_size = tensor.shape[0]
|
483 |
+
if current_batch_size == target_batch_size:
|
484 |
+
return tensor
|
485 |
+
else:
|
486 |
+
return torch.cat([tensor] * batched_number, dim=0)
|
487 |
+
|
488 |
+
|
489 |
+
# from https://stackoverflow.com/a/24621200
|
490 |
+
def deepcopy_with_sharing(obj, shared_attribute_names, memo=None):
|
491 |
+
'''
|
492 |
+
Deepcopy an object, except for a given list of attributes, which should
|
493 |
+
be shared between the original object and its copy.
|
494 |
+
|
495 |
+
obj is some object
|
496 |
+
shared_attribute_names: A list of strings identifying the attributes that
|
497 |
+
should be shared between the original and its copy.
|
498 |
+
memo is the dictionary passed into __deepcopy__. Ignore this argument if
|
499 |
+
not calling from within __deepcopy__.
|
500 |
+
'''
|
501 |
+
assert isinstance(shared_attribute_names, (list, tuple))
|
502 |
+
|
503 |
+
shared_attributes = {k: getattr(obj, k) for k in shared_attribute_names}
|
504 |
+
|
505 |
+
if hasattr(obj, '__deepcopy__'):
|
506 |
+
# Do hack to prevent infinite recursion in call to deepcopy
|
507 |
+
deepcopy_method = obj.__deepcopy__
|
508 |
+
obj.__deepcopy__ = None
|
509 |
+
|
510 |
+
for attr in shared_attribute_names:
|
511 |
+
del obj.__dict__[attr]
|
512 |
+
|
513 |
+
clone = deepcopy(obj)
|
514 |
+
|
515 |
+
for attr, val in shared_attributes.items():
|
516 |
+
setattr(obj, attr, val)
|
517 |
+
setattr(clone, attr, val)
|
518 |
+
|
519 |
+
if hasattr(obj, '__deepcopy__'):
|
520 |
+
# Undo hack
|
521 |
+
obj.__deepcopy__ = deepcopy_method
|
522 |
+
del clone.__deepcopy__
|
523 |
+
|
524 |
+
return clone
|
525 |
+
|
526 |
+
|
527 |
+
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
528 |
+
if not objects:
|
529 |
+
return objects
|
530 |
+
elif len(objects) <= 1:
|
531 |
+
return [x for x in objects]
|
532 |
+
# now that we know we have to sort, do it following these rules:
|
533 |
+
# a) if objects have same value of attribute, maintain their relative order
|
534 |
+
# b) perform sorting of the groups of objects with same attributes
|
535 |
+
unique_attrs = {}
|
536 |
+
for o in objects:
|
537 |
+
val_attr = getattr(o, attr)
|
538 |
+
attr_list: list = unique_attrs.get(val_attr, list())
|
539 |
+
attr_list.append(o)
|
540 |
+
if val_attr not in unique_attrs:
|
541 |
+
unique_attrs[val_attr] = attr_list
|
542 |
+
# now that we have the unique attr values grouped together in relative order, sort them by key
|
543 |
+
sorted_attrs = dict(sorted(unique_attrs.items()))
|
544 |
+
# now flatten out the dict into a list to return
|
545 |
+
sorted_list = []
|
546 |
+
for object_list in sorted_attrs.values():
|
547 |
+
sorted_list.extend(object_list)
|
548 |
+
return sorted_list
|
549 |
+
|
550 |
+
|
551 |
+
# DFS Search for Torch.nn.Module, Written by Lvmin
|
552 |
+
def torch_dfs(model: torch.nn.Module):
|
553 |
+
result = [model]
|
554 |
+
for child in model.children():
|
555 |
+
result += torch_dfs(child)
|
556 |
+
return result
|
557 |
+
|
558 |
+
|
559 |
+
class WeightTypeException(TypeError):
|
560 |
+
"Raised when weight not compatible with AdvancedControlBase object"
|
561 |
+
pass
|
562 |
+
|
563 |
+
|
564 |
+
class AdvancedControlBase:
|
565 |
+
def __init__(self, base: ControlBase, timestep_keyframes: TimestepKeyframeGroup, weights_default: ControlWeights, require_model=False, require_vae=False, allow_condhint_latents=False):
|
566 |
+
self.base = base
|
567 |
+
self.compatible_weights = [ControlWeightType.UNIVERSAL, ControlWeightType.DEFAULT]
|
568 |
+
self.add_compatible_weight(weights_default.weight_type)
|
569 |
+
# mask for which parts of controlnet output to keep
|
570 |
+
self.mask_cond_hint_original = None
|
571 |
+
self.mask_cond_hint = None
|
572 |
+
self.tk_mask_cond_hint_original = None
|
573 |
+
self.tk_mask_cond_hint = None
|
574 |
+
self.weight_mask_cond_hint = None
|
575 |
+
# actual index values
|
576 |
+
self.sub_idxs = None
|
577 |
+
self.full_latent_length = 0
|
578 |
+
self.context_length = 0
|
579 |
+
# timesteps
|
580 |
+
self.t: float = None
|
581 |
+
self.prev_t: float = None
|
582 |
+
self.batched_number: Union[int, IntWithCondOrUncond] = None
|
583 |
+
self.batch_size: int = 0
|
584 |
+
# weights + override
|
585 |
+
self.weights: ControlWeights = None
|
586 |
+
self.weights_default: ControlWeights = weights_default
|
587 |
+
self.weights_override: ControlWeights = None
|
588 |
+
# latent keyframe + override
|
589 |
+
self.latent_keyframes: LatentKeyframeGroup = None
|
590 |
+
self.latent_keyframe_override: LatentKeyframeGroup = None
|
591 |
+
# initialize timestep_keyframes
|
592 |
+
self.set_timestep_keyframes(timestep_keyframes)
|
593 |
+
# override some functions
|
594 |
+
self.get_control = self.get_control_inject
|
595 |
+
self.control_merge = self.control_merge_inject
|
596 |
+
self.pre_run = self.pre_run_inject
|
597 |
+
self.cleanup = self.cleanup_inject
|
598 |
+
self.set_previous_controlnet = self.set_previous_controlnet_inject
|
599 |
+
self.set_cond_hint = self.set_cond_hint_inject
|
600 |
+
# vae to store
|
601 |
+
self.adv_vae = None
|
602 |
+
# require model/vae to be passed into Apply Advanced ControlNet 🛂🅐🅒🅝 node
|
603 |
+
self.require_model = require_model
|
604 |
+
self.require_vae = require_vae
|
605 |
+
self.allow_condhint_latents = allow_condhint_latents
|
606 |
+
# disarm - when set to False, used to force usage of Apply Advanced ControlNet 🛂🅐🅒🅝 node (which will set it to True)
|
607 |
+
self.disarmed = not require_model
|
608 |
+
|
609 |
+
def patch_model(self, model: ModelPatcher):
|
610 |
+
pass
|
611 |
+
|
612 |
+
def add_compatible_weight(self, control_weight_type: str):
|
613 |
+
self.compatible_weights.append(control_weight_type)
|
614 |
+
|
615 |
+
def verify_all_weights(self, throw_error=True):
|
616 |
+
# first, check if override exists - if so, only need to check the override
|
617 |
+
if self.weights_override is not None:
|
618 |
+
if self.weights_override.weight_type not in self.compatible_weights:
|
619 |
+
msg = f"Weight override is type {self.weights_override.weight_type}, but loaded {type(self).__name__}" + \
|
620 |
+
f"only supports {self.compatible_weights} weights."
|
621 |
+
raise WeightTypeException(msg)
|
622 |
+
# otherwise, check all timestep keyframe weights
|
623 |
+
else:
|
624 |
+
for tk in self.timestep_keyframes.keyframes:
|
625 |
+
if tk.has_control_weights() and tk.control_weights.weight_type not in self.compatible_weights:
|
626 |
+
msg = f"Weight on Timestep Keyframe with start_percent={tk.start_percent} is type " + \
|
627 |
+
f"{tk.control_weights.weight_type}, but loaded {type(self).__name__} only supports {self.compatible_weights} weights."
|
628 |
+
raise WeightTypeException(msg)
|
629 |
+
|
630 |
+
def set_timestep_keyframes(self, timestep_keyframes: TimestepKeyframeGroup):
|
631 |
+
self.timestep_keyframes = timestep_keyframes if timestep_keyframes else TimestepKeyframeGroup()
|
632 |
+
# prepare first timestep_keyframe related stuff
|
633 |
+
self._current_timestep_keyframe = None
|
634 |
+
self._current_timestep_index = -1
|
635 |
+
self._current_used_steps = 0
|
636 |
+
self.weights = None
|
637 |
+
self.latent_keyframes = None
|
638 |
+
|
639 |
+
def prepare_current_timestep(self, t: Tensor, batched_number: int=1):
|
640 |
+
self.t = float(t[0])
|
641 |
+
# check if t has changed (otherwise do nothing, as step already accounted for)
|
642 |
+
if self.t == self.prev_t:
|
643 |
+
return
|
644 |
+
# get current step percent
|
645 |
+
curr_t: float = self.t
|
646 |
+
prev_index = self._current_timestep_index
|
647 |
+
# if met guaranteed steps (or no current keyframe), look for next keyframe in case need to switch
|
648 |
+
if self._current_timestep_keyframe is None or self._current_used_steps >= self._current_timestep_keyframe.guarantee_steps:
|
649 |
+
# if has next index, loop through and see if need to switch
|
650 |
+
if self.timestep_keyframes.has_index(self._current_timestep_index+1):
|
651 |
+
for i in range(self._current_timestep_index+1, len(self.timestep_keyframes)):
|
652 |
+
eval_tk = self.timestep_keyframes[i]
|
653 |
+
# check if start percent is less or equal to curr_t
|
654 |
+
if eval_tk.start_t >= curr_t:
|
655 |
+
self._current_timestep_index = i
|
656 |
+
self._current_timestep_keyframe = eval_tk
|
657 |
+
self._current_used_steps = 0
|
658 |
+
# keep track of control weights, latent keyframes, and masks,
|
659 |
+
# accounting for inherit_missing
|
660 |
+
if self._current_timestep_keyframe.has_control_weights():
|
661 |
+
self.weights = self._current_timestep_keyframe.control_weights
|
662 |
+
elif not self._current_timestep_keyframe.inherit_missing:
|
663 |
+
self.weights = self.weights_default
|
664 |
+
if self._current_timestep_keyframe.has_latent_keyframes():
|
665 |
+
self.latent_keyframes = self._current_timestep_keyframe.latent_keyframes
|
666 |
+
elif not self._current_timestep_keyframe.inherit_missing:
|
667 |
+
self.latent_keyframes = None
|
668 |
+
if self._current_timestep_keyframe.has_mask_hint():
|
669 |
+
self.tk_mask_cond_hint_original = self._current_timestep_keyframe.mask_hint_orig
|
670 |
+
elif not self._current_timestep_keyframe.inherit_missing:
|
671 |
+
del self.tk_mask_cond_hint_original
|
672 |
+
self.tk_mask_cond_hint_original = None
|
673 |
+
# if guarantee_steps greater than zero, stop searching for other keyframes
|
674 |
+
if self._current_timestep_keyframe.guarantee_steps > 0:
|
675 |
+
break
|
676 |
+
# if eval_tk is outside of percent range, stop looking further
|
677 |
+
else:
|
678 |
+
break
|
679 |
+
# update prev_t
|
680 |
+
self.prev_t = self.t
|
681 |
+
# update steps current keyframe is used
|
682 |
+
self._current_used_steps += 1
|
683 |
+
# if index changed, apply overrides
|
684 |
+
if prev_index != self._current_timestep_index:
|
685 |
+
if self.weights_override is not None:
|
686 |
+
self.weights = self.weights_override
|
687 |
+
if self.latent_keyframe_override is not None:
|
688 |
+
self.latent_keyframes = self.latent_keyframe_override
|
689 |
+
|
690 |
+
# make sure weights and latent_keyframes are in a workable state
|
691 |
+
# Note: each AdvancedControlBase should create their own get_universal_weights class
|
692 |
+
self.prepare_weights()
|
693 |
+
|
694 |
+
def prepare_weights(self):
|
695 |
+
if self.weights is None:
|
696 |
+
self.weights = self.weights_default
|
697 |
+
elif self.weights.weight_type == ControlWeightType.UNIVERSAL:
|
698 |
+
# if universal and weight_mask present, no need to convert
|
699 |
+
if self.weights.weight_mask is not None:
|
700 |
+
return
|
701 |
+
self.weights = self.get_universal_weights()
|
702 |
+
|
703 |
+
def get_universal_weights(self) -> ControlWeights:
|
704 |
+
return self.weights
|
705 |
+
|
706 |
+
def set_cond_hint_mask(self, mask_hint):
|
707 |
+
self.mask_cond_hint_original = mask_hint
|
708 |
+
return self
|
709 |
+
|
710 |
+
def set_cond_hint_inject(self, *args, **kwargs):
|
711 |
+
to_return = self.base.set_cond_hint(*args, **kwargs)
|
712 |
+
# if vae required, look in args and kwargs for it
|
713 |
+
if self.require_vae:
|
714 |
+
# check args first, as that's the default way vae param is used in ComfyUI
|
715 |
+
for arg in args:
|
716 |
+
if isinstance(arg, VAE):
|
717 |
+
self.adv_vae = arg
|
718 |
+
break
|
719 |
+
# if not in args, check kwargs now
|
720 |
+
if self.adv_vae is None:
|
721 |
+
if 'vae' in kwargs:
|
722 |
+
self.adv_vae = kwargs['vae']
|
723 |
+
return to_return
|
724 |
+
|
725 |
+
def pre_run_inject(self, model, percent_to_timestep_function):
|
726 |
+
self.base.pre_run(model, percent_to_timestep_function)
|
727 |
+
self.pre_run_advanced(model, percent_to_timestep_function)
|
728 |
+
|
729 |
+
def pre_run_advanced(self, model, percent_to_timestep_function):
|
730 |
+
# for each timestep keyframe, calculate the start_t
|
731 |
+
for tk in self.timestep_keyframes.keyframes:
|
732 |
+
tk.start_t = percent_to_timestep_function(tk.start_percent)
|
733 |
+
# clear variables
|
734 |
+
self.cleanup_advanced()
|
735 |
+
|
736 |
+
def set_previous_controlnet_inject(self, *args, **kwargs):
|
737 |
+
to_return = self.base.set_previous_controlnet(*args, **kwargs)
|
738 |
+
if not self.disarmed:
|
739 |
+
raise Exception(f"Type '{type(self).__name__}' must be used with Apply Advanced ControlNet 🛂🅐🅒🅝 node (with model_optional passed in); otherwise, it will not work.")
|
740 |
+
return to_return
|
741 |
+
|
742 |
+
def disarm(self):
|
743 |
+
self.disarmed = True
|
744 |
+
|
745 |
+
def should_run(self):
|
746 |
+
if math.isclose(self.strength, 0.0) or math.isclose(self._current_timestep_keyframe.strength, 0.0):
|
747 |
+
return False
|
748 |
+
if self.timestep_range is not None:
|
749 |
+
if self.t > self.timestep_range[0] or self.t < self.timestep_range[1]:
|
750 |
+
return False
|
751 |
+
return True
|
752 |
+
|
753 |
+
def get_control_inject(self, x_noisy, t, cond, batched_number):
|
754 |
+
self.batched_number = batched_number
|
755 |
+
self.batch_size = len(t)
|
756 |
+
# prepare timestep and everything related
|
757 |
+
self.prepare_current_timestep(t=t, batched_number=batched_number)
|
758 |
+
# if should not perform any actions for the controlnet, exit without doing any work
|
759 |
+
if self.strength == 0.0 or self._current_timestep_keyframe.strength == 0.0:
|
760 |
+
return self.default_control_actions(x_noisy, t, cond, batched_number)
|
761 |
+
# otherwise, perform normal function
|
762 |
+
return self.get_control_advanced(x_noisy, t, cond, batched_number)
|
763 |
+
|
764 |
+
def get_control_advanced(self, x_noisy, t, cond, batched_number):
|
765 |
+
return self.default_control_actions(x_noisy, t, cond, batched_number)
|
766 |
+
|
767 |
+
def default_control_actions(self, x_noisy, t, cond, batched_number):
|
768 |
+
control_prev = None
|
769 |
+
if self.previous_controlnet is not None:
|
770 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
771 |
+
return control_prev
|
772 |
+
|
773 |
+
def calc_weight(self, idx: int, x: Tensor, control: dict[str, list[Tensor]], key: str) -> Union[float, Tensor]:
|
774 |
+
if self.weights.weight_mask is not None:
|
775 |
+
# prepare weight mask
|
776 |
+
self.prepare_weight_mask_cond_hint(x, self.batched_number)
|
777 |
+
# adjust mask for current layer and return
|
778 |
+
return torch.pow(self.weight_mask_cond_hint, self.get_calc_pow(idx=idx, control=control, key=key))
|
779 |
+
return self.weights.get(idx=idx, control=control, key=key)
|
780 |
+
|
781 |
+
def get_calc_pow(self, idx: int, control: dict[str, list[Tensor]], key: str) -> int:
|
782 |
+
if key == "middle":
|
783 |
+
return 0
|
784 |
+
else:
|
785 |
+
c_len = len(control[key])
|
786 |
+
real_idx = c_len-idx
|
787 |
+
if key == "input":
|
788 |
+
real_idx = c_len - real_idx + 1
|
789 |
+
return real_idx
|
790 |
+
|
791 |
+
def calc_latent_keyframe_mults(self, x: Tensor, batched_number: int) -> Tensor:
|
792 |
+
# apply strengths, and get batch indeces to null out
|
793 |
+
# AKA latents that should not be influenced by ControlNet
|
794 |
+
final_mults = [1.0] * x.shape[0]
|
795 |
+
if self.latent_keyframes:
|
796 |
+
latent_count = x.shape[0] // batched_number
|
797 |
+
indeces_to_null = set(range(latent_count))
|
798 |
+
mapped_indeces = None
|
799 |
+
# if expecting subdivision, will need to translate between subset and actual idx values
|
800 |
+
if self.sub_idxs:
|
801 |
+
mapped_indeces = {}
|
802 |
+
for i, actual in enumerate(self.sub_idxs):
|
803 |
+
mapped_indeces[actual] = i
|
804 |
+
for keyframe in self.latent_keyframes:
|
805 |
+
real_index = keyframe.batch_index
|
806 |
+
# if negative, count from end
|
807 |
+
if real_index < 0:
|
808 |
+
real_index += latent_count if self.sub_idxs is None else self.full_latent_length
|
809 |
+
|
810 |
+
# if not mapping indeces, what you see is what you get
|
811 |
+
if mapped_indeces is None:
|
812 |
+
if real_index in indeces_to_null:
|
813 |
+
indeces_to_null.remove(real_index)
|
814 |
+
# otherwise, see if batch_index is even included in this set of latents
|
815 |
+
else:
|
816 |
+
real_index = mapped_indeces.get(real_index, None)
|
817 |
+
if real_index is None:
|
818 |
+
continue
|
819 |
+
indeces_to_null.remove(real_index)
|
820 |
+
|
821 |
+
# if real_index is outside the bounds of latents, don't apply
|
822 |
+
if real_index >= latent_count or real_index < 0:
|
823 |
+
continue
|
824 |
+
|
825 |
+
# apply strength for each batched cond/uncond
|
826 |
+
for b in range(batched_number):
|
827 |
+
final_mults[(latent_count*b)+real_index] = keyframe.strength
|
828 |
+
# null them out by multiplying by null_latent_kf_strength
|
829 |
+
for batch_index in indeces_to_null:
|
830 |
+
# apply null for each batched cond/uncond
|
831 |
+
for b in range(batched_number):
|
832 |
+
final_mults[(latent_count*b)+batch_index] = self._current_timestep_keyframe.null_latent_kf_strength
|
833 |
+
# convert final_mults into tensor and match expected dimension count
|
834 |
+
final_tensor = torch.tensor(final_mults, dtype=x.dtype, device=x.device)
|
835 |
+
while len(final_tensor.shape) < len(x.shape):
|
836 |
+
final_tensor = final_tensor.unsqueeze(-1)
|
837 |
+
return final_tensor
|
838 |
+
|
839 |
+
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, flux_shape: tuple=None):
|
840 |
+
# handle weight's uncond_multiplier, if applicable
|
841 |
+
if self.weights.has_uncond_multiplier:
|
842 |
+
cond_or_uncond = self.batched_number.cond_or_uncond
|
843 |
+
actual_length = x.size(0) // batched_number
|
844 |
+
for idx, cond_type in enumerate(cond_or_uncond):
|
845 |
+
# if uncond, set to weight's uncond_multiplier
|
846 |
+
if cond_type == 1:
|
847 |
+
x[actual_length*idx:actual_length*(idx+1)] *= self.weights.uncond_multiplier
|
848 |
+
if self.weights.has_uncond_mask:
|
849 |
+
pass
|
850 |
+
|
851 |
+
if self.latent_keyframes is not None:
|
852 |
+
x[:] = x[:] * self.calc_latent_keyframe_mults(x=x, batched_number=batched_number)
|
853 |
+
# apply masks, resizing mask to required dims
|
854 |
+
if self.mask_cond_hint is not None:
|
855 |
+
masks = prepare_mask_batch(self.mask_cond_hint, x.shape, match_shape=True, flux_shape=flux_shape)
|
856 |
+
x[:] = x[:] * masks
|
857 |
+
if self.tk_mask_cond_hint is not None:
|
858 |
+
masks = prepare_mask_batch(self.tk_mask_cond_hint, x.shape, match_shape=True, flux_shape=flux_shape)
|
859 |
+
x[:] = x[:] * masks
|
860 |
+
# apply timestep keyframe strengths
|
861 |
+
if self._current_timestep_keyframe.strength != 1.0:
|
862 |
+
x[:] *= self._current_timestep_keyframe.strength
|
863 |
+
|
864 |
+
def control_merge_inject(self: 'AdvancedControlBase', control: dict[str, list[Tensor]], control_prev: dict, output_dtype):
|
865 |
+
out = {'input':[], 'middle':[], 'output': []}
|
866 |
+
|
867 |
+
for key in control:
|
868 |
+
control_output = control[key]
|
869 |
+
applied_to = set()
|
870 |
+
for i in range(len(control_output)):
|
871 |
+
x = control_output[i]
|
872 |
+
if x is not None:
|
873 |
+
if self.global_average_pooling:
|
874 |
+
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
875 |
+
|
876 |
+
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
877 |
+
applied_to.add(x)
|
878 |
+
self.apply_advanced_strengths_and_masks(x, self.batched_number)
|
879 |
+
x *= self.strength * self.calc_weight(i, x, control, key)
|
880 |
+
|
881 |
+
if output_dtype is not None and x.dtype != output_dtype:
|
882 |
+
x = x.to(output_dtype)
|
883 |
+
|
884 |
+
out[key].append(x)
|
885 |
+
|
886 |
+
if control_prev is not None:
|
887 |
+
for x in ['input', 'middle', 'output']:
|
888 |
+
o = out[x]
|
889 |
+
for i in range(len(control_prev[x])):
|
890 |
+
prev_val = control_prev[x][i]
|
891 |
+
if i >= len(o):
|
892 |
+
o.append(prev_val)
|
893 |
+
elif prev_val is not None:
|
894 |
+
if o[i] is None:
|
895 |
+
o[i] = prev_val
|
896 |
+
else:
|
897 |
+
if o[i].shape[0] < prev_val.shape[0]:
|
898 |
+
o[i] = prev_val + o[i]
|
899 |
+
else:
|
900 |
+
o[i] = prev_val + o[i] # TODO from base ComfyUI: change back to inplace add if shared tensors stop being an issue
|
901 |
+
return out
|
902 |
+
|
903 |
+
def prepare_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False):
|
904 |
+
self._prepare_mask("mask_cond_hint", self.mask_cond_hint_original, x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn)
|
905 |
+
self.prepare_tk_mask_cond_hint(x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn)
|
906 |
+
|
907 |
+
def prepare_tk_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False):
|
908 |
+
return self._prepare_mask("tk_mask_cond_hint", self._current_timestep_keyframe.mask_hint_orig, x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn)
|
909 |
+
|
910 |
+
def prepare_weight_mask_cond_hint(self, x_noisy: Tensor, batched_number, dtype=None):
|
911 |
+
return self._prepare_mask("weight_mask_cond_hint", self.weights.weight_mask, x_noisy, t=None, cond=None, batched_number=batched_number, dtype=dtype, direct_attn=True)
|
912 |
+
|
913 |
+
def _prepare_mask(self, attr_name, orig_mask: Tensor, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False):
|
914 |
+
# make mask appropriate dimensions, if present
|
915 |
+
if orig_mask is not None:
|
916 |
+
out_mask = getattr(self, attr_name)
|
917 |
+
multiplier = 1 if direct_attn else 8
|
918 |
+
if self.sub_idxs is not None or out_mask is None or x_noisy.shape[2] * multiplier != out_mask.shape[1] or x_noisy.shape[3] * multiplier != out_mask.shape[2]:
|
919 |
+
self._reset_attr(attr_name)
|
920 |
+
del out_mask
|
921 |
+
# TODO: perform upscale on only the sub_idxs masks at a time instead of all to conserve RAM
|
922 |
+
# resize mask and match batch count
|
923 |
+
out_mask = prepare_mask_batch(orig_mask, x_noisy.shape, multiplier=multiplier, match_shape=True)
|
924 |
+
actual_latent_length = x_noisy.shape[0] // batched_number
|
925 |
+
out_mask = extend_to_batch_size(out_mask, actual_latent_length if self.sub_idxs is None else self.full_latent_length)
|
926 |
+
if self.sub_idxs is not None:
|
927 |
+
out_mask = out_mask[self.sub_idxs]
|
928 |
+
# make cond_hint_mask length match x_noise
|
929 |
+
if x_noisy.shape[0] != out_mask.shape[0]:
|
930 |
+
out_mask = broadcast_image_to_extend(out_mask, x_noisy.shape[0], batched_number)
|
931 |
+
# default dtype to be same as x_noisy
|
932 |
+
if dtype is None:
|
933 |
+
dtype = x_noisy.dtype
|
934 |
+
setattr(self, attr_name, out_mask.to(dtype=dtype).to(x_noisy.device))
|
935 |
+
del out_mask
|
936 |
+
|
937 |
+
def _reset_attr(self, attr_name, new_value=None):
|
938 |
+
if hasattr(self, attr_name):
|
939 |
+
delattr(self, attr_name)
|
940 |
+
setattr(self, attr_name, new_value)
|
941 |
+
|
942 |
+
def cleanup_inject(self):
|
943 |
+
self.base.cleanup()
|
944 |
+
self.cleanup_advanced()
|
945 |
+
|
946 |
+
def cleanup_advanced(self):
|
947 |
+
self.sub_idxs = None
|
948 |
+
self.full_latent_length = 0
|
949 |
+
self.context_length = 0
|
950 |
+
self.t = None
|
951 |
+
self.prev_t = None
|
952 |
+
self.batched_number = None
|
953 |
+
self.batch_size = 0
|
954 |
+
self.weights = None
|
955 |
+
self.latent_keyframes = None
|
956 |
+
# timestep stuff
|
957 |
+
self._current_timestep_keyframe = None
|
958 |
+
self._current_timestep_index = -1
|
959 |
+
self._current_used_steps = 0
|
960 |
+
# clear mask hints
|
961 |
+
if self.mask_cond_hint is not None:
|
962 |
+
del self.mask_cond_hint
|
963 |
+
self.mask_cond_hint = None
|
964 |
+
if self.tk_mask_cond_hint_original is not None:
|
965 |
+
del self.tk_mask_cond_hint_original
|
966 |
+
self.tk_mask_cond_hint_original = None
|
967 |
+
if self.tk_mask_cond_hint is not None:
|
968 |
+
del self.tk_mask_cond_hint
|
969 |
+
self.tk_mask_cond_hint = None
|
970 |
+
if self.weight_mask_cond_hint is not None:
|
971 |
+
del self.weight_mask_cond_hint
|
972 |
+
self.weight_mask_cond_hint = None
|
973 |
+
|
974 |
+
def copy_to_advanced(self, copied: 'AdvancedControlBase'):
|
975 |
+
copied.mask_cond_hint_original = self.mask_cond_hint_original
|
976 |
+
copied.weights_override = self.weights_override
|
977 |
+
copied.latent_keyframe_override = self.latent_keyframe_override
|
978 |
+
copied.adv_vae = self.adv_vae
|
979 |
+
copied.require_vae = self.require_vae
|
980 |
+
copied.allow_condhint_latents = self.allow_condhint_latents
|
981 |
+
copied.disarmed = self.disarmed
|
ComfyUI-Advanced-ControlNet/pyproject.toml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "comfyui-advanced-controlnet"
|
3 |
+
description = "Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks."
|
4 |
+
version = "1.3.0"
|
5 |
+
license = { file = "LICENSE" }
|
6 |
+
dependencies = []
|
7 |
+
|
8 |
+
[project.urls]
|
9 |
+
Repository = "https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet"
|
10 |
+
|
11 |
+
# Used by Comfy Registry https://comfyregistry.org
|
12 |
+
[tool.comfy]
|
13 |
+
PublisherId = "kosinkadink"
|
14 |
+
DisplayName = "ComfyUI-Advanced-ControlNet"
|
15 |
+
Icon = ""
|
ComfyUI-Advanced-ControlNet/web/js/autosize.js
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { app } from '../../../scripts/app.js'
|
2 |
+
|
3 |
+
function addResizeHook(node, padding, useOldMin=false) {
|
4 |
+
let origOnCreated = node.onNodeCreated
|
5 |
+
node.onNodeCreated = function() {
|
6 |
+
let r = origOnCreated?.apply(this, arguments)
|
7 |
+
let size = this.computeSize();
|
8 |
+
size[0] += padding || 0;
|
9 |
+
if (useOldMin) {
|
10 |
+
//equal to LiteGraph.NODE_WIDTH*1.5*1.5
|
11 |
+
size[0] = Math.max(size[0], 315)
|
12 |
+
}
|
13 |
+
this.setSize(size);
|
14 |
+
return r
|
15 |
+
}
|
16 |
+
}
|
17 |
+
|
18 |
+
app.registerExtension({
|
19 |
+
name: "AdvancedControlNet.autosize",
|
20 |
+
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
21 |
+
//since python_module is based off folder path,
|
22 |
+
//it could be changed by users and should only be used as fallback
|
23 |
+
if (nodeData?.name?.startsWith("ACN_")
|
24 |
+
|| nodeData.python_module == 'custom_nodes.ComfyUI-Advanced-ControlNet') {
|
25 |
+
if (nodeData?.input?.hidden?.autosize) {
|
26 |
+
addResizeHook(nodeType.prototype, nodeData.input.hidden.autosize[1]?.padding)
|
27 |
+
} else if (!nodeData?.input?.optional?.autosize) {
|
28 |
+
addResizeHook(nodeType.prototype, 0, true)
|
29 |
+
}
|
30 |
+
}
|
31 |
+
},
|
32 |
+
async getCustomWidgets() {
|
33 |
+
return {
|
34 |
+
ACNAUTOSIZE(node, inputName, inputData) {
|
35 |
+
let w = {
|
36 |
+
name : inputName,
|
37 |
+
type : "ACN.AUTOSIZE",
|
38 |
+
value : "",
|
39 |
+
options : {"serialize": false},
|
40 |
+
computeSize : function(width) {
|
41 |
+
return [0, -4];
|
42 |
+
}
|
43 |
+
}
|
44 |
+
if (!node.widgets) {
|
45 |
+
node.widgets = []
|
46 |
+
}
|
47 |
+
node.widgets.push(w)
|
48 |
+
addResizeHook(node, inputData[1].padding);
|
49 |
+
return w;
|
50 |
+
}
|
51 |
+
}
|
52 |
+
}
|
53 |
+
});
|
ComfyUI-Advanced-ControlNet/web/js/documentation.js
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { app } from '../../../scripts/app.js'
|
2 |
+
|
3 |
+
function chainCallback(object, property, callback) {
|
4 |
+
if (object == undefined) {
|
5 |
+
//This should not happen.
|
6 |
+
console.error("Tried to add callback to non-existant object")
|
7 |
+
return;
|
8 |
+
}
|
9 |
+
if (property in object && object[property]) {
|
10 |
+
const callback_orig = object[property]
|
11 |
+
object[property] = function () {
|
12 |
+
const r = callback_orig.apply(this, arguments);
|
13 |
+
callback.apply(this, arguments);
|
14 |
+
return r
|
15 |
+
};
|
16 |
+
} else {
|
17 |
+
object[property] = callback;
|
18 |
+
}
|
19 |
+
}
|
20 |
+
var helpDOM;
|
21 |
+
function initHelpDOM() {
|
22 |
+
let parentDOM = document.createElement("div");
|
23 |
+
document.body.appendChild(parentDOM)
|
24 |
+
parentDOM.appendChild(helpDOM)
|
25 |
+
helpDOM.className = "litegraph";
|
26 |
+
let scrollbarStyle = document.createElement('style');
|
27 |
+
scrollbarStyle.innerHTML = `
|
28 |
+
<style id="scroll-properties">
|
29 |
+
* {
|
30 |
+
scrollbar-width: 6px;
|
31 |
+
scrollbar-color: #0003 #0000;
|
32 |
+
}
|
33 |
+
::-webkit-scrollbar {
|
34 |
+
background: transparent;
|
35 |
+
width: 6px;
|
36 |
+
}
|
37 |
+
::-webkit-scrollbar-thumb {
|
38 |
+
background: #0005;
|
39 |
+
border-radius: 20px
|
40 |
+
}
|
41 |
+
::-webkit-scrollbar-button {
|
42 |
+
display: none;
|
43 |
+
}
|
44 |
+
.VHS_loopedvideo::-webkit-media-controls-mute-button {
|
45 |
+
display:none;
|
46 |
+
}
|
47 |
+
.VHS_loopedvideo::-webkit-media-controls-fullscreen-button {
|
48 |
+
display:none;
|
49 |
+
}
|
50 |
+
</style>
|
51 |
+
`
|
52 |
+
parentDOM.appendChild(scrollbarStyle)
|
53 |
+
chainCallback(app.canvas, "onDrawForeground", function (ctx, visible_rect){
|
54 |
+
let n = helpDOM.node
|
55 |
+
if (!n || !n?.graph) {
|
56 |
+
parentDOM.style['left'] = '-5000px'
|
57 |
+
return
|
58 |
+
}
|
59 |
+
//draw : function(ctx, node, widgetWidth, widgetY, height) {
|
60 |
+
//update widget position, even if off screen
|
61 |
+
const transform = ctx.getTransform();
|
62 |
+
const scale = app.canvas.ds.scale;//gets the litegraph zoom
|
63 |
+
//calculate coordinates with account for browser zoom
|
64 |
+
const bcr = app.canvas.canvas.getBoundingClientRect()
|
65 |
+
const x = transform.e*scale/transform.a + bcr.x;
|
66 |
+
const y = transform.f*scale/transform.a + bcr.y;
|
67 |
+
//TODO: text reflows at low zoom. investigate alternatives
|
68 |
+
Object.assign(parentDOM.style, {
|
69 |
+
left: (x+(n.pos[0] + n.size[0]+15)*scale) + "px",
|
70 |
+
top: (y+(n.pos[1]-LiteGraph.NODE_TITLE_HEIGHT)*scale) + "px",
|
71 |
+
width: "400px",
|
72 |
+
minHeight: "100px",
|
73 |
+
maxHeight: "600px",
|
74 |
+
overflowY: 'scroll',
|
75 |
+
transformOrigin: '0 0',
|
76 |
+
transform: 'scale(' + scale + ',' + scale +')',
|
77 |
+
fontSize: '18px',
|
78 |
+
backgroundColor: LiteGraph.NODE_DEFAULT_BGCOLOR,
|
79 |
+
boxShadow: '0 0 10px black',
|
80 |
+
borderRadius: '4px',
|
81 |
+
padding: '3px',
|
82 |
+
zIndex: 3,
|
83 |
+
position: "absolute",
|
84 |
+
display: 'inline',
|
85 |
+
});
|
86 |
+
});
|
87 |
+
function setCollapse(el, doCollapse) {
|
88 |
+
if (doCollapse) {
|
89 |
+
el.children[0].children[0].innerHTML = '+'
|
90 |
+
Object.assign(el.children[1].style, {
|
91 |
+
color: '#CCC',
|
92 |
+
overflowX: 'hidden',
|
93 |
+
width: '0px',
|
94 |
+
minWidth: 'calc(100% - 20px)',
|
95 |
+
textOverflow: 'ellipsis',
|
96 |
+
whiteSpace: 'nowrap',
|
97 |
+
})
|
98 |
+
for (let child of el.children[1].children) {
|
99 |
+
if (child.style.display != 'none'){
|
100 |
+
child.origDisplay = child.style.display
|
101 |
+
}
|
102 |
+
child.style.display = 'none'
|
103 |
+
}
|
104 |
+
} else {
|
105 |
+
el.children[0].children[0].innerHTML = '-'
|
106 |
+
Object.assign(el.children[1].style, {
|
107 |
+
color: '',
|
108 |
+
overflowX: '',
|
109 |
+
width: '100%',
|
110 |
+
minWidth: '',
|
111 |
+
textOverflow: '',
|
112 |
+
whiteSpace: '',
|
113 |
+
})
|
114 |
+
for (let child of el.children[1].children) {
|
115 |
+
child.style.display = child.origDisplay
|
116 |
+
}
|
117 |
+
}
|
118 |
+
}
|
119 |
+
helpDOM.collapseOnClick = function() {
|
120 |
+
let doCollapse = this.children[0].innerHTML == '-'
|
121 |
+
setCollapse(this.parentElement, doCollapse)
|
122 |
+
}
|
123 |
+
helpDOM.selectHelp = function(name, value) {
|
124 |
+
//attempt to navigate to name in help
|
125 |
+
function collapseUnlessMatch(items,t) {
|
126 |
+
var match = items.querySelector('[vhs_title="' + t + '"]')
|
127 |
+
if (!match) {
|
128 |
+
for (let i of items.children) {
|
129 |
+
if (i.innerHTML.slice(0,t.length+5).includes(t)) {
|
130 |
+
match = i
|
131 |
+
break
|
132 |
+
}
|
133 |
+
}
|
134 |
+
}
|
135 |
+
if (!match) {
|
136 |
+
return null
|
137 |
+
}
|
138 |
+
//For longer documentation items with fewer collapsable elements,
|
139 |
+
//scroll to make sure the entirety of the selected item is visible
|
140 |
+
//This has the unfortunate side effect of trying to scroll the main
|
141 |
+
//window if the documentation windows is forcibly offscreen,
|
142 |
+
//but it's easy to simply scroll the main window back and seems to
|
143 |
+
//have no visual side effects
|
144 |
+
match.scrollIntoView(false)
|
145 |
+
window.scrollTo(0,0)
|
146 |
+
for (let i of items.querySelectorAll('.VHS_collapse')) {
|
147 |
+
if (i.contains(match)) {
|
148 |
+
setCollapse(i, false)
|
149 |
+
} else {
|
150 |
+
setCollapse(i, true)
|
151 |
+
}
|
152 |
+
}
|
153 |
+
return match
|
154 |
+
}
|
155 |
+
let target = collapseUnlessMatch(helpDOM, name)
|
156 |
+
if (target && value) {
|
157 |
+
collapseUnlessMatch(target, value)
|
158 |
+
}
|
159 |
+
}
|
160 |
+
|
161 |
+
helpDOM.addHelp = function(node, nodeType, description) {
|
162 |
+
if (!description) {
|
163 |
+
return
|
164 |
+
}
|
165 |
+
//Pad computed size for the clickable question mark
|
166 |
+
let originalComputeSize = node.computeSize
|
167 |
+
node.computeSize = function() {
|
168 |
+
let size = originalComputeSize.apply(this, arguments)
|
169 |
+
if (!this.title) {
|
170 |
+
return size
|
171 |
+
}
|
172 |
+
let title_width = this.title.length * 0.6 * LiteGraph.NODE_TEXT_SIZE
|
173 |
+
size[0] = Math.max(size[0], title_width + LiteGraph.NODE_TITLE_HEIGHT)
|
174 |
+
return size
|
175 |
+
}
|
176 |
+
|
177 |
+
node.description = description
|
178 |
+
chainCallback(node, "onDrawForeground", function (ctx) {
|
179 |
+
//draw question mark
|
180 |
+
ctx.save()
|
181 |
+
ctx.font = 'bold 20px Arial'
|
182 |
+
ctx.fillText("?", this.size[0]-17, -8)
|
183 |
+
ctx.restore()
|
184 |
+
})
|
185 |
+
chainCallback(node, "onMouseDown", function (e, pos, canvas) {
|
186 |
+
//On click would be preferred, but this'll be good enough
|
187 |
+
if (pos[1] < 0 && pos[0] + LiteGraph.NODE_TITLE_HEIGHT > this.size[0]) {
|
188 |
+
//corner question mark clicked
|
189 |
+
if (helpDOM.node == this) {
|
190 |
+
helpDOM.node = undefined
|
191 |
+
} else {
|
192 |
+
helpDOM.node = this;
|
193 |
+
helpDOM.innerHTML = this.description || "no help provided ".repeat(20)
|
194 |
+
for (let e of helpDOM.querySelectorAll('.VHS_collapse')) {
|
195 |
+
e.children[0].onclick = helpDOM.collapseOnClick
|
196 |
+
e.children[0].style.cursor = 'pointer'
|
197 |
+
}
|
198 |
+
for (let e of helpDOM.querySelectorAll('.VHS_precollapse')) {
|
199 |
+
setCollapse(e, true)
|
200 |
+
}
|
201 |
+
}
|
202 |
+
return true
|
203 |
+
}
|
204 |
+
})
|
205 |
+
let timeout = null
|
206 |
+
chainCallback(node, "onMouseMove", function (e, pos, canvas) {
|
207 |
+
if (timeout) {
|
208 |
+
clearTimeout(timeout)
|
209 |
+
timeout = null
|
210 |
+
}
|
211 |
+
if (helpDOM.node != this) {
|
212 |
+
return
|
213 |
+
}
|
214 |
+
timeout = setTimeout(() => {
|
215 |
+
let n = this
|
216 |
+
if (pos[0] > 0 && pos[0] < n.size[0]
|
217 |
+
&& pos[1] > 0 && pos[1] < n.size[1]) {
|
218 |
+
//TODO: provide help specific to element clicked
|
219 |
+
let inputRows = Math.max(n.inputs.length, n.outputs.length)
|
220 |
+
if (pos[1] < LiteGraph.NODE_SLOT_HEIGHT * inputRows) {
|
221 |
+
let row = Math.floor((pos[1] - 7) / LiteGraph.NODE_SLOT_HEIGHT)
|
222 |
+
if (pos[0] < n.size[0]/2) {
|
223 |
+
if (row < n.inputs.length) {
|
224 |
+
helpDOM.selectHelp(n.inputs[row].name)
|
225 |
+
}
|
226 |
+
} else {
|
227 |
+
if (row < n.outputs.length) {
|
228 |
+
helpDOM.selectHelp(n.outputs[row].name)
|
229 |
+
}
|
230 |
+
}
|
231 |
+
} else {
|
232 |
+
//probably widget, but widgets have variable height.
|
233 |
+
let basey = LiteGraph.NODE_SLOT_HEIGHT * inputRows + 6
|
234 |
+
for (let w of n.widgets) {
|
235 |
+
if (w.y) {
|
236 |
+
basey = w.y
|
237 |
+
}
|
238 |
+
let wheight = LiteGraph.NODE_WIDGET_HEIGHT+4
|
239 |
+
if (w.computeSize) {
|
240 |
+
wheight = w.computeSize(n.size[0])[1]
|
241 |
+
}
|
242 |
+
if (pos[1] < basey + wheight) {
|
243 |
+
helpDOM.selectHelp(w.name, w.value)
|
244 |
+
break
|
245 |
+
}
|
246 |
+
basey += wheight
|
247 |
+
}
|
248 |
+
}
|
249 |
+
}
|
250 |
+
}, 500)
|
251 |
+
})
|
252 |
+
chainCallback(node, "onMouseLeave", function (e, pos, canvas) {
|
253 |
+
if (timeout) {
|
254 |
+
clearTimeout(timeout)
|
255 |
+
timeout = null
|
256 |
+
}
|
257 |
+
});
|
258 |
+
}
|
259 |
+
}
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
app.registerExtension({
|
264 |
+
name: "AdvancedControlNet.documentation",
|
265 |
+
async init() {
|
266 |
+
if (app.VHSHelp) {
|
267 |
+
helpDOM = app.VHSHelp
|
268 |
+
} else {
|
269 |
+
helpDOM = document.createElement("div");
|
270 |
+
initHelpDOM()
|
271 |
+
app.VHSHelp = helpDOM
|
272 |
+
}
|
273 |
+
},
|
274 |
+
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
275 |
+
// NOTE: May need manual adjusting for the few non-namespaced nodes
|
276 |
+
if(nodeData?.name?.startsWith("ACN_") && nodeData.description) {
|
277 |
+
let description = nodeData.description
|
278 |
+
let el = document.createElement("div")
|
279 |
+
el.innerHTML = description
|
280 |
+
if (!el.children.length) {
|
281 |
+
//Is plaintext. Do minor convenience formatting
|
282 |
+
let chunks = description.split('\n')
|
283 |
+
nodeData.description = chunks[0]
|
284 |
+
description = chunks.join('<br>')
|
285 |
+
} else {
|
286 |
+
nodeData.description = el.querySelector('#VHS_shortdesc')?.innerHTML || el.children[1]?.firstChild?.innerHTML
|
287 |
+
}
|
288 |
+
chainCallback(nodeType.prototype, "onNodeCreated", function () {
|
289 |
+
helpDOM.addHelp(this, nodeType, description)
|
290 |
+
})
|
291 |
+
}
|
292 |
+
},
|
293 |
+
});
|