lehduong commited on
Commit
038856e
1 Parent(s): 07d760c

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cond_and_image.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/onediffusion_appendix_faceid.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/onediffusion_appendix_faceid_3.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/onediffusion_appendix_multiview.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/onediffusion_appendix_text2multiview.pdf filter=lfs diff=lfs merge=lfs -text
41
+ assets/onediffusion_zeroshot.jpg filter=lfs diff=lfs merge=lfs -text
42
+ assets/promptguide_complex.jpg filter=lfs diff=lfs merge=lfs -text
43
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/text2multiview.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+
142
+ Section 2 -- Scope.
143
+
144
+ a. License grant.
145
+
146
+ 1. Subject to the terms and conditions of this Public License,
147
+ the Licensor hereby grants You a worldwide, royalty-free,
148
+ non-sublicensable, non-exclusive, irrevocable license to
149
+ exercise the Licensed Rights in the Licensed Material to:
150
+
151
+ a. reproduce and Share the Licensed Material, in whole or
152
+ in part, for NonCommercial purposes only; and
153
+
154
+ b. produce, reproduce, and Share Adapted Material for
155
+ NonCommercial purposes only.
156
+
157
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
158
+ Exceptions and Limitations apply to Your use, this Public
159
+ License does not apply, and You do not need to comply with
160
+ its terms and conditions.
161
+
162
+ 3. Term. The term of this Public License is specified in Section
163
+ 6(a).
164
+
165
+ 4. Media and formats; technical modifications allowed. The
166
+ Licensor authorizes You to exercise the Licensed Rights in
167
+ all media and formats whether now known or hereafter created,
168
+ and to make technical modifications necessary to do so. The
169
+ Licensor waives and/or agrees not to assert any right or
170
+ authority to forbid You from making technical modifications
171
+ necessary to exercise the Licensed Rights, including
172
+ technical modifications necessary to circumvent Effective
173
+ Technological Measures. For purposes of this Public License,
174
+ simply making modifications authorized by this Section 2(a)
175
+ (4) never produces Adapted Material.
176
+
177
+ 5. Downstream recipients.
178
+
179
+ a. Offer from the Licensor -- Licensed Material. Every
180
+ recipient of the Licensed Material automatically
181
+ receives an offer from the Licensor to exercise the
182
+ Licensed Rights under the terms and conditions of this
183
+ Public License.
184
+
185
+ b. No downstream restrictions. You may not offer or impose
186
+ any additional or different terms or conditions on, or
187
+ apply any Effective Technological Measures to, the
188
+ Licensed Material if doing so restricts exercise of the
189
+ Licensed Rights by any recipient of the Licensed
190
+ Material.
191
+
192
+ 6. No endorsement. Nothing in this Public License constitutes or
193
+ may be construed as permission to assert or imply that You
194
+ are, or that Your use of the Licensed Material is, connected
195
+ with, or sponsored, endorsed, or granted official status by,
196
+ the Licensor or others designated to receive attribution as
197
+ provided in Section 3(a)(1)(A)(i).
198
+
199
+ b. Other rights.
200
+
201
+ 1. Moral rights, such as the right of integrity, are not
202
+ licensed under this Public License, nor are publicity,
203
+ privacy, and/or other similar personality rights; however, to
204
+ the extent possible, the Licensor waives and/or agrees not to
205
+ assert any such rights held by the Licensor to the limited
206
+ extent necessary to allow You to exercise the Licensed
207
+ Rights, but not otherwise.
208
+
209
+ 2. Patent and trademark rights are not licensed under this
210
+ Public License.
211
+
212
+ 3. To the extent possible, the Licensor waives any right to
213
+ collect royalties from You for the exercise of the Licensed
214
+ Rights, whether directly or through a collecting society
215
+ under any voluntary or waivable statutory or compulsory
216
+ licensing scheme. In all other cases the Licensor expressly
217
+ reserves any right to collect such royalties, including when
218
+ the Licensed Material is used other than for NonCommercial
219
+ purposes.
220
+
221
+
222
+ Section 3 -- License Conditions.
223
+
224
+ Your exercise of the Licensed Rights is expressly made subject to the
225
+ following conditions.
226
+
227
+ a. Attribution.
228
+
229
+ 1. If You Share the Licensed Material (including in modified
230
+ form), You must:
231
+
232
+ a. retain the following if it is supplied by the Licensor
233
+ with the Licensed Material:
234
+
235
+ i. identification of the creator(s) of the Licensed
236
+ Material and any others designated to receive
237
+ attribution, in any reasonable manner requested by
238
+ the Licensor (including by pseudonym if
239
+ designated);
240
+
241
+ ii. a copyright notice;
242
+
243
+ iii. a notice that refers to this Public License;
244
+
245
+ iv. a notice that refers to the disclaimer of
246
+ warranties;
247
+
248
+ v. a URI or hyperlink to the Licensed Material to the
249
+ extent reasonably practicable;
250
+
251
+ b. indicate if You modified the Licensed Material and
252
+ retain an indication of any previous modifications; and
253
+
254
+ c. indicate the Licensed Material is licensed under this
255
+ Public License, and include the text of, or the URI or
256
+ hyperlink to, this Public License.
257
+
258
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
259
+ reasonable manner based on the medium, means, and context in
260
+ which You Share the Licensed Material. For example, it may be
261
+ reasonable to satisfy the conditions by providing a URI or
262
+ hyperlink to a resource that includes the required
263
+ information.
264
+
265
+ 3. If requested by the Licensor, You must remove any of the
266
+ information required by Section 3(a)(1)(A) to the extent
267
+ reasonably practicable.
268
+
269
+ 4. If You Share Adapted Material You produce, the Adapter's
270
+ License You apply must not prevent recipients of the Adapted
271
+ Material from complying with this Public License.
272
+
273
+
274
+ Section 4 -- Sui Generis Database Rights.
275
+
276
+ Where the Licensed Rights include Sui Generis Database Rights that
277
+ apply to Your use of the Licensed Material:
278
+
279
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280
+ to extract, reuse, reproduce, and Share all or a substantial
281
+ portion of the contents of the database for NonCommercial purposes
282
+ only;
283
+
284
+ b. if You include all or a substantial portion of the database
285
+ contents in a database in which You have Sui Generis Database
286
+ Rights, then the database in which You have Sui Generis Database
287
+ Rights (but not its individual contents) is Adapted Material; and
288
+
289
+ c. You must comply with the conditions in Section 3(a) if You Share
290
+ all or a substantial portion of the contents of the database.
291
+
292
+ For the avoidance of doubt, this Section 4 supplements and does not
293
+ replace Your obligations under this Public License where the Licensed
294
+ Rights include other Copyright and Similar Rights.
295
+
296
+
297
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298
+
299
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309
+
310
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319
+
320
+ c. The disclaimer of warranties and limitation of liability provided
321
+ above shall be interpreted in a manner that, to the extent
322
+ possible, most closely approximates an absolute disclaimer and
323
+ waiver of all liability.
324
+
325
+
326
+ Section 6 -- Term and Termination.
327
+
328
+ a. This Public License applies for the term of the Copyright and
329
+ Similar Rights licensed here. However, if You fail to comply with
330
+ this Public License, then Your rights under this Public License
331
+ terminate automatically.
332
+
333
+ b. Where Your right to use the Licensed Material has terminated under
334
+ Section 6(a), it reinstates:
335
+
336
+ 1. automatically as of the date the violation is cured, provided
337
+ it is cured within 30 days of Your discovery of the
338
+ violation; or
339
+
340
+ 2. upon express reinstatement by the Licensor.
341
+
342
+ For the avoidance of doubt, this Section 6(b) does not affect any
343
+ right the Licensor may have to seek remedies for Your violations
344
+ of this Public License.
345
+
346
+ c. For the avoidance of doubt, the Licensor may also offer the
347
+ Licensed Material under separate terms or conditions or stop
348
+ distributing the Licensed Material at any time; however, doing so
349
+ will not terminate this Public License.
350
+
351
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352
+ License.
353
+
354
+
355
+ Section 7 -- Other Terms and Conditions.
356
+
357
+ a. The Licensor shall not be bound by any additional or different
358
+ terms or conditions communicated by You unless expressly agreed.
359
+
360
+ b. Any arrangements, understandings, or agreements regarding the
361
+ Licensed Material not stated herein are separate from and
362
+ independent of the terms and conditions of this Public License.
363
+
364
+
365
+ Section 8 -- Interpretation.
366
+
367
+ a. For the avoidance of doubt, this Public License does not, and
368
+ shall not be interpreted to, reduce, limit, restrict, or impose
369
+ conditions on any use of the Licensed Material that could lawfully
370
+ be made without permission under this Public License.
371
+
372
+ b. To the extent possible, if any provision of this Public License is
373
+ deemed unenforceable, it shall be automatically reformed to the
374
+ minimum extent necessary to make it enforceable. If the provision
375
+ cannot be reformed, it shall be severed from this Public License
376
+ without affecting the enforceability of the remaining terms and
377
+ conditions.
378
+
379
+ c. No term or condition of this Public License will be waived and no
380
+ failure to comply consented to unless expressly agreed to by the
381
+ Licensor.
382
+
383
+ d. Nothing in this Public License constitutes or may be interpreted
384
+ as a limitation upon, or waiver of, any privileges and immunities
385
+ that apply to the Licensor or You, including from the legal
386
+ processes of any jurisdiction or authority.
387
+
388
+ =======================================================================
389
+
390
+ Creative Commons is not a party to its public
391
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
392
+ its public licenses to material it publishes and in those instances
393
+ will be considered the “Licensor.” The text of the Creative Commons
394
+ public licenses is dedicated to the public domain under the CC0 Public
395
+ Domain Dedication. Except for the limited purpose of indicating that
396
+ material is shared under a Creative Commons public license or as
397
+ otherwise permitted by the Creative Commons policies published at
398
+ creativecommons.org/policies, Creative Commons does not authorize the
399
+ use of the trademark "Creative Commons" or any other trademark or logo
400
+ of Creative Commons without its prior written consent including,
401
+ without limitation, in connection with any unauthorized modifications
402
+ to any of its public licenses or any other arrangements,
403
+ understandings, or agreements concerning use of licensed material. For
404
+ the avoidance of doubt, this paragraph does not form part of the
405
+ public licenses.
406
+
407
+ Creative Commons may be contacted at creativecommons.org.
PROMPT_GUIDE.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prompt Guide
2
+
3
+ All examples are generated with a CFG of $4.2$, $50$ steps, and are non-cherrypicked unless otherwise stated. Negative prompt is set to:
4
+ ```
5
+ monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation
6
+ ```
7
+
8
+ ## 1. Text-to-Image
9
+
10
+ ### 1.1 Long and detailed prompts give (much) better results.
11
+
12
+ Since our training comprised of long and detailed prompts, the model is more likely to generate better images with detailed prompts.
13
+
14
+
15
+ The model shows good text adherence with long and complex prompts as in below images. We use the first $20$ prompts from [simoryu's examples](https://cloneofsimo.github.io/compare_aura_sd3/). For detailed prompts, results of other models, refer to the above link.
16
+
17
+ <p align="center">
18
+ <img src="assets/promptguide_complex.jpg" alt="Text-to-Image results" width="800">
19
+ </p>
20
+
21
+
22
+ ### 1.2 Resolution
23
+
24
+ The model generally works well with height and width in range of $[768; 1280]$ (height/width must be divisible by 16) for text-to-image. For other tasks, it performs best with resolution around $512$.
25
+
26
+ ## 2. ID Customization & Subject-driven generation
27
+
28
+ - The expected length of source captions is $30$ to $75$ words. Empirically, we find that longer prompt can help preserve the ID better but it might hinder the text-adherence for target caption.
29
+
30
+ - We find it better to add some descriptions (e.g., from source caption) to target to preserve the identity, especially for complex subjects with delicate details.
31
+
32
+ <p align="center">
33
+ <img src="assets/promptguide_idtask.jpg" alt="ablation id task" width="800">
34
+ </p>
35
+
36
+ ## 3. Multiview generation
37
+
38
+ We recommend not use captions, which describe the facial features e.g., looking at the camera, etc, to mitigate multifaced/janus problems.
39
+
40
+ ## 4. Image editing
41
+
42
+ We find it's generally better to set the guidance scale to lower value e.g., $[3; 3.5]$ to avoid over-saturation results.
43
+
44
+ ## 5. Special tokens and available colors
45
+
46
+ ### 5.1 Task Tokens
47
+
48
+ | Task | Token | Additional Tokens |
49
+ |:---------------------|:---------------------------|:------------------|
50
+ | Text to Image | `[[text2image]]` | |
51
+ | Deblurring | `[[deblurring]]` | |
52
+ | Inpainting | `[[image_inpainting]]` | |
53
+ | Canny-edge and Image | `[[canny2image]]` | |
54
+ | Depth and Image | `[[depth2image]]` | |
55
+ | Hed and Image | `[[hed2img]]` | |
56
+ | Pose and Image | `[[pose2image]]` | |
57
+ | Image editing with Instruction | `[[image_editing]]` | |
58
+ | Semantic map and Image| `[[semanticmap2image]]` | `<#00FFFF cyan mask: object/to/segment>` |
59
+ | Boundingbox and Image | `[[boundingbox2image]]` | `<#00FFFF cyan boundingbox: object/to/detect>` |
60
+ | ID customization | `[[faceid]]` | `[[img0]] target/caption [[img1]] caption/of/source/image_1 [[img2]] caption/of/source/image_2 [[img3]] caption/of/source/image_3` |
61
+ | Multiview | `[[multiview]]` | |
62
+ | Subject-Driven | `[[subject_driven]]` | `<item: name/of/subject> [[img0]] target/caption/goes/here [[img1]] insert/source/caption` |
63
+
64
+
65
+ Note that you can replace the cyan color above with any from below table and have multiple additional tokens to detect/segment multiple classes.
66
+
67
+ ### 5.2 Available colors
68
+
69
+
70
+ | Hex Code | Color Name |
71
+ |:---------|:-----------|
72
+ | #FF0000 | <span style="color: #FF0000">red</span> |
73
+ | #00FF00 | <span style="color: #00FF00">lime</span> |
74
+ | #0000FF | <span style="color: #0000FF">blue</span> |
75
+ | #FFFF00 | <span style="color: #FFFF00">yellow</span> |
76
+ | #FF00FF | <span style="color: #FF00FF">magenta</span> |
77
+ | #00FFFF | <span style="color: #00FFFF">cyan</span> |
78
+ | #FFA500 | <span style="color: #FFA500">orange</span> |
79
+ | #800080 | <span style="color: #800080">purple</span> |
80
+ | #A52A2A | <span style="color: #A52A2A">brown</span> |
81
+ | #008000 | <span style="color: #008000">green</span> |
82
+ | #FFC0CB | <span style="color: #FFC0CB">pink</span> |
83
+ | #008080 | <span style="color: #008080">teal</span> |
84
+ | #FF8C00 | <span style="color: #FF8C00">darkorange</span> |
85
+ | #8A2BE2 | <span style="color: #8A2BE2">blueviolet</span> |
86
+ | #006400 | <span style="color: #006400">darkgreen</span> |
87
+ | #FF4500 | <span style="color: #FF4500">orangered</span> |
88
+ | #000080 | <span style="color: #000080">navy</span> |
89
+ | #FFD700 | <span style="color: #FFD700">gold</span> |
90
+ | #40E0D0 | <span style="color: #40E0D0">turquoise</span> |
91
+ | #DA70D6 | <span style="color: #DA70D6">orchid</span> |
README.md CHANGED
@@ -1,14 +1,169 @@
1
- ---
2
- title: OneDiffusion Space
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-nc-4.0
11
- short_description: demo for onediffusion
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # One Diffusion to Generate Them All
2
+
3
+ <p align="left">
4
+ <a href="https://lehduong.github.io/OneDiffusion-homepage/">
5
+ <img alt="Build" src="https://img.shields.io/badge/Project%20Page-OneDiffusion-yellow">
6
+ </a>
7
+ <a href="https://arxiv.org/abs/2411.16318">
8
+ <img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-2411.16318-b31b1b.svg">
9
+ </a>
10
+ <a href="https://huggingface.co/spaces/lehduong/OneDiffusion">
11
+ <img alt="License" src="https://img.shields.io/badge/HF%20Demo-🤗-lightblue">
12
+ </a>
13
+ <a href="https://huggingface.co/lehduong/OneDiffusion">
14
+ <img alt="Build" src="https://img.shields.io/badge/HF%20Model-🤗-yellow">
15
+ </a>
16
+ </p>
17
+
18
+ <h4 align="left">
19
+ <p>
20
+ <a href=#news>News</a> |
21
+ <a href=#quick-start>Quick start</a> |
22
+ <a href=https://github.com/lehduong/OneDiffusion/blob/main/PROMPT_GUIDE.md>Prompt guide & Supported tasks </a> |
23
+ <a href=#qualitative-results>Qualitative results</a> |
24
+ <a href="#license">License</a> |
25
+ <a href="#citation">Citation</a>
26
+ <p>
27
+ </h4>
28
+
29
+
30
+ <p align="center">
31
+ <img src="assets/teaser.png" alt="Teaser Image" width="800">
32
+ </p>
33
+
34
+
35
+ This is official repo of OneDiffusion, a versatile, large-scale diffusion model that seamlessly supports bidirectional image synthesis and understanding across diverse tasks.
36
+
37
+ ## News
38
+ - 📦 2024/12/10: Released weight.
39
+ - 📝 2024/12/06: Added image editing from instruction.
40
+ - ✨ 2024/12/02: Added subject-driven generation
41
+
42
+ ## Installation
43
+ ```
44
+ conda create -n onediffusion_env python=3.8 &&
45
+ conda activate onediffusion_env &&
46
+ pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118 &&
47
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git" &&
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ ## Quick start
52
+
53
+ Check `inference.py` for more detailed. For text-to-image, you can use below code snipe.
54
+
55
+ ```
56
+ import torch
57
+ from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline
58
+
59
+ device = torch.device('cuda:0')
60
+
61
+ pipeline = OneDiffusionPipeline.from_pretrained("lehduong/OneDiffusion").to(device=device, dtype=torch.bfloat16)
62
+
63
+ NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
64
+
65
+ output = pipeline(
66
+ prompt="[[text2image]] A bipedal black cat wearing a huge oversized witch hat, a wizards robe, casting a spell,in an enchanted forest. The scene is filled with fireflies and moss on surrounding rocks and trees",
67
+ negative_prompt=NEGATIVE_PROMPT,
68
+ num_inference_steps=50,
69
+ guidance_scale=4,
70
+ height=1024,
71
+ width=1024,
72
+ )
73
+ output.images[0].save('text2image_output.jpg')
74
+ ```
75
+
76
+ You can run the gradio demo with:
77
+ ```
78
+ python gradio_demo.py --captioner molmo # [molmo, llava, disable]
79
+ ```
80
+ The demo provides guidance and helps format the prompt properly for each task.
81
+ - By default, it loads the Molmo for captioning source images, which significantly increases memory usage. You generally need a GPU with at least $40$ GB of memory to run the demo.
82
+ - Opting to use LLaVA can reduce this requirement to $\approx 27$ GB, though the resulting captions may be less accurate in some cases.
83
+ - You can also manually provide the caption for each input image and run with `disable` mode. In this mode, the returned caption is an empty string, but you should still press the `Generate Caption` button so that the code formats the input text properly. The memory requirement for this mode is $\approx 12$ GB.
84
+
85
+ Note that the above required memory can change if you use higher resolution or more input images.
86
+
87
+ ## Qualitative Results
88
+
89
+ ### 1. Text-to-Image
90
+ <p align="center">
91
+ <img src="assets/text2image.jpg" alt="Text-to-Image results" width="800">
92
+ </p>
93
+
94
+
95
+ ### 2. ID customization
96
+
97
+ <p align="center">
98
+ <img src="assets/onediffusion_appendix_faceid.jpg" alt="ID customization" width="800">
99
+ </p>
100
+
101
+ <p align="center">
102
+ <img src="assets/onediffusion_appendix_faceid_3.jpg" alt="ID customization non-human subject" width="800">
103
+ </p>
104
+
105
+ ### 3. Multiview generation
106
+
107
+ Single image to multiview:
108
+
109
+ <p align="center">
110
+ <img src="assets/onediffusion_appendix_multiview.jpg" alt="Image to multiview" width="800">
111
+ </p>
112
+
113
+ <p align="center">
114
+ <img src="assets/onediffusion_appendix_multiview_2.jpg" alt="image to multiview" width="800">
115
+ </p>
116
+
117
+ Text to multiview:
118
+
119
+ <p align="center">
120
+ <img src="assets/text2multiview.jpg" alt="Text to multiview image" width="800">
121
+ </p>
122
+
123
+ ### 4. Condition-to-Image and vice versa
124
+ <p align="center">
125
+ <img src="assets/cond_and_image.jpg" alt="Condition and Image" width="800">
126
+ </p>
127
+
128
+ ### 5. Subject-driven generation
129
+
130
+ We finetuned the model on [Subject-200K](https://huggingface.co/datasets/Yuanshi/Subjects200K) dataset (along with all other tasks) for additional 40k steps. The model is now capable of subject-driven generation.
131
+
132
+ <p align="center">
133
+ <img src="assets/subject_driven.jpg" alt="Subject driven generation" width="800">
134
+ </p>
135
+
136
+ ### 6. Text-guide image editing
137
+
138
+ We finetuned the model on [OmniEdit](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M) dataset for additional 30K steps.
139
+
140
+ <p align="center">
141
+ <img src="assets/onediffusion_editing.jpg" alt="Text-guide editing" width="800">
142
+ </p>
143
+
144
+ ### 7. Zero-shot Task combinations
145
+
146
+ We found that the model can handle multiple tasks in a zero-shot setting by combining condition images and task tokens without any fine-tuning, as shown in the examples below. However, its performance on these combined tasks might not be robust, and the model’s behavior may change if the order of task tokens or captions is altered. For example, when using both image inpainting and ID customization together, the target prompt and the caption of the masked image must be identical. If you plan to use such combinations, we recommend fine-tuning the model on these tasks to achieve better performance and simpler usage.
147
+
148
+
149
+ <p align="center">
150
+ <img src="assets/onediffusion_zeroshot.jpg" alt="Subject driven generation" width="800">
151
+ </p>
152
+
153
+ ## License
154
+
155
+ The model is trained on several non-commercially licensed datasets (e.g., DL3DV, Unsplash), thus, **model weights** are released under a CC BY-NC license as described in [LICENSE](https://github.com/lehduong/onediffusion/blob/main/LICENSE).
156
+
157
+ ## Citation
158
+
159
+ ```bibtex
160
+ @misc{le2024diffusiongenerate,
161
+ title={One Diffusion to Generate Them All},
162
+ author={Duong H. Le and Tuan Pham and Sangho Lee and Christopher Clark and Aniruddha Kembhavi and Stephan Mandt and Ranjay Krishna and Jiasen Lu},
163
+ year={2024},
164
+ eprint={2411.16318},
165
+ archivePrefix={arXiv},
166
+ primaryClass={cs.CV},
167
+ url={https://arxiv.org/abs/2411.16318},
168
+ }
169
+ ```
assets/cond_and_image.jpg ADDED

Git LFS Details

  • SHA256: 6fcf6f6327d4a72a05dea636e7cacae6c2bdee4b61d7f583424c15f91e4bb903
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
assets/examples/id_customization/chenhao/image_0.png ADDED
assets/examples/id_customization/chenhao/image_1.png ADDED
assets/examples/id_customization/chenhao/image_2.png ADDED
assets/onediffusion_appendix_faceid.jpg ADDED

Git LFS Details

  • SHA256: 8a04d050bf0d2b0f6ec13934f09387bf7d0dac82c32b0c1d808d6013e25cf6ec
  • Pointer size: 132 Bytes
  • Size of remote file: 1.75 MB
assets/onediffusion_appendix_faceid_3.jpg ADDED

Git LFS Details

  • SHA256: dd8a9a2bb587e4093cb9b9ab36d06675c0fbc91bbd8f1ff3c45e7cc0fb1d211e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.68 MB
assets/onediffusion_appendix_multiview.jpg ADDED

Git LFS Details

  • SHA256: 70026d6376c2d52ad268be1f5d2b7dc80fee716fa76b6c2aa611f105cbb76614
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
assets/onediffusion_appendix_multiview_2.jpg ADDED
assets/onediffusion_appendix_text2multiview.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60a945f7c1c92e823dbd3c7876c843d2668d16cb1ae883f2ba4080d324056225
3
+ size 8278287
assets/onediffusion_editing.jpg ADDED
assets/onediffusion_zeroshot.jpg ADDED

Git LFS Details

  • SHA256: a243196d8e6ca959357af24a71d183e15ebb90910ef0deca56af70ebe59a83f3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.99 MB
assets/promptguide_complex.jpg ADDED

Git LFS Details

  • SHA256: e8e338d97b8e4f90b52e2fe5680f00a4538cdbe7c0423e07042bc1780aa94a51
  • Pointer size: 132 Bytes
  • Size of remote file: 2.05 MB
assets/promptguide_idtask.jpg ADDED
assets/subject_driven.jpg ADDED
assets/teaser.png ADDED

Git LFS Details

  • SHA256: 2c9ba3c39cdd6882d6c8172e45f71bef282d9439700c3adb6fa951ee394afedc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.78 MB
assets/text2image.jpg ADDED
assets/text2multiview.jpg ADDED

Git LFS Details

  • SHA256: 98bc67be460dd5cb1207be5ac1a7ae842fea0c65546a3c66b92b132aa4652cc4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.37 MB
docker/Dockerfile ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile
2
+ # ARG COMPAT=0
3
+ ARG PERSONAL=0
4
+ # FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0
5
+ FROM nvcr.io/nvidia/pytorch:22.12-py3 as base
6
+
7
+ ENV HOST docker
8
+ ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
9
+ # https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes
10
+ ENV TZ America/Los_Angeles
11
+ RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
12
+
13
+ # git for installing dependencies
14
+ # tzdata to set time zone
15
+ # wget and unzip to download data
16
+ # [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment.
17
+ # [2021-12-07] TD: openmpi-bin for MPI (multi-node training)
18
+ RUN apt-get update && apt-get install -y --no-install-recommends \
19
+ build-essential \
20
+ cmake \
21
+ curl \
22
+ ca-certificates \
23
+ sudo \
24
+ less \
25
+ htop \
26
+ git \
27
+ tzdata \
28
+ wget \
29
+ tmux \
30
+ zip \
31
+ unzip \
32
+ zsh stow subversion fasd \
33
+ && rm -rf /var/lib/apt/lists/*
34
+ # openmpi-bin \
35
+
36
+ # Allow running runmpi as root
37
+ # ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1
38
+
39
+ # # Create a non-root user and switch to it
40
+ # RUN adduser --disabled-password --gecos '' --shell /bin/bash user \
41
+ # && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user
42
+ # USER user
43
+
44
+ # All users can use /home/user as their home directory
45
+ ENV HOME=/home/user
46
+ RUN mkdir -p /home/user && chmod 777 /home/user
47
+ WORKDIR /home/user
48
+
49
+ # Set up personal environment
50
+ # FROM base-${COMPAT} as env-0
51
+ FROM base as env-0
52
+ FROM env-0 as env-1
53
+ # Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image
54
+ # https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile
55
+ ONBUILD COPY dotfiles ./dotfiles
56
+ ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami)
57
+ # nvcr pytorch image sets SHELL=/bin/bash
58
+ ONBUILD ENV SHELL=/bin/zsh
59
+
60
+ FROM env-${PERSONAL} as packages
61
+
62
+ # Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for
63
+ ENV PIP_NO_CACHE_DIR=1
64
+
65
+ # # apex and pytorch-fast-transformers take a while to compile so we install them first
66
+ # TD [2022-04-28] apex is already installed. In case we need a newer commit:
67
+ # RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex
68
+
69
+ # xgboost conflicts with deepspeed
70
+ RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.7
71
+
72
+ # General packages that we don't care about the version
73
+ # zstandard to extract the_pile dataset
74
+ # psutil to get the number of cpu physical cores
75
+ # twine to upload package to PyPI
76
+ RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine gdown \
77
+ && python -m spacy download en_core_web_sm
78
+ # hydra
79
+ RUN pip install hydra-core==1.3.1 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich
80
+ # Core packages
81
+ RUN pip install transformers==4.45.2 datasets==3.0.1 pytorch-lightning==2.2.1 triton==2.3.1 wandb==0.16.3 controlnet_aux==0.0.9 timm==0.6.7 torchmetrics==1.3.2
82
+ # torchmetrics 0.11.0 broke hydra's instantiate
83
+
84
+ # For MLPerf
85
+ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
86
+
87
+ RUN pip install accelerate==0.34.2
88
+
89
+ RUN pip install diffusers==0.30.3
90
+
91
+ RUN pip install deepspeed==0.15.2
92
+
93
+ RUN pip install sentencepiece==0.1.99
94
+
95
+ RUN pip install pillow==10.2.0
96
+
97
+ RUN pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
98
+
99
+ # Install FlashAttention
100
+ RUN pip install flash-attn==2.6.3
101
+
102
+ # Install CUDA extensions for fused dense
103
+ RUN pip install git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib
104
+
105
+ RUN pip install jaxtyping mediapipe gradio
106
+
107
+ RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git"
108
+
109
+ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
110
+
111
+ RUN pip install opencv-python==4.5.5.64
112
+
113
+ RUN pip install opencv-python-headless==4.5.5.64
114
+
115
+ RUN pip install huggingface_hub==0.24
116
+
117
+ RUN pip install numpy==1.24.4
118
+
119
+
gradio_demo.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import base64
4
+ import io
5
+ from PIL import Image
6
+ from transformers import (
7
+ LlavaNextProcessor, LlavaNextForConditionalGeneration,
8
+ T5EncoderModel, T5Tokenizer
9
+ )
10
+ from transformers import (
11
+ AutoProcessor, AutoModelForCausalLM, GenerationConfig,
12
+ T5EncoderModel, T5Tokenizer
13
+ )
14
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FluxPipeline
15
+ from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline
16
+ from onediffusion.models.denoiser.nextdit import NextDiT
17
+ from onediffusion.dataset.utils import get_closest_ratio, ASPECT_RATIO_512
18
+ from typing import List, Optional
19
+ import matplotlib
20
+ import numpy as np
21
+ import cv2
22
+ import argparse
23
+
24
+ # Task-specific tokens
25
+ TASK2SPECIAL_TOKENS = {
26
+ "text2image": "[[text2image]]",
27
+ "deblurring": "[[deblurring]]",
28
+ "inpainting": "[[image_inpainting]]",
29
+ "canny": "[[canny2image]]",
30
+ "depth2image": "[[depth2image]]",
31
+ "hed2image": "[[hed2img]]",
32
+ "pose2image": "[[pose2image]]",
33
+ "semanticmap2image": "[[semanticmap2image]]",
34
+ "boundingbox2image": "[[boundingbox2image]]",
35
+ "image_editing": "[[image_editing]]",
36
+ "faceid": "[[faceid]]",
37
+ "multiview": "[[multiview]]",
38
+ "subject_driven": "[[subject_driven]]"
39
+ }
40
+ NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
41
+
42
+
43
+ class LlavaCaptionProcessor:
44
+ def __init__(self):
45
+ model_name = "llava-hf/llama3-llava-next-8b-hf"
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
48
+ self.processor = LlavaNextProcessor.from_pretrained(model_name)
49
+ self.model = LlavaNextForConditionalGeneration.from_pretrained(
50
+ model_name, torch_dtype=dtype, low_cpu_mem_usage=True
51
+ ).to(device)
52
+ self.SPECIAL_TOKENS = "assistant\n\n\n"
53
+
54
+ def generate_response(self, image: Image.Image, msg: str) -> str:
55
+ conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": msg}]}]
56
+ with torch.no_grad():
57
+ prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
58
+ inputs = self.processor(prompt, image, return_tensors="pt").to(self.model.device)
59
+ output = self.model.generate(**inputs, max_new_tokens=250)
60
+ response = self.processor.decode(output[0], skip_special_tokens=True)
61
+ return response.split(msg)[-1].strip()[len(self.SPECIAL_TOKENS):]
62
+
63
+ def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
64
+ if msg is None:
65
+ msg = f"Describe the contents of the photo in 150 words or fewer."
66
+ try:
67
+ return [self.generate_response(img, msg) for img in images]
68
+ except Exception as e:
69
+ print(f"Error in process: {str(e)}")
70
+ raise
71
+
72
+
73
+ class MolmoCaptionProcessor:
74
+ def __init__(self):
75
+ pretrained_model_name = 'allenai/Molmo-7B-O-0924'
76
+ self.processor = AutoProcessor.from_pretrained(
77
+ pretrained_model_name,
78
+ trust_remote_code=True,
79
+ torch_dtype='auto',
80
+ device_map='auto'
81
+ )
82
+ self.model = AutoModelForCausalLM.from_pretrained(
83
+ pretrained_model_name,
84
+ trust_remote_code=True,
85
+ torch_dtype='auto',
86
+ device_map='auto'
87
+ )
88
+
89
+ def generate_response(self, image: Image.Image, msg: str) -> str:
90
+ inputs = self.processor.process(
91
+ images=[image],
92
+ text=msg
93
+ )
94
+ # Move inputs to the correct device and make a batch of size 1
95
+ inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
96
+
97
+ # Generate output
98
+ output = self.model.generate_from_batch(
99
+ inputs,
100
+ GenerationConfig(max_new_tokens=250, stop_strings="<|endoftext|>"),
101
+ tokenizer=self.processor.tokenizer
102
+ )
103
+
104
+ # Only get generated tokens and decode them to text
105
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
106
+ return self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
107
+
108
+
109
+ def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
110
+ if msg is None:
111
+ msg = f"Describe the contents of the photo in 150 words or fewer."
112
+ try:
113
+ return [self.generate_response(img, msg) for img in images]
114
+ except Exception as e:
115
+ print(f"Error in process: {str(e)}")
116
+ raise
117
+
118
+
119
+ class PlaceHolderCaptionProcessor:
120
+ def __init__(self):
121
+ pass
122
+
123
+ def generate_response(self, image: Image.Image, msg: str) -> str:
124
+ return ""
125
+
126
+ def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
127
+ return [""] * len(images)
128
+
129
+
130
+ def initialize_models(captioner_name):
131
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
132
+ pipeline = OneDiffusionPipeline.from_pretrained("lehduong/OneDiffusion").to(device=device, dtype=torch.bfloat16)
133
+ if captioner_name == 'molmo':
134
+ captioner = MolmoCaptionProcessor()
135
+ elif captioner_name == 'llava':
136
+ captioner = LlavaCaptionProcessor()
137
+ else:
138
+ captioner = PlaceHolderCaptionProcessor()
139
+ return pipeline, captioner
140
+
141
+ def colorize_depth_maps(
142
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
143
+ ):
144
+ """
145
+ Colorize depth maps with reversed colors.
146
+ """
147
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
148
+
149
+ if isinstance(depth_map, torch.Tensor):
150
+ depth = depth_map.detach().squeeze().numpy()
151
+ elif isinstance(depth_map, np.ndarray):
152
+ depth = depth_map.copy().squeeze()
153
+ # reshape to [ (B,) H, W ]
154
+ if depth.ndim < 3:
155
+ depth = depth[np.newaxis, :, :]
156
+
157
+ # Normalize depth values to [0, 1]
158
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
159
+ # Invert the depth values to reverse the colors
160
+ depth = 1 - depth
161
+
162
+ # Use the colormap
163
+ cm = matplotlib.colormaps[cmap]
164
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # values from 0 to 1
165
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
166
+
167
+ if valid_mask is not None:
168
+ if isinstance(depth_map, torch.Tensor):
169
+ valid_mask = valid_mask.detach().numpy()
170
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
171
+ if valid_mask.ndim < 3:
172
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
173
+ else:
174
+ valid_mask = valid_mask[:, np.newaxis, :, :]
175
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
176
+ img_colored_np[~valid_mask] = 0
177
+
178
+ if isinstance(depth_map, torch.Tensor):
179
+ img_colored = torch.from_numpy(img_colored_np).float()
180
+ elif isinstance(depth_map, np.ndarray):
181
+ img_colored = img_colored_np
182
+
183
+ return img_colored
184
+
185
+
186
+ def format_prompt(task_type: str, captions: List[str]) -> str:
187
+ if not captions:
188
+ return ""
189
+ if task_type == "faceid":
190
+ img_prompts = [f"[[img{i}]] {caption}" for i, caption in enumerate(captions, start=1)]
191
+ return f"[[faceid]] [[img0]] insert/your/caption/here {' '.join(img_prompts)}"
192
+ elif task_type == "image_editing":
193
+ return f"[[image_editing]] insert/your/instruction/here"
194
+ elif task_type == "semanticmap2image":
195
+ return f"[[semanticmap2image]] <#00ffff Cyan mask: insert/concept/to/segment/here> {captions[0]}"
196
+ elif task_type == "boundingbox2image":
197
+ return f"[[boundingbox2image]] <#00ffff Cyan boundingbox: insert/concept/to/segment/here> {captions[0]}"
198
+ elif task_type == "multiview":
199
+ img_prompts = captions[0]
200
+ return f"[[multiview]] {img_prompts}"
201
+ elif task_type == "subject_driven":
202
+ return f"[[subject_driven]] <item: insert/item/here> [[img0]] insert/your/target/caption/here [[img1]] {captions[0]}"
203
+ else:
204
+ return f"{TASK2SPECIAL_TOKENS[task_type]} {captions[0]}"
205
+
206
+ def update_prompt(images: List[Image.Image], task_type: str, custom_msg: str = None):
207
+ if not images:
208
+ return format_prompt(task_type, []), "Please upload at least one image!"
209
+ try:
210
+ captions = captioner.process(images, custom_msg)
211
+ if not captions:
212
+ return "", "No valid images found!"
213
+ prompt = format_prompt(task_type, captions)
214
+ return prompt, f"Generated {len(captions)} captions successfully!"
215
+ except Exception as e:
216
+ return "", f"Error generating captions: {str(e)}"
217
+
218
+ def generate_image(images: List[Image.Image], prompt: str, negative_prompt: str, num_inference_steps: int, guidance_scale: float,
219
+ denoise_mask: List[str], task_type: str, azimuth: str, elevation: str, distance: str, focal_length: float,
220
+ height: int = 1024, width: int = 1024, scale_factor: float = 1.0, scale_watershed: float = 1.0,
221
+ noise_scale: float = None, progress=gr.Progress()):
222
+ try:
223
+ img2img_kwargs = {
224
+ 'prompt': prompt,
225
+ 'negative_prompt': negative_prompt,
226
+ 'num_inference_steps': num_inference_steps,
227
+ 'guidance_scale': guidance_scale,
228
+ 'height': height,
229
+ 'width': width,
230
+ 'forward_kwargs': {
231
+ 'scale_factor': scale_factor,
232
+ 'scale_watershed': scale_watershed
233
+ },
234
+ 'noise_scale': noise_scale # Added noise_scale here
235
+ }
236
+
237
+ if task_type == 'multiview':
238
+ # Parse azimuth, elevation, and distance into lists, allowing 'None' values
239
+ azimuths = [float(a.strip()) if a.strip().lower() != 'none' else None for a in azimuth.split(',')] if azimuth else []
240
+ elevations = [float(e.strip()) if e.strip().lower() != 'none' else None for e in elevation.split(',')] if elevation else []
241
+ distances = [float(d.strip()) if d.strip().lower() != 'none' else None for d in distance.split(',')] if distance else []
242
+
243
+ num_views = max(len(images), len(azimuths), len(elevations), len(distances))
244
+ if num_views == 0:
245
+ return None, "At least one image or camera parameter must be provided."
246
+
247
+ total_components = []
248
+ for i in range(num_views):
249
+ total_components.append(f"image_{i}")
250
+ total_components.append(f"camera_pose_{i}")
251
+
252
+ denoise_mask_int = [1 if comp in denoise_mask else 0 for comp in total_components]
253
+
254
+ if len(denoise_mask_int) != len(total_components):
255
+ return None, f"Denoise mask length mismatch: expected {len(total_components)} components."
256
+
257
+ # Pad the input lists to num_views length
258
+ images_padded = images + [] * (num_views - len(images)) # Do not add None
259
+ azimuths_padded = azimuths + [None] * (num_views - len(azimuths))
260
+ elevations_padded = elevations + [None] * (num_views - len(elevations))
261
+ distances_padded = distances + [None] * (num_views - len(distances))
262
+
263
+ # Prepare values
264
+ img2img_kwargs.update({
265
+ 'image': images_padded,
266
+ 'multiview_azimuths': azimuths_padded,
267
+ 'multiview_elevations': elevations_padded,
268
+ 'multiview_distances': distances_padded,
269
+ 'multiview_focal_length': focal_length, # Pass focal_length here
270
+ 'is_multiview': True,
271
+ 'denoise_mask': denoise_mask_int,
272
+ # 'predict_camera_poses': True,
273
+ })
274
+ else:
275
+ total_components = ["image_0"] + [f"image_{i+1}" for i in range(len(images))]
276
+ denoise_mask_int = [1 if comp in denoise_mask else 0 for comp in total_components]
277
+ if len(denoise_mask_int) != len(total_components):
278
+ return None, f"Denoise mask length mismatch: expected {len(total_components)} components."
279
+
280
+ img2img_kwargs.update({
281
+ 'image': images,
282
+ 'denoise_mask': denoise_mask_int
283
+ })
284
+
285
+ progress(0, desc="Generating image...")
286
+ if task_type == 'text2image':
287
+ output = pipeline(
288
+ prompt=prompt,
289
+ negative_prompt=negative_prompt,
290
+ num_inference_steps=num_inference_steps,
291
+ guidance_scale=guidance_scale,
292
+ height=height,
293
+ width=width,
294
+ scale_factor=scale_factor,
295
+ scale_watershed=scale_watershed,
296
+ noise_scale=noise_scale # Added noise_scale here
297
+ )
298
+ else:
299
+ output = pipeline.img2img(**img2img_kwargs)
300
+ progress(1, desc="Done!")
301
+
302
+ # Process the output images if task is 'depth2image' and predicting depth
303
+ if task_type == 'depth2image' and denoise_mask_int[-1] == 1:
304
+ processed_images = []
305
+ for img in output.images:
306
+ depth_map = np.array(img.convert('L')) # Convert to grayscale numpy array
307
+ min_depth = depth_map.min()
308
+ max_depth = depth_map.max()
309
+ colorized = colorize_depth_maps(depth_map, min_depth, max_depth)[0]
310
+ colorized = np.transpose(colorized, (1, 2, 0))
311
+ colorized = (colorized * 255).astype(np.uint8)
312
+ img_colorized = Image.fromarray(colorized)
313
+ processed_images.append(img_colorized)
314
+ output_images = processed_images + output.images
315
+ elif task_type in ['boundingbox2image', 'semanticmap2image'] and denoise_mask_int == [0,1] and images:
316
+ # Interpolate between input and output images
317
+ processed_images = []
318
+ for input_img, output_img in zip(images, output.images):
319
+ input_img_resized = input_img.resize(output_img.size)
320
+ blended_img = Image.blend(input_img_resized, output_img, alpha=0.5)
321
+ processed_images.append(blended_img)
322
+ output_images = processed_images + output.images
323
+ else:
324
+ output_images = output.images
325
+
326
+ return output_images, "Generation completed successfully!"
327
+
328
+ except Exception as e:
329
+ return None, f"Error during generation: {str(e)}"
330
+
331
+ def update_denoise_checkboxes(images_state: List[Image.Image], task_type: str, azimuth: str, elevation: str, distance: str):
332
+ if task_type == 'multiview':
333
+ azimuths = [a.strip() for a in azimuth.split(',')] if azimuth else []
334
+ elevations = [e.strip() for e in elevation.split(',')] if elevation else []
335
+ distances = [d.strip() for d in distance.split(',')] if distance else []
336
+ images_len = len(images_state)
337
+
338
+ num_views = max(images_len, len(azimuths), len(elevations), len(distances))
339
+ if num_views == 0:
340
+ return gr.update(choices=[], value=[]), "Please provide at least one image or camera parameter."
341
+
342
+ # Pad lists to the same length
343
+ azimuths += ['None'] * (num_views - len(azimuths))
344
+ elevations += ['None'] * (num_views - len(elevations))
345
+ distances += ['None'] * (num_views - len(distances))
346
+ # Do not add None to images_state
347
+
348
+ labels = []
349
+ values = []
350
+ for i in range(num_views):
351
+ labels.append(f"image_{i}")
352
+ labels.append(f"camera_pose_{i}")
353
+
354
+ # Default behavior: condition on provided inputs, generate missing ones
355
+ if i >= images_len:
356
+ values.append(f"image_{i}")
357
+ if azimuths[i].lower() == 'none' or elevations[i].lower() == 'none' or distances[i].lower() == 'none':
358
+ values.append(f"camera_pose_{i}")
359
+
360
+ return gr.update(choices=labels, value=values)
361
+ else:
362
+ labels = ["image_0"] + [f"image_{i+1}" for i in range(len(images_state))]
363
+ values = ["image_0"]
364
+ return gr.update(choices=labels, value=values)
365
+
366
+ def apply_mask(images_state):
367
+ if len(images_state) < 2:
368
+ return None, "Please upload at least two images: first as the base image, second as the mask."
369
+ base_img = images_state[0]
370
+ mask_img = images_state[1]
371
+
372
+ # Convert images to arrays
373
+ base_arr = np.array(base_img)
374
+ mask_arr = np.array(mask_img)
375
+
376
+ # Convert mask to grayscale
377
+ if mask_arr.ndim == 3:
378
+ gray_mask = cv2.cvtColor(mask_arr, cv2.COLOR_RGB2GRAY)
379
+ else:
380
+ gray_mask = mask_arr
381
+
382
+ # Create a binary mask where non-black pixels are True
383
+ binary_mask = gray_mask > 10
384
+
385
+ # Define the gray color
386
+ gray_color = np.array([128, 128, 128], dtype=np.uint8)
387
+
388
+ # Apply gray color where mask is True
389
+ masked_arr = base_arr.copy()
390
+ masked_arr[binary_mask] = gray_color
391
+
392
+ masked_img = Image.fromarray(masked_arr)
393
+ return [masked_img], "Mask applied successfully!"
394
+
395
+ def process_images_for_task_type(images_state: List[Image.Image], task_type: str):
396
+ # No changes needed here since we are processing the output images
397
+ return images_state, images_state
398
+
399
+ with gr.Blocks(title="OneDiffusion Demo") as demo:
400
+ gr.Markdown("""
401
+ # OneDiffusion Demo
402
+
403
+ **Welcome to the OneDiffusion Demo!**
404
+
405
+ This application allows you to generate images based on your input prompts for various tasks. Here's how to use it:
406
+
407
+ 1. **Select Task Type**: Choose the type of task you want to perform from the "Task Type" dropdown menu.
408
+
409
+ 2. **Upload Images**: Drag and drop images directly onto the upload area, or click to select files from your device.
410
+
411
+ 3. **Generate Captions**: **If you upload any images**, Click the "Generate Captions with Molmo" button to generate descriptive captions for your uploaded images (depend on the task). You can enter a custom message in the "Custom Message for Molmo" textbox e.g., "caption in 30 words" instead of 50 words.
412
+
413
+ 4. **Configure Generation Settings**: Expand the "Advanced Configuration" section to adjust parameters like the number of inference steps, guidance scale, image size, and more.
414
+
415
+ 5. **Generate Images**: After setting your preferences, click the "Generate Image" button. The generated images will appear in the "Generated Images" gallery.
416
+
417
+ 6. **Manage Images**: Use the "Delete Selected Images" or "Delete All Images" buttons to remove unwanted images from the gallery.
418
+
419
+ **Notes**:
420
+ - Check out the [Prompt Guide](https://github.com/lehduong/OneDiffusion/blob/main/PROMPT_GUIDE.md).
421
+
422
+ - For text-to-image:
423
+ + simply enter your prompt in this format "[[text2image]] your/prompt/here" and press the "Generate Image" button.
424
+
425
+ - For boundingbox2image/semantic2image/inpainting etc tasks:
426
+ + To perform condition-to-image such as semantic map to image, follow above steps
427
+ + For image-to-condition e.g., image to depth, change the denoise_mask checkbox before generating images. You must UNCHECK image_0 box and CHECK image_1 box.
428
+
429
+ - For FaceID tasks:
430
+ + Use 3 or 4 images if single input image does not give satisfactory results.
431
+ + All images will be resized and center cropped to the input height and width. You should choose height and width so that faces in input images won't be cropped.
432
+ + Model works best with close-up portrait (input and output) images.
433
+ + If the model does not conform your text prompt, try using shorter caption for source image(s).
434
+ + If you have non-human subjects and does not get satisfactory results, try "copying" part of caption of source images where it describes the properties of the subject e.g., a monster with red eyes, sharp teeth, etc.
435
+
436
+ - For Multiview generation:
437
+ + The input camera elevation/azimuth ALWAYS starts with $0$. If you want to generate images of azimuths 30,60,90 and elevations of 10,20,30 (wrt input image), the correct input azimuth is: `0, 30, 60, 90`; input elevation is `0,10,20,30`. The camera distance will be `1.5,1.5,1.5,1.5`
438
+ + Only support square images (ideally in 512x512 resolution).
439
+ + Ensure the number of elevations, azimuths, and distances are equal.
440
+ + The model generally works well for 2-5 views (include both input and generated images). Since the model is trained with 3 views on 512x512 resolution, you might try scale_factor of [1.1; 1.5] and scale_watershed of [100; 400] for better extrapolation.
441
+ + For better results:
442
+ 1) try increasing num_inference_steps to 75-100.
443
+ 2) avoid aggressively changes in target camera poses, for example to generate novel views at azimuth of 180, (simultaneously) generate 4 views with azimuth of 45, 90, 135, 180.
444
+
445
+ Enjoy creating images with OneDiffusion!
446
+ """)
447
+
448
+ with gr.Row():
449
+ with gr.Column():
450
+ images_state = gr.State([])
451
+ selected_indices_state = gr.State([])
452
+
453
+ with gr.Row():
454
+ gallery = gr.Gallery(
455
+ label="Input Images",
456
+ show_label=True,
457
+ columns=2,
458
+ rows=2,
459
+ height="auto",
460
+ object_fit="contain"
461
+ )
462
+
463
+ # In the UI section, update the file_output component:
464
+ file_output = gr.File(
465
+ file_count="multiple",
466
+ file_types=["image"],
467
+ label="Drag and drop images here or click to upload",
468
+ height=100,
469
+ scale=2,
470
+ type="filepath" # Add this parameter
471
+ )
472
+
473
+ with gr.Row():
474
+ delete_button = gr.Button("Delete Selected Images")
475
+ delete_all_button = gr.Button("Delete All Images")
476
+
477
+ task_type = gr.Dropdown(
478
+ choices=list(TASK2SPECIAL_TOKENS.keys()),
479
+ value="text2image",
480
+ label="Task Type"
481
+ )
482
+
483
+ captioning_message = gr.Textbox(
484
+ lines=2,
485
+ value="Describe the contents of the photo in 50 words.",
486
+ label="Custom message for captioner"
487
+ )
488
+
489
+ auto_caption_btn = gr.Button("Generate Captions")
490
+
491
+ with gr.Column():
492
+ prompt = gr.Textbox(
493
+ lines=3,
494
+ placeholder="Enter your prompt here or use auto-caption...",
495
+ label="Prompt"
496
+ )
497
+ negative_prompt = gr.Textbox(
498
+ lines=3,
499
+ value=NEGATIVE_PROMPT,
500
+ placeholder="Enter negative prompt here...",
501
+ label="Negative Prompt"
502
+ )
503
+ caption_status = gr.Textbox(label="Caption Status")
504
+
505
+ num_steps = gr.Slider(
506
+ minimum=1,
507
+ maximum=200,
508
+ value=50,
509
+ step=1,
510
+ label="Number of Inference Steps"
511
+ )
512
+ guidance_scale = gr.Slider(
513
+ minimum=0.1,
514
+ maximum=10.0,
515
+ value=4,
516
+ step=0.1,
517
+ label="Guidance Scale"
518
+ )
519
+ height = gr.Number(value=1024, label="Height")
520
+ width = gr.Number(value=1024, label="Width")
521
+
522
+ with gr.Accordion("Advanced Configuration", open=False):
523
+ with gr.Row():
524
+ denoise_mask_checkbox = gr.CheckboxGroup(
525
+ label="Denoise Mask",
526
+ choices=["image_0"],
527
+ value=["image_0"]
528
+ )
529
+ azimuth = gr.Textbox(
530
+ value="0",
531
+ label="Azimuths (degrees, comma-separated, 'None' for missing)"
532
+ )
533
+ elevation = gr.Textbox(
534
+ value="0",
535
+ label="Elevations (degrees, comma-separated, 'None' for missing)"
536
+ )
537
+ distance = gr.Textbox(
538
+ value="1.5",
539
+ label="Distances (comma-separated, 'None' for missing)"
540
+ )
541
+ focal_length = gr.Number(
542
+ value=1.3887,
543
+ label="Focal Length of camera for multiview generation"
544
+ )
545
+ scale_factor = gr.Number(value=1.0, label="Scale Factor")
546
+ scale_watershed = gr.Number(value=1.0, label="Scale Watershed")
547
+ noise_scale = gr.Number(value=1.0, label="Noise Scale") # Added noise_scale input
548
+
549
+ output_images = gr.Gallery(
550
+ label="Generated Images",
551
+ show_label=True,
552
+ columns=4,
553
+ rows=2,
554
+ height="auto",
555
+ object_fit="contain"
556
+ )
557
+
558
+ with gr.Column():
559
+ generate_btn = gr.Button("Generate Image")
560
+ # apply_mask_btn = gr.Button("Apply Mask")
561
+
562
+ status = gr.Textbox(label="Generation Status")
563
+
564
+ # Event Handlers
565
+ def update_gallery(files, images_state):
566
+ if not files:
567
+ return images_state, images_state
568
+
569
+ new_images = []
570
+ for file in files:
571
+ try:
572
+ # Handle both file paths and file objects
573
+ if isinstance(file, dict): # For drag and drop files
574
+ file = file['path']
575
+ elif hasattr(file, 'name'): # For uploaded files
576
+ file = file.name
577
+
578
+ img = Image.open(file).convert('RGB')
579
+ new_images.append(img)
580
+ except Exception as e:
581
+ print(f"Error loading image: {str(e)}")
582
+ continue
583
+
584
+ images_state.extend(new_images)
585
+ return images_state, images_state
586
+
587
+ def on_image_select(evt: gr.SelectData, selected_indices_state):
588
+ selected_indices = selected_indices_state or []
589
+ index = evt.index
590
+ if index in selected_indices:
591
+ selected_indices.remove(index)
592
+ else:
593
+ selected_indices.append(index)
594
+ return selected_indices
595
+
596
+ def delete_images(selected_indices, images_state):
597
+ updated_images = [img for i, img in enumerate(images_state) if i not in selected_indices]
598
+ selected_indices_state = []
599
+ return updated_images, updated_images, selected_indices_state
600
+
601
+ def delete_all_images(images_state):
602
+ updated_images = []
603
+ selected_indices_state = []
604
+ return updated_images, updated_images, selected_indices_state
605
+
606
+ def update_height_width(images_state):
607
+ if images_state:
608
+ closest_ar = get_closest_ratio(
609
+ height=images_state[0].size[1],
610
+ width=images_state[0].size[0],
611
+ ratios=ASPECT_RATIO_512
612
+ )
613
+ height_val, width_val = int(closest_ar[0][0]), int(closest_ar[0][1])
614
+ else:
615
+ height_val, width_val = 1024, 1024 # Default values
616
+ return gr.update(value=height_val), gr.update(value=width_val)
617
+
618
+ # Connect events
619
+ file_output.change(
620
+ fn=update_gallery,
621
+ inputs=[file_output, images_state],
622
+ outputs=[images_state, gallery]
623
+ ).then(
624
+ fn=update_height_width,
625
+ inputs=[images_state],
626
+ outputs=[height, width]
627
+ ).then(
628
+ fn=update_denoise_checkboxes,
629
+ inputs=[images_state, task_type, azimuth, elevation, distance],
630
+ outputs=[denoise_mask_checkbox]
631
+ )
632
+
633
+ gallery.select(
634
+ fn=on_image_select,
635
+ inputs=[selected_indices_state],
636
+ outputs=[selected_indices_state]
637
+ )
638
+
639
+ delete_button.click(
640
+ fn=delete_images,
641
+ inputs=[selected_indices_state, images_state],
642
+ outputs=[images_state, gallery, selected_indices_state]
643
+ ).then(
644
+ fn=update_denoise_checkboxes,
645
+ inputs=[images_state, task_type, azimuth, elevation, distance],
646
+ outputs=[denoise_mask_checkbox]
647
+ )
648
+
649
+ delete_all_button.click(
650
+ fn=delete_all_images,
651
+ inputs=[images_state],
652
+ outputs=[images_state, gallery, selected_indices_state]
653
+ ).then(
654
+ fn=update_denoise_checkboxes,
655
+ inputs=[images_state, task_type, azimuth, elevation, distance],
656
+ outputs=[denoise_mask_checkbox]
657
+ )
658
+
659
+ task_type.change(
660
+ fn=update_denoise_checkboxes,
661
+ inputs=[images_state, task_type, azimuth, elevation, distance],
662
+ outputs=[denoise_mask_checkbox]
663
+ )
664
+
665
+ azimuth.change(
666
+ fn=update_denoise_checkboxes,
667
+ inputs=[images_state, task_type, azimuth, elevation, distance],
668
+ outputs=[denoise_mask_checkbox]
669
+ )
670
+
671
+ elevation.change(
672
+ fn=update_denoise_checkboxes,
673
+ inputs=[images_state, task_type, azimuth, elevation, distance],
674
+ outputs=[denoise_mask_checkbox]
675
+ )
676
+
677
+ distance.change(
678
+ fn=update_denoise_checkboxes,
679
+ inputs=[images_state, task_type, azimuth, elevation, distance],
680
+ outputs=[denoise_mask_checkbox]
681
+ )
682
+
683
+ generate_btn.click(
684
+ fn=generate_image,
685
+ inputs=[
686
+ images_state, prompt, negative_prompt, num_steps, guidance_scale,
687
+ denoise_mask_checkbox, task_type, azimuth, elevation, distance,
688
+ focal_length, height, width, scale_factor, scale_watershed, noise_scale # Added noise_scale here
689
+ ],
690
+ outputs=[output_images, status],
691
+ concurrency_id="gpu_queue"
692
+ )
693
+
694
+ auto_caption_btn.click(
695
+ fn=update_prompt,
696
+ inputs=[images_state, task_type, captioning_message],
697
+ outputs=[prompt, caption_status],
698
+ concurrency_id="gpu_queue"
699
+ )
700
+
701
+ # apply_mask_btn.click(
702
+ # fn=apply_mask,
703
+ # inputs=[images_state],
704
+ # outputs=[output_images, status]
705
+ # )
706
+
707
+ if __name__ == "__main__":
708
+ parser = argparse.ArgumentParser(description='Start the Gradio demo with specified captioner.')
709
+ parser.add_argument('--captioner', type=str, choices=['molmo', 'llava', 'disable'], default='molmo', help='Captioner to use: molmo, llava, disable.')
710
+ args = parser.parse_args()
711
+
712
+ # Initialize models with the specified captioner
713
+ pipeline, captioner = initialize_models(args.captioner)
714
+
715
+ demo.launch(share=True)
inference.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline
3
+ from PIL import Image
4
+
5
+ device = torch.device('cuda:0')
6
+ pipeline = OneDiffusionPipeline.from_pretrained("lehduong/OneDiffusion").to(device=device, dtype=torch.bfloat16)
7
+
8
+ NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
9
+
10
+ ## Text-to-image
11
+ output = pipeline(
12
+ prompt="[[text2image]] A bipedal black cat wearing a huge oversized witch hat, a wizards robe, casting a spell,in an enchanted forest. The scene is filled with fireflies and moss on surrounding rocks and trees",
13
+ negative_prompt=NEGATIVE_PROMPT,
14
+ num_inference_steps=50,
15
+ guidance_scale=4,
16
+ height=1024,
17
+ width=1024,
18
+ )
19
+ output.images[0].save('text2image_output.jpg')
20
+
21
+ ## ID Customization
22
+ image = [
23
+ Image.open("assets/examples/id_customization/chenhao/image_0.png"),
24
+ Image.open("assets/examples/id_customization/chenhao/image_1.png"),
25
+ Image.open("assets/examples/id_customization/chenhao/image_2.png")
26
+ ]
27
+
28
+ # input = [noise, cond_1, cond_2, cond_3]
29
+ prompt = "[[faceid]] \
30
+ [[img0]] A woman dressed in traditional attire with intricate headpieces, posing gracefully with a serene expression. \
31
+ [[img1]] A woman with long dark hair, smiling warmly while wearing a floral dress. \
32
+ [[img2]] A woman in traditional clothing holding a lace parasol, with her hair styled elegantly. \
33
+ [[img3]] A woman in elaborate traditional attire and jewelry, with an ornate headdress, looking intently forward. \
34
+ "
35
+
36
+ ret = pipeline.img2img(image=image, num_inference_steps=75, prompt=prompt, denoise_mask=[1, 0, 0, 0], guidance_scale=4)
37
+ ret.images[0].save("idcustomization_output.jpg")
onediffusion/dataset/multitask/multiview.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from PIL import Image
5
+ import torch
6
+ from typing import List, Tuple, Union
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+ import torchvision.transforms as T
10
+ from onediffusion.dataset.utils import *
11
+ import glob
12
+
13
+ from onediffusion.dataset.raydiff_utils import cameras_to_rays, first_camera_transform, normalize_cameras
14
+ from onediffusion.dataset.transforms import CenterCropResizeImage
15
+ from pytorch3d.renderer import PerspectiveCameras
16
+
17
+ import numpy as np
18
+
19
+ def _cameras_from_opencv_projection(
20
+ R: torch.Tensor,
21
+ tvec: torch.Tensor,
22
+ camera_matrix: torch.Tensor,
23
+ image_size: torch.Tensor,
24
+ do_normalize_cameras,
25
+ normalize_scale,
26
+ ) -> PerspectiveCameras:
27
+ focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
28
+ principal_point = camera_matrix[:, :2, 2]
29
+
30
+ # Retype the image_size correctly and flip to width, height.
31
+ image_size_wh = image_size.to(R).flip(dims=(1,))
32
+
33
+ # Screen to NDC conversion:
34
+ # For non square images, we scale the points such that smallest side
35
+ # has range [-1, 1] and the largest side has range [-u, u], with u > 1.
36
+ # This convention is consistent with the PyTorch3D renderer, as well as
37
+ # the transformation function `get_ndc_to_screen_transform`.
38
+ scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0
39
+ scale = scale.expand(-1, 2)
40
+ c0 = image_size_wh / 2.0
41
+
42
+ # Get the PyTorch3D focal length and principal point.
43
+ focal_pytorch3d = focal_length / scale
44
+ p0_pytorch3d = -(principal_point - c0) / scale
45
+
46
+ # For R, T we flip x, y axes (opencv screen space has an opposite
47
+ # orientation of screen axes).
48
+ # We also transpose R (opencv multiplies points from the opposite=left side).
49
+ R_pytorch3d = R.clone().permute(0, 2, 1)
50
+ T_pytorch3d = tvec.clone()
51
+ R_pytorch3d[:, :, :2] *= -1
52
+ T_pytorch3d[:, :2] *= -1
53
+
54
+ cams = PerspectiveCameras(
55
+ R=R_pytorch3d,
56
+ T=T_pytorch3d,
57
+ focal_length=focal_pytorch3d,
58
+ principal_point=p0_pytorch3d,
59
+ image_size=image_size,
60
+ device=R.device,
61
+ )
62
+
63
+ if do_normalize_cameras:
64
+ cams, _ = normalize_cameras(cams, scale=normalize_scale)
65
+
66
+ cams = first_camera_transform(cams, rotation_only=False)
67
+ return cams
68
+
69
+ def calculate_rays(Ks, sizes, Rs, Ts, target_size, use_plucker=True, do_normalize_cameras=False, normalize_scale=1.0):
70
+ cameras = _cameras_from_opencv_projection(
71
+ R=Rs,
72
+ tvec=Ts,
73
+ camera_matrix=Ks,
74
+ image_size=sizes,
75
+ do_normalize_cameras=do_normalize_cameras,
76
+ normalize_scale=normalize_scale
77
+ )
78
+
79
+ rays_embedding = cameras_to_rays(
80
+ cameras=cameras,
81
+ num_patches_x=target_size,
82
+ num_patches_y=target_size,
83
+ crop_parameters=None,
84
+ use_plucker=use_plucker
85
+ )
86
+
87
+ return rays_embedding.rays
88
+
89
+ def convert_rgba_to_rgb_white_bg(image):
90
+ """Convert RGBA image to RGB with white background"""
91
+ if image.mode == 'RGBA':
92
+ # Create a white background
93
+ background = Image.new('RGBA', image.size, (255, 255, 255, 255))
94
+ # Composite the image onto the white background
95
+ return Image.alpha_composite(background, image).convert('RGB')
96
+ return image.convert('RGB')
97
+
98
+ class MultiviewDataset(Dataset):
99
+ def __init__(
100
+ self,
101
+ scene_folders: str,
102
+ samples_per_set: Union[int, Tuple[int, int]], # Changed from samples_per_set to samples_range
103
+ transform=None,
104
+ caption_keys: Union[str, List] = "caption",
105
+ multiscale=False,
106
+ aspect_ratio_type=ASPECT_RATIO_512,
107
+ c2w_scaling=1.7,
108
+ default_max_distance=1, # default max distance from all camera of a scene ,
109
+ do_normalize=True, # whether normalize translation of c2w with max_distance
110
+ swap_xz=False, # whether swap x and z axis of 3D scenes
111
+ valid_paths: str = "",
112
+ frame_sliding_windows: float = None # limit all sampled frames to be within this window, so that camera poses won't be too different
113
+ ):
114
+ if not isinstance(samples_per_set, tuple) and not isinstance(samples_per_set, list):
115
+ samples_per_set = (samples_per_set, samples_per_set)
116
+ self.samples_range = samples_per_set # Tuple of (min_samples, max_samples)
117
+ self.transform = transform
118
+ self.caption_keys = caption_keys if isinstance(caption_keys, list) else [caption_keys]
119
+ self.aspect_ratio = aspect_ratio_type
120
+ self.scene_folders = sorted(glob.glob(scene_folders))
121
+ # filter out scene folders that do not have transforms.json
122
+ self.scene_folders = list(filter(lambda x: os.path.exists(os.path.join(x, "transforms.json")), self.scene_folders))
123
+
124
+ # if valid_paths.txt exists, only use paths in that file
125
+ if os.path.exists(valid_paths):
126
+ with open(valid_paths, 'r') as f:
127
+ valid_scene_folders = f.read().splitlines()
128
+ self.scene_folders = sorted(valid_scene_folders)
129
+
130
+ self.c2w_scaling = c2w_scaling
131
+ self.do_normalize = do_normalize
132
+ self.default_max_distance = default_max_distance
133
+ self.swap_xz = swap_xz
134
+ self.frame_sliding_windows = frame_sliding_windows
135
+
136
+ if multiscale:
137
+ assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880]
138
+ if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]:
139
+ self.interpolate_model = T.InterpolationMode.LANCZOS
140
+ self.ratio_index = {}
141
+ self.ratio_nums = {}
142
+ for k, v in self.aspect_ratio.items():
143
+ self.ratio_index[float(k)] = [] # used for self.getitem
144
+ self.ratio_nums[float(k)] = 0 # used for batch-sampler
145
+
146
+ def __len__(self):
147
+ return len(self.scene_folders)
148
+
149
+ def __getitem__(self, idx):
150
+ try:
151
+ scene_path = self.scene_folders[idx]
152
+
153
+ if os.path.exists(os.path.join(scene_path, "images")):
154
+ image_folder = os.path.join(scene_path, "images")
155
+ downscale_factor = 1
156
+ elif os.path.exists(os.path.join(scene_path, "images_4")):
157
+ image_folder = os.path.join(scene_path, "images_4")
158
+ downscale_factor = 1 / 4
159
+ elif os.path.exists(os.path.join(scene_path, "images_8")):
160
+ image_folder = os.path.join(scene_path, "images_8")
161
+ downscale_factor = 1 / 8
162
+ else:
163
+ raise NotImplementedError
164
+
165
+ json_path = os.path.join(scene_path, "transforms.json")
166
+ caption_path = os.path.join(scene_path, "caption.json")
167
+ image_files = os.listdir(image_folder)
168
+
169
+ with open(json_path, 'r') as f:
170
+ json_data = json.load(f)
171
+ height, width = json_data['h'], json_data['w']
172
+
173
+ dh, dw = int(height * downscale_factor), int(width * downscale_factor)
174
+ fl_x, fl_y = json_data['fl_x'] * downscale_factor, json_data['fl_y'] * downscale_factor
175
+ cx = dw // 2
176
+ cy = dh // 2
177
+
178
+ frame_list = json_data['frames']
179
+
180
+ # Randomly select number of samples
181
+
182
+ samples_per_set = random.randint(self.samples_range[0], self.samples_range[1])
183
+
184
+ # uniformly for all scenes
185
+ if self.frame_sliding_windows is None:
186
+ selected_indices = random.sample(range(len(frame_list)), min(samples_per_set, len(frame_list)))
187
+ # limit the multiview to be in a sliding window (to avoid catastrophic difference in camera angles)
188
+ else:
189
+ # Determine the starting index of the sliding window
190
+ if len(frame_list) <= self.frame_sliding_windows:
191
+ # If the frame list is smaller than or equal to X, use the entire list
192
+ window_start = 0
193
+ window_end = len(frame_list)
194
+ else:
195
+ # Randomly select a starting point for the window
196
+ window_start = random.randint(0, len(frame_list) - self.frame_sliding_windows)
197
+ window_end = window_start + self.frame_sliding_windows
198
+
199
+ # Get the indices within the sliding window
200
+ window_indices = list(range(window_start, window_end))
201
+
202
+ # Randomly sample indices from the window
203
+ selected_indices = random.sample(window_indices, samples_per_set)
204
+
205
+ image_files = [os.path.basename(frame_list[i]['file_path']) for i in selected_indices]
206
+ image_paths = [os.path.join(image_folder, file) for file in image_files]
207
+
208
+ # Load images and convert RGBA to RGB with white background
209
+ images = [convert_rgba_to_rgb_white_bg(Image.open(image_path)) for image_path in image_paths]
210
+
211
+ if self.transform:
212
+ images = [self.transform(image) for image in images]
213
+ else:
214
+ closest_size, closest_ratio = self.aspect_ratio['1.0'], 1.0
215
+ closest_size = tuple(map(int, closest_size))
216
+ transform = T.Compose([
217
+ T.ToTensor(),
218
+ CenterCropResizeImage(closest_size),
219
+ T.Normalize([.5], [.5]),
220
+ ])
221
+ images = [transform(image) for image in images]
222
+ images = torch.stack(images)
223
+
224
+ c2ws = [frame_list[i]['transform_matrix'] for i in selected_indices]
225
+ c2ws = torch.tensor(c2ws).reshape(-1, 4, 4)
226
+ # max_distance = json_data.get('max_distance', self.default_max_distance)
227
+ # if 'max_distance' not in json_data.keys():
228
+ # print(f"not found `max_distance` in json path: {json_path}")
229
+
230
+ if self.swap_xz:
231
+ swap_xz = torch.tensor([[[0, 0, 1., 0],
232
+ [0, 1., 0, 0],
233
+ [-1., 0, 0, 0],
234
+ [0, 0, 0, 1.]]])
235
+ c2ws = swap_xz @ c2ws
236
+
237
+ # OPENGL to OPENCV
238
+ c2ws[:, 0:3, 1:3] *= -1
239
+ c2ws = c2ws[:, [1, 0, 2, 3], :]
240
+ c2ws[:, 2, :] *= -1
241
+
242
+ w2cs = torch.inverse(c2ws)
243
+ K = torch.tensor([[[fl_x, 0, cx], [0, fl_y, cy], [0, 0, 1]]]).repeat(len(c2ws), 1, 1)
244
+ Rs = w2cs[:, :3, :3]
245
+ Ts = w2cs[:, :3, 3]
246
+ sizes = torch.tensor([[dh, dw]]).repeat(len(c2ws), 1)
247
+
248
+ # get ray embedding and padding last dimension to 16 (num channels of VAE)
249
+ # rays_od = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, use_plucker=False, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
250
+ rays = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
251
+ rays = rays.reshape(samples_per_set, closest_size[0] // 8, closest_size[1] // 8, 6)
252
+ # padding = (0, 10) # pad the last dimension to 16
253
+ # rays = torch.nn.functional.pad(rays, padding, "constant", 0)
254
+ rays = torch.cat([rays, rays, rays[..., :4]], dim=-1) * 1.658
255
+
256
+ if os.path.exists(caption_path):
257
+ with open(caption_path, 'r') as f:
258
+ caption_key = random.choice(self.caption_keys)
259
+ caption = json.load(f).get(caption_key, "")
260
+ else:
261
+ caption = ""
262
+
263
+ caption = "[[multiview]] " + caption if caption else "[[multiview]]"
264
+
265
+ return {
266
+ 'pixel_values': images,
267
+ 'rays': rays,
268
+ 'aspect_ratio': closest_ratio,
269
+ 'caption': caption,
270
+ 'height': dh,
271
+ 'width': dw,
272
+ # 'origins': rays_od[..., :3],
273
+ # 'dirs': rays_od[..., 3:6]
274
+ }
275
+ except Exception as e:
276
+ return self.__getitem__(random.randint(0, len(self.scene_folders) - 1))
277
+
onediffusion/dataset/raydiff_utils.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Adapted from code originally written by David Novotny.
4
+ """
5
+
6
+ import torch
7
+ from pytorch3d.transforms import Rotate, Translate
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ from pytorch3d.renderer import PerspectiveCameras, RayBundle
13
+
14
+ def intersect_skew_line_groups(p, r, mask):
15
+ # p, r both of shape (B, N, n_intersected_lines, 3)
16
+ # mask of shape (B, N, n_intersected_lines)
17
+ p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
18
+ if p_intersect is None:
19
+ return None, None, None, None
20
+ _, p_line_intersect = point_line_distance(
21
+ p, r, p_intersect[..., None, :].expand_as(p)
22
+ )
23
+ intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
24
+ dim=-1
25
+ )
26
+ return p_intersect, p_line_intersect, intersect_dist_squared, r
27
+
28
+
29
+ def intersect_skew_lines_high_dim(p, r, mask=None):
30
+ # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
31
+ dim = p.shape[-1]
32
+ # make sure the heading vectors are l2-normed
33
+ if mask is None:
34
+ mask = torch.ones_like(p[..., 0])
35
+ r = torch.nn.functional.normalize(r, dim=-1)
36
+
37
+ eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
38
+ I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
39
+ sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
40
+
41
+ # I_eps = torch.zeros_like(I_min_cov.sum(dim=-3)) + 1e-10
42
+ # p_intersect = torch.pinverse(I_min_cov.sum(dim=-3) + I_eps).matmul(sum_proj)[..., 0]
43
+ p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
44
+
45
+ # I_min_cov.sum(dim=-3): torch.Size([1, 1, 3, 3])
46
+ # sum_proj: torch.Size([1, 1, 3, 1])
47
+
48
+ # p_intersect = np.linalg.lstsq(I_min_cov.sum(dim=-3).numpy(), sum_proj.numpy(), rcond=None)[0]
49
+
50
+ if torch.any(torch.isnan(p_intersect)):
51
+ print(p_intersect)
52
+ return None, None
53
+ ipdb.set_trace()
54
+ assert False
55
+ return p_intersect, r
56
+
57
+
58
+ def point_line_distance(p1, r1, p2):
59
+ df = p2 - p1
60
+ proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
61
+ line_pt_nearest = p2 - proj_vector
62
+ d = (proj_vector).norm(dim=-1)
63
+ return d, line_pt_nearest
64
+
65
+
66
+ def compute_optical_axis_intersection(cameras):
67
+ centers = cameras.get_camera_center()
68
+ principal_points = cameras.principal_point
69
+
70
+ one_vec = torch.ones((len(cameras), 1), device=centers.device)
71
+ optical_axis = torch.cat((principal_points, one_vec), -1)
72
+
73
+ # optical_axis = torch.cat(
74
+ # (principal_points, cameras.focal_length[:, 0].unsqueeze(1)), -1
75
+ # )
76
+
77
+ pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
78
+ pp2 = torch.diagonal(pp, dim1=0, dim2=1).T
79
+
80
+ directions = pp2 - centers
81
+ centers = centers.unsqueeze(0).unsqueeze(0)
82
+ directions = directions.unsqueeze(0).unsqueeze(0)
83
+
84
+ p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
85
+ p=centers, r=directions, mask=None
86
+ )
87
+
88
+ if p_intersect is None:
89
+ dist = None
90
+ else:
91
+ p_intersect = p_intersect.squeeze().unsqueeze(0)
92
+ dist = (p_intersect - centers).norm(dim=-1)
93
+
94
+ return p_intersect, dist, p_line_intersect, pp2, r
95
+
96
+
97
+ def normalize_cameras(cameras, scale=1.0):
98
+ """
99
+ Normalizes cameras such that the optical axes point to the origin, the rotation is
100
+ identity, and the norm of the translation of the first camera is 1.
101
+
102
+ Args:
103
+ cameras (pytorch3d.renderer.cameras.CamerasBase).
104
+ scale (float): Norm of the translation of the first camera.
105
+
106
+ Returns:
107
+ new_cameras (pytorch3d.renderer.cameras.CamerasBase): Normalized cameras.
108
+ undo_transform (function): Function that undoes the normalization.
109
+ """
110
+
111
+ # Let distance from first camera to origin be unit
112
+ new_cameras = cameras.clone()
113
+ new_transform = (
114
+ new_cameras.get_world_to_view_transform()
115
+ ) # potential R is not valid matrix
116
+ p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
117
+ cameras
118
+ )
119
+
120
+ if p_intersect is None:
121
+ print("Warning: optical axes code has a nan. Returning identity cameras.")
122
+ new_cameras.R[:] = torch.eye(3, device=cameras.R.device, dtype=cameras.R.dtype)
123
+ new_cameras.T[:] = torch.tensor(
124
+ [0, 0, 1], device=cameras.T.device, dtype=cameras.T.dtype
125
+ )
126
+ return new_cameras, lambda x: x
127
+
128
+ d = dist.squeeze(dim=1).squeeze(dim=0)[0]
129
+ # Degenerate case
130
+ if d == 0:
131
+ print(cameras.T)
132
+ print(new_transform.get_matrix()[:, 3, :3])
133
+ assert False
134
+ assert d != 0
135
+
136
+ # Can't figure out how to make scale part of the transform too without messing up R.
137
+ # Ideally, we would just wrap it all in a single Pytorch3D transform so that it
138
+ # would work with any structure (eg PointClouds, Meshes).
139
+ tR = Rotate(new_cameras.R[0].unsqueeze(0)).inverse()
140
+ tT = Translate(p_intersect)
141
+ t = tR.compose(tT)
142
+
143
+ new_transform = t.compose(new_transform)
144
+ new_cameras.R = new_transform.get_matrix()[:, :3, :3]
145
+ new_cameras.T = new_transform.get_matrix()[:, 3, :3] / d * scale
146
+
147
+ def undo_transform(cameras):
148
+ cameras_copy = cameras.clone()
149
+ cameras_copy.T *= d / scale
150
+ new_t = (
151
+ t.inverse().compose(cameras_copy.get_world_to_view_transform()).get_matrix()
152
+ )
153
+ cameras_copy.R = new_t[:, :3, :3]
154
+ cameras_copy.T = new_t[:, 3, :3]
155
+ return cameras_copy
156
+
157
+ return new_cameras, undo_transform
158
+
159
+ def first_camera_transform(cameras, rotation_only=True):
160
+ new_cameras = cameras.clone()
161
+ new_transform = new_cameras.get_world_to_view_transform()
162
+ tR = Rotate(new_cameras.R[0].unsqueeze(0))
163
+ if rotation_only:
164
+ t = tR.inverse()
165
+ else:
166
+ tT = Translate(new_cameras.T[0].unsqueeze(0))
167
+ t = tR.compose(tT).inverse()
168
+
169
+ new_transform = t.compose(new_transform)
170
+ new_cameras.R = new_transform.get_matrix()[:, :3, :3]
171
+ new_cameras.T = new_transform.get_matrix()[:, 3, :3]
172
+
173
+ return new_cameras
174
+
175
+
176
+ def get_identity_cameras_with_intrinsics(cameras):
177
+ D = len(cameras)
178
+ device = cameras.R.device
179
+
180
+ new_cameras = cameras.clone()
181
+ new_cameras.R = torch.eye(3, device=device).unsqueeze(0).repeat((D, 1, 1))
182
+ new_cameras.T = torch.zeros((D, 3), device=device)
183
+
184
+ return new_cameras
185
+
186
+
187
+ def normalize_cameras_batch(cameras, scale=1.0, normalize_first_camera=False):
188
+ new_cameras = []
189
+ undo_transforms = []
190
+ for cam in cameras:
191
+ if normalize_first_camera:
192
+ # Normalize cameras such that first camera is identity and origin is at
193
+ # first camera center.
194
+ normalized_cameras = first_camera_transform(cam, rotation_only=False)
195
+ undo_transform = None
196
+ else:
197
+ normalized_cameras, undo_transform = normalize_cameras(cam, scale=scale)
198
+ new_cameras.append(normalized_cameras)
199
+ undo_transforms.append(undo_transform)
200
+ return new_cameras, undo_transforms
201
+
202
+
203
+ class Rays(object):
204
+ def __init__(
205
+ self,
206
+ rays=None,
207
+ origins=None,
208
+ directions=None,
209
+ moments=None,
210
+ is_plucker=False,
211
+ moments_rescale=1.0,
212
+ ndc_coordinates=None,
213
+ crop_parameters=None,
214
+ num_patches_x=16,
215
+ num_patches_y=16,
216
+ ):
217
+ """
218
+ Ray class to keep track of current ray representation.
219
+
220
+ Args:
221
+ rays: (..., 6).
222
+ origins: (..., 3).
223
+ directions: (..., 3).
224
+ moments: (..., 3).
225
+ is_plucker: If True, rays are in plucker coordinates (Default: False).
226
+ moments_rescale: Rescale the moment component of the rays by a scalar.
227
+ ndc_coordinates: (..., 2): NDC coordinates of each ray.
228
+ """
229
+ if rays is not None:
230
+ self.rays = rays
231
+ self._is_plucker = is_plucker
232
+ elif origins is not None and directions is not None:
233
+ self.rays = torch.cat((origins, directions), dim=-1)
234
+ self._is_plucker = False
235
+ elif directions is not None and moments is not None:
236
+ self.rays = torch.cat((directions, moments), dim=-1)
237
+ self._is_plucker = True
238
+ else:
239
+ raise Exception("Invalid combination of arguments")
240
+
241
+ if moments_rescale != 1.0:
242
+ self.rescale_moments(moments_rescale)
243
+
244
+ if ndc_coordinates is not None:
245
+ self.ndc_coordinates = ndc_coordinates
246
+ elif crop_parameters is not None:
247
+ # (..., H, W, 2)
248
+ xy_grid = compute_ndc_coordinates(
249
+ crop_parameters,
250
+ num_patches_x=num_patches_x,
251
+ num_patches_y=num_patches_y,
252
+ )[..., :2]
253
+ xy_grid = xy_grid.reshape(*xy_grid.shape[:-3], -1, 2)
254
+ self.ndc_coordinates = xy_grid
255
+ else:
256
+ self.ndc_coordinates = None
257
+
258
+ def __getitem__(self, index):
259
+ return Rays(
260
+ rays=self.rays[index],
261
+ is_plucker=self._is_plucker,
262
+ ndc_coordinates=(
263
+ self.ndc_coordinates[index]
264
+ if self.ndc_coordinates is not None
265
+ else None
266
+ ),
267
+ )
268
+
269
+ def to_spatial(self, include_ndc_coordinates=False):
270
+ """
271
+ Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W)
272
+
273
+ Returns:
274
+ torch.Tensor: (..., 6, H, W)
275
+ """
276
+ rays = self.to_plucker().rays
277
+ *batch_dims, P, D = rays.shape
278
+ H = W = int(np.sqrt(P))
279
+ assert H * W == P
280
+ rays = torch.transpose(rays, -1, -2) # (..., 6, H * W)
281
+ rays = rays.reshape(*batch_dims, D, H, W)
282
+ if include_ndc_coordinates:
283
+ ndc_coords = self.ndc_coordinates.transpose(-1, -2) # (..., 2, H * W)
284
+ ndc_coords = ndc_coords.reshape(*batch_dims, 2, H, W)
285
+ rays = torch.cat((rays, ndc_coords), dim=-3)
286
+ return rays
287
+
288
+ def rescale_moments(self, scale):
289
+ """
290
+ Rescale the moment component of the rays by a scalar. Might be desirable since
291
+ moments may come from a very narrow distribution.
292
+
293
+ Note that this modifies in place!
294
+ """
295
+ if self.is_plucker:
296
+ self.rays[..., 3:] *= scale
297
+ return self
298
+ else:
299
+ return self.to_plucker().rescale_moments(scale)
300
+
301
+ @classmethod
302
+ def from_spatial(cls, rays, moments_rescale=1.0, ndc_coordinates=None):
303
+ """
304
+ Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6)
305
+
306
+ Args:
307
+ rays: (..., 6, H, W)
308
+
309
+ Returns:
310
+ Rays: (..., H * W, 6)
311
+ """
312
+ *batch_dims, D, H, W = rays.shape
313
+ rays = rays.reshape(*batch_dims, D, H * W)
314
+ rays = torch.transpose(rays, -1, -2)
315
+ return cls(
316
+ rays=rays,
317
+ is_plucker=True,
318
+ moments_rescale=moments_rescale,
319
+ ndc_coordinates=ndc_coordinates,
320
+ )
321
+
322
+ def to_point_direction(self, normalize_moment=True):
323
+ """
324
+ Convert to point direction representation <O, D>.
325
+
326
+ Returns:
327
+ rays: (..., 6).
328
+ """
329
+ if self._is_plucker:
330
+ direction = torch.nn.functional.normalize(self.rays[..., :3], dim=-1)
331
+ moment = self.rays[..., 3:]
332
+ if normalize_moment:
333
+ c = torch.linalg.norm(direction, dim=-1, keepdim=True)
334
+ moment = moment / c
335
+ points = torch.cross(direction, moment, dim=-1)
336
+ return Rays(
337
+ rays=torch.cat((points, direction), dim=-1),
338
+ is_plucker=False,
339
+ ndc_coordinates=self.ndc_coordinates,
340
+ )
341
+ else:
342
+ return self
343
+
344
+ def to_plucker(self):
345
+ """
346
+ Convert to plucker representation <D, OxD>.
347
+ """
348
+ if self.is_plucker:
349
+ return self
350
+ else:
351
+ ray = self.rays.clone()
352
+ ray_origins = ray[..., :3]
353
+ ray_directions = ray[..., 3:]
354
+ # Normalize ray directions to unit vectors
355
+ ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True)
356
+ plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
357
+ new_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
358
+ return Rays(
359
+ rays=new_ray, is_plucker=True, ndc_coordinates=self.ndc_coordinates
360
+ )
361
+
362
+ def get_directions(self, normalize=True):
363
+ if self.is_plucker:
364
+ directions = self.rays[..., :3]
365
+ else:
366
+ directions = self.rays[..., 3:]
367
+ if normalize:
368
+ directions = torch.nn.functional.normalize(directions, dim=-1)
369
+ return directions
370
+
371
+ def get_origins(self):
372
+ if self.is_plucker:
373
+ origins = self.to_point_direction().get_origins()
374
+ else:
375
+ origins = self.rays[..., :3]
376
+ return origins
377
+
378
+ def get_moments(self):
379
+ if self.is_plucker:
380
+ moments = self.rays[..., 3:]
381
+ else:
382
+ moments = self.to_plucker().get_moments()
383
+ return moments
384
+
385
+ def get_ndc_coordinates(self):
386
+ return self.ndc_coordinates
387
+
388
+ @property
389
+ def is_plucker(self):
390
+ return self._is_plucker
391
+
392
+ @property
393
+ def device(self):
394
+ return self.rays.device
395
+
396
+ def __repr__(self, *args, **kwargs):
397
+ ray_str = self.rays.__repr__(*args, **kwargs)[6:] # remove "tensor"
398
+ if self._is_plucker:
399
+ return "PluRay" + ray_str
400
+ else:
401
+ return "DirRay" + ray_str
402
+
403
+ def to(self, device):
404
+ self.rays = self.rays.to(device)
405
+
406
+ def clone(self):
407
+ return Rays(rays=self.rays.clone(), is_plucker=self._is_plucker)
408
+
409
+ @property
410
+ def shape(self):
411
+ return self.rays.shape
412
+
413
+ def visualize(self):
414
+ directions = torch.nn.functional.normalize(self.get_directions(), dim=-1).cpu()
415
+ moments = torch.nn.functional.normalize(self.get_moments(), dim=-1).cpu()
416
+ return (directions + 1) / 2, (moments + 1) / 2
417
+
418
+ def to_ray_bundle(self, length=0.3, recenter=True):
419
+ lengths = torch.ones_like(self.get_origins()[..., :2]) * length
420
+ lengths[..., 0] = 0
421
+ if recenter:
422
+ centers, _ = intersect_skew_lines_high_dim(
423
+ self.get_origins(), self.get_directions()
424
+ )
425
+ centers = centers.unsqueeze(1).repeat(1, lengths.shape[1], 1)
426
+ else:
427
+ centers = self.get_origins()
428
+ return RayBundle(
429
+ origins=centers,
430
+ directions=self.get_directions(),
431
+ lengths=lengths,
432
+ xys=self.get_directions(),
433
+ )
434
+
435
+
436
+ def cameras_to_rays(
437
+ cameras,
438
+ crop_parameters,
439
+ use_half_pix=True,
440
+ use_plucker=True,
441
+ num_patches_x=16,
442
+ num_patches_y=16,
443
+ ):
444
+ """
445
+ Unprojects rays from camera center to grid on image plane.
446
+
447
+ Args:
448
+ cameras: Pytorch3D cameras to unproject. Can be batched.
449
+ crop_parameters: Crop parameters in NDC (cc_x, cc_y, crop_width, scale).
450
+ Shape is (B, 4).
451
+ use_half_pix: If True, use half pixel offset (Default: True).
452
+ use_plucker: If True, return rays in plucker coordinates (Default: False).
453
+ num_patches_x: Number of patches in x direction (Default: 16).
454
+ num_patches_y: Number of patches in y direction (Default: 16).
455
+ """
456
+ unprojected = []
457
+ crop_parameters_list = (
458
+ crop_parameters if crop_parameters is not None else [None for _ in cameras]
459
+ )
460
+ for camera, crop_param in zip(cameras, crop_parameters_list):
461
+ xyd_grid = compute_ndc_coordinates(
462
+ crop_parameters=crop_param,
463
+ use_half_pix=use_half_pix,
464
+ num_patches_x=num_patches_x,
465
+ num_patches_y=num_patches_y,
466
+ )
467
+
468
+ unprojected.append(
469
+ camera.unproject_points(
470
+ xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True
471
+ )
472
+ )
473
+ unprojected = torch.stack(unprojected, dim=0) # (N, P, 3)
474
+ origins = cameras.get_camera_center().unsqueeze(1) # (N, 1, 3)
475
+ origins = origins.repeat(1, num_patches_x * num_patches_y, 1) # (N, P, 3)
476
+ directions = unprojected - origins
477
+
478
+ rays = Rays(
479
+ origins=origins,
480
+ directions=directions,
481
+ crop_parameters=crop_parameters,
482
+ num_patches_x=num_patches_x,
483
+ num_patches_y=num_patches_y,
484
+ )
485
+ if use_plucker:
486
+ return rays.to_plucker()
487
+ return rays
488
+
489
+
490
+ def rays_to_cameras(
491
+ rays,
492
+ crop_parameters,
493
+ num_patches_x=16,
494
+ num_patches_y=16,
495
+ use_half_pix=True,
496
+ sampled_ray_idx=None,
497
+ cameras=None,
498
+ focal_length=(3.453,),
499
+ ):
500
+ """
501
+ If cameras are provided, will use those intrinsics. Otherwise will use the provided
502
+ focal_length(s). Dataset default is 3.32.
503
+
504
+ Args:
505
+ rays (Rays): (N, P, 6)
506
+ crop_parameters (torch.Tensor): (N, 4)
507
+ """
508
+ device = rays.device
509
+ origins = rays.get_origins()
510
+ directions = rays.get_directions()
511
+ camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
512
+
513
+ # Retrieve target rays
514
+ if cameras is None:
515
+ if len(focal_length) == 1:
516
+ focal_length = focal_length * rays.shape[0]
517
+ I_camera = PerspectiveCameras(focal_length=focal_length, device=device)
518
+ else:
519
+ # Use same intrinsics but reset to identity extrinsics.
520
+ I_camera = cameras.clone()
521
+ I_camera.R[:] = torch.eye(3, device=device)
522
+ I_camera.T[:] = torch.zeros(3, device=device)
523
+ I_patch_rays = cameras_to_rays(
524
+ cameras=I_camera,
525
+ num_patches_x=num_patches_x,
526
+ num_patches_y=num_patches_y,
527
+ use_half_pix=use_half_pix,
528
+ crop_parameters=crop_parameters,
529
+ ).get_directions()
530
+
531
+ if sampled_ray_idx is not None:
532
+ I_patch_rays = I_patch_rays[:, sampled_ray_idx]
533
+
534
+ # Compute optimal rotation to align rays
535
+ R = torch.zeros_like(I_camera.R)
536
+ for i in range(len(I_camera)):
537
+ R[i] = compute_optimal_rotation_alignment(
538
+ I_patch_rays[i],
539
+ directions[i],
540
+ )
541
+
542
+ # Construct and return rotated camera
543
+ cam = I_camera.clone()
544
+ cam.R = R
545
+ cam.T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
546
+ return cam
547
+
548
+
549
+ # https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/
550
+ def ql_decomposition(A):
551
+ P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float()
552
+ A_tilde = torch.matmul(A, P)
553
+ Q_tilde, R_tilde = torch.linalg.qr(A_tilde)
554
+ Q = torch.matmul(Q_tilde, P)
555
+ L = torch.matmul(torch.matmul(P, R_tilde), P)
556
+ d = torch.diag(L)
557
+ Q[:, 0] *= torch.sign(d[0])
558
+ Q[:, 1] *= torch.sign(d[1])
559
+ Q[:, 2] *= torch.sign(d[2])
560
+ L[0] *= torch.sign(d[0])
561
+ L[1] *= torch.sign(d[1])
562
+ L[2] *= torch.sign(d[2])
563
+ return Q, L
564
+
565
+
566
+ def rays_to_cameras_homography(
567
+ rays,
568
+ crop_parameters,
569
+ num_patches_x=16,
570
+ num_patches_y=16,
571
+ use_half_pix=True,
572
+ sampled_ray_idx=None,
573
+ reproj_threshold=0.2,
574
+ ):
575
+ """
576
+ Args:
577
+ rays (Rays): (N, P, 6)
578
+ crop_parameters (torch.Tensor): (N, 4)
579
+ """
580
+ device = rays.device
581
+ origins = rays.get_origins()
582
+ directions = rays.get_directions()
583
+ camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
584
+
585
+ # Retrieve target rays
586
+ I_camera = PerspectiveCameras(focal_length=[1] * rays.shape[0], device=device)
587
+ I_patch_rays = cameras_to_rays(
588
+ cameras=I_camera,
589
+ num_patches_x=num_patches_x,
590
+ num_patches_y=num_patches_y,
591
+ use_half_pix=use_half_pix,
592
+ crop_parameters=crop_parameters,
593
+ ).get_directions()
594
+
595
+ if sampled_ray_idx is not None:
596
+ I_patch_rays = I_patch_rays[:, sampled_ray_idx]
597
+
598
+ # Compute optimal rotation to align rays
599
+ Rs = []
600
+ focal_lengths = []
601
+ principal_points = []
602
+ for i in range(rays.shape[-3]):
603
+ R, f, pp = compute_optimal_rotation_intrinsics(
604
+ I_patch_rays[i],
605
+ directions[i],
606
+ reproj_threshold=reproj_threshold,
607
+ )
608
+ Rs.append(R)
609
+ focal_lengths.append(f)
610
+ principal_points.append(pp)
611
+
612
+ R = torch.stack(Rs)
613
+ focal_lengths = torch.stack(focal_lengths)
614
+ principal_points = torch.stack(principal_points)
615
+ T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
616
+ return PerspectiveCameras(
617
+ R=R,
618
+ T=T,
619
+ focal_length=focal_lengths,
620
+ principal_point=principal_points,
621
+ device=device,
622
+ )
623
+
624
+
625
+ def compute_optimal_rotation_alignment(A, B):
626
+ """
627
+ Compute optimal R that minimizes: || A - B @ R ||_F
628
+
629
+ Args:
630
+ A (torch.Tensor): (N, 3)
631
+ B (torch.Tensor): (N, 3)
632
+
633
+ Returns:
634
+ R (torch.tensor): (3, 3)
635
+ """
636
+ # normally with R @ B, this would be A @ B.T
637
+ H = B.T @ A
638
+ U, _, Vh = torch.linalg.svd(H, full_matrices=True)
639
+ s = torch.linalg.det(U @ Vh)
640
+ S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
641
+ return U @ S_prime @ Vh
642
+
643
+
644
+ def compute_optimal_rotation_intrinsics(
645
+ rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2
646
+ ):
647
+ """
648
+ Note: for some reason, f seems to be 1/f.
649
+
650
+ Args:
651
+ rays_origin (torch.Tensor): (N, 3)
652
+ rays_target (torch.Tensor): (N, 3)
653
+ z_threshold (float): Threshold for z value to be considered valid.
654
+
655
+ Returns:
656
+ R (torch.tensor): (3, 3)
657
+ focal_length (torch.tensor): (2,)
658
+ principal_point (torch.tensor): (2,)
659
+ """
660
+ device = rays_origin.device
661
+ z_mask = torch.logical_and(
662
+ torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold
663
+ )[:, 2]
664
+ rays_target = rays_target[z_mask]
665
+ rays_origin = rays_origin[z_mask]
666
+ rays_origin = rays_origin[:, :2] / rays_origin[:, -1:]
667
+ rays_target = rays_target[:, :2] / rays_target[:, -1:]
668
+
669
+ A, _ = cv2.findHomography(
670
+ rays_origin.cpu().numpy(),
671
+ rays_target.cpu().numpy(),
672
+ cv2.RANSAC,
673
+ reproj_threshold,
674
+ )
675
+ A = torch.from_numpy(A).float().to(device)
676
+
677
+ if torch.linalg.det(A) < 0:
678
+ A = -A
679
+
680
+ R, L = ql_decomposition(A)
681
+ L = L / L[2][2]
682
+
683
+ f = torch.stack((L[0][0], L[1][1]))
684
+ pp = torch.stack((L[2][0], L[2][1]))
685
+ return R, f, pp
686
+
687
+
688
+ def compute_ndc_coordinates(
689
+ crop_parameters=None,
690
+ use_half_pix=True,
691
+ num_patches_x=16,
692
+ num_patches_y=16,
693
+ device=None,
694
+ ):
695
+ """
696
+ Computes NDC Grid using crop_parameters. If crop_parameters is not provided,
697
+ then it assumes that the crop is the entire image (corresponding to an NDC grid
698
+ where top left corner is (1, 1) and bottom right corner is (-1, -1)).
699
+ """
700
+ if crop_parameters is None:
701
+ cc_x, cc_y, width = 0, 0, 2
702
+ else:
703
+ if len(crop_parameters.shape) > 1:
704
+ return torch.stack(
705
+ [
706
+ compute_ndc_coordinates(
707
+ crop_parameters=crop_param,
708
+ use_half_pix=use_half_pix,
709
+ num_patches_x=num_patches_x,
710
+ num_patches_y=num_patches_y,
711
+ )
712
+ for crop_param in crop_parameters
713
+ ],
714
+ dim=0,
715
+ )
716
+ device = crop_parameters.device
717
+ cc_x, cc_y, width, _ = crop_parameters
718
+
719
+ dx = 1 / num_patches_x
720
+ dy = 1 / num_patches_y
721
+ if use_half_pix:
722
+ min_y = 1 - dy
723
+ max_y = -min_y
724
+ min_x = 1 - dx
725
+ max_x = -min_x
726
+ else:
727
+ min_y = min_x = 1
728
+ max_y = -1 + 2 * dy
729
+ max_x = -1 + 2 * dx
730
+
731
+ y, x = torch.meshgrid(
732
+ torch.linspace(min_y, max_y, num_patches_y, dtype=torch.float32, device=device),
733
+ torch.linspace(min_x, max_x, num_patches_x, dtype=torch.float32, device=device),
734
+ indexing="ij",
735
+ )
736
+ x_prime = x * width / 2 - cc_x
737
+ y_prime = y * width / 2 - cc_y
738
+ xyd_grid = torch.stack([x_prime, y_prime, torch.ones_like(x)], dim=-1)
739
+ return xyd_grid
onediffusion/dataset/transforms.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def crop(image, i, j, h, w):
5
+ """
6
+ Args:
7
+ image (torch.tensor): Image to be cropped. Size is (C, H, W)
8
+ """
9
+ if len(image.size()) != 3:
10
+ raise ValueError("image should be a 3D tensor")
11
+ return image[..., i : i + h, j : j + w]
12
+
13
+ def resize(image, target_size, interpolation_mode):
14
+ if len(target_size) != 2:
15
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
16
+ return F.interpolate(image.unsqueeze(0), size=target_size, mode=interpolation_mode, align_corners=False).squeeze(0)
17
+
18
+ def resize_scale(image, target_size, interpolation_mode):
19
+ if len(target_size) != 2:
20
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
21
+ H, W = image.size(-2), image.size(-1)
22
+ scale_ = target_size[0] / min(H, W)
23
+ return F.interpolate(image.unsqueeze(0), scale_factor=scale_, mode=interpolation_mode, align_corners=False).squeeze(0)
24
+
25
+ def resized_crop(image, i, j, h, w, size, interpolation_mode="bilinear"):
26
+ """
27
+ Do spatial cropping and resizing to the image
28
+ Args:
29
+ image (torch.tensor): Image to be cropped. Size is (C, H, W)
30
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
31
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
32
+ h (int): Height of the cropped region.
33
+ w (int): Width of the cropped region.
34
+ size (tuple(int, int)): height and width of resized image
35
+ Returns:
36
+ image (torch.tensor): Resized and cropped image. Size is (C, H, W)
37
+ """
38
+ if len(image.size()) != 3:
39
+ raise ValueError("image should be a 3D torch.tensor")
40
+ image = crop(image, i, j, h, w)
41
+ image = resize(image, size, interpolation_mode)
42
+ return image
43
+
44
+ def center_crop(image, crop_size):
45
+ if len(image.size()) != 3:
46
+ raise ValueError("image should be a 3D torch.tensor")
47
+ h, w = image.size(-2), image.size(-1)
48
+ th, tw = crop_size
49
+ if h < th or w < tw:
50
+ raise ValueError("height and width must be no smaller than crop_size")
51
+ i = int(round((h - th) / 2.0))
52
+ j = int(round((w - tw) / 2.0))
53
+ return crop(image, i, j, th, tw)
54
+
55
+ def center_crop_using_short_edge(image):
56
+ if len(image.size()) != 3:
57
+ raise ValueError("image should be a 3D torch.tensor")
58
+ h, w = image.size(-2), image.size(-1)
59
+ if h < w:
60
+ th, tw = h, h
61
+ i = 0
62
+ j = int(round((w - tw) / 2.0))
63
+ else:
64
+ th, tw = w, w
65
+ i = int(round((h - th) / 2.0))
66
+ j = 0
67
+ return crop(image, i, j, th, tw)
68
+
69
+ class CenterCropResizeImage:
70
+ """
71
+ Resize the image while maintaining aspect ratio, and then crop it to the desired size.
72
+ The resizing is done such that the area of padding/cropping is minimized.
73
+ """
74
+ def __init__(self, size, interpolation_mode="bilinear"):
75
+ if isinstance(size, tuple):
76
+ if len(size) != 2:
77
+ raise ValueError(f"Size should be a tuple (height, width), instead got {size}")
78
+ self.size = size
79
+ else:
80
+ self.size = (size, size)
81
+ self.interpolation_mode = interpolation_mode
82
+
83
+ def __call__(self, image):
84
+ """
85
+ Args:
86
+ image (torch.Tensor): Image to be resized and cropped. Size is (C, H, W)
87
+
88
+ Returns:
89
+ torch.Tensor: Resized and cropped image. Size is (C, target_height, target_width)
90
+ """
91
+ target_height, target_width = self.size
92
+ target_aspect = target_width / target_height
93
+
94
+ # Get current image shape and aspect ratio
95
+ _, height, width = image.shape
96
+ height, width = float(height), float(width)
97
+ current_aspect = width / height
98
+
99
+ # Calculate crop dimensions
100
+ if current_aspect > target_aspect:
101
+ # Image is wider than target, crop width
102
+ crop_height = height
103
+ crop_width = height * target_aspect
104
+ else:
105
+ # Image is taller than target, crop height
106
+ crop_height = width / target_aspect
107
+ crop_width = width
108
+
109
+ # Calculate crop coordinates (center crop)
110
+ y1 = (height - crop_height) / 2
111
+ x1 = (width - crop_width) / 2
112
+
113
+ # Perform the crop
114
+ cropped_image = crop(image, int(y1), int(x1), int(crop_height), int(crop_width))
115
+
116
+ # Resize the cropped image to the target size
117
+ resized_image = resize(cropped_image, self.size, self.interpolation_mode)
118
+
119
+ return resized_image
120
+
121
+ # Example usage
122
+ if __name__ == "__main__":
123
+ # Create a sample image tensor
124
+ sample_image = torch.rand(3, 480, 640) # (C, H, W)
125
+
126
+ # Initialize the transform
127
+ transform = CenterCropResizeImage(size=(224, 224), interpolation_mode="bilinear")
128
+
129
+ # Apply the transform
130
+ transformed_image = transform(sample_image)
131
+
132
+ print(f"Original image shape: {sample_image.shape}")
133
+ print(f"Transformed image shape: {transformed_image.shape}")
onediffusion/dataset/utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ASPECT_RATIO_2880 = {
3
+ '0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0],
4
+ '0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0],
5
+ '0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0],
6
+ '0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0],
7
+ '0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0],
8
+ '1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0],
9
+ '1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0],
10
+ '1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0],
11
+ '2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0],
12
+ '3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0]
13
+ }
14
+
15
+ ASPECT_RATIO_2048 = {
16
+ '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0],
17
+ '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
18
+ '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
19
+ '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
20
+ '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
21
+ '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
22
+ '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
23
+ '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
24
+ '2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0],
25
+ '3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0]
26
+ }
27
+
28
+ ASPECT_RATIO_1024 = {
29
+ '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
30
+ '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
31
+ '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
32
+ '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
33
+ '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
34
+ '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
35
+ '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
36
+ '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
37
+ '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
38
+ '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
39
+ }
40
+
41
+ ASPECT_RATIO_512 = {
42
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
43
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
44
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
45
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
46
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
47
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
48
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
49
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
50
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
51
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
52
+ }
53
+
54
+
55
+ ASPECT_RATIO_384 = {
56
+ '0.25': [192.0, 768.0],
57
+ '0.26': [192.0, 736.0],
58
+ '0.27': [208.0, 768.0],
59
+ '0.28': [208.0, 736.0],
60
+ '0.33': [240.0, 720.0],
61
+ '0.4': [256.0, 640.0],
62
+ '0.42': [304.0, 720.0],
63
+ '0.48': [368.0, 768.0],
64
+ '0.5': [384.0, 768.0],
65
+ '0.52': [384.0, 736.0],
66
+ '0.57': [384.0, 672.0],
67
+ '0.6': [384.0, 640.0],
68
+ '0.73': [384.0, 528.0],
69
+ '0.77': [384.0, 496.0],
70
+ '0.83': [384.0, 464.0],
71
+ '0.89': [384.0, 432.0],
72
+ '0.92': [384.0, 416.0],
73
+ '1.0': [384.0, 384.0],
74
+ '1.09': [384.0, 352.0],
75
+ '1.14': [384.0, 336.0],
76
+ '1.2': [384.0, 320.0],
77
+ '1.26': [384.0, 304.0],
78
+ '1.33': [384.0, 288.0],
79
+ '1.41': [384.0, 272.0],
80
+ '1.6': [384.0, 240.0],
81
+ '1.71': [384.0, 224.0],
82
+ '2.0': [384.0, 192.0],
83
+ '2.4': [384.0, 160.0],
84
+ '2.88': [368.0, 128.0],
85
+ '3.0': [384.0, 128.0],
86
+ '3.43': [384.0, 112.0],
87
+ '4.0': [384.0, 96.0]
88
+ }
89
+
90
+ ASPECT_RATIO_256 = {
91
+ '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
92
+ '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
93
+ '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
94
+ '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
95
+ '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
96
+ '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
97
+ '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
98
+ '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
99
+ '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
100
+ '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
101
+ }
102
+
103
+ ASPECT_RATIO_256_TEST = {
104
+ '0.25': [128.0, 512.0], '0.28': [128.0, 464.0],
105
+ '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
106
+ '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
107
+ '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
108
+ '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
109
+ '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
110
+ '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
111
+ '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
112
+ '2.5': [400.0, 160.0], '3.0': [432.0, 144.0],
113
+ '4.0': [512.0, 128.0]
114
+ }
115
+
116
+ ASPECT_RATIO_512_TEST = {
117
+ '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0],
118
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
119
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
120
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
121
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
122
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
123
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
124
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
125
+ '2.5': [800.0, 320.0], '3.0': [864.0, 288.0],
126
+ '4.0': [1024.0, 256.0]
127
+ }
128
+
129
+ ASPECT_RATIO_1024_TEST = {
130
+ '0.25': [512., 2048.], '0.28': [512., 1856.],
131
+ '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
132
+ '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
133
+ '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
134
+ '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
135
+ '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
136
+ '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
137
+ '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
138
+ '2.5': [1600., 640.], '3.0': [1728., 576.],
139
+ '4.0': [2048., 512.],
140
+ }
141
+
142
+ ASPECT_RATIO_2048_TEST = {
143
+ '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0],
144
+ '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
145
+ '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
146
+ '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
147
+ '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
148
+ '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
149
+ '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
150
+ '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
151
+ '2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0],
152
+ '4.0': [4096.0, 1024.0]
153
+ }
154
+
155
+ ASPECT_RATIO_2880_TEST = {
156
+ '0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0],
157
+ '0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0],
158
+ '0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0],
159
+ '0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0],
160
+ '0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0],
161
+ '1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0],
162
+ '1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0],
163
+ '1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0],
164
+ '2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0],
165
+ '4.0': [8192.0, 2048.0],
166
+ }
167
+
168
+ def get_chunks(lst, n):
169
+ for i in range(0, len(lst), n):
170
+ yield lst[i:i + n]
171
+
172
+ def get_closest_ratio(height: float, width: float, ratios: dict):
173
+ aspect_ratio = height / width
174
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
175
+ return ratios[closest_ratio], float(closest_ratio)
onediffusion/diffusion/pipelines/image_processor.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import warnings
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchvision.transforms as T
24
+ from PIL import Image, ImageFilter, ImageOps
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
28
+
29
+ from onediffusion.dataset.transforms import CenterCropResizeImage
30
+
31
+ PipelineImageInput = Union[
32
+ PIL.Image.Image,
33
+ np.ndarray,
34
+ torch.Tensor,
35
+ List[PIL.Image.Image],
36
+ List[np.ndarray],
37
+ List[torch.Tensor],
38
+ ]
39
+
40
+ PipelineDepthInput = PipelineImageInput
41
+
42
+
43
+ def is_valid_image(image):
44
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
45
+
46
+
47
+ def is_valid_image_imagelist(images):
48
+ # check if the image input is one of the supported formats for image and image list:
49
+ # it can be either one of below 3
50
+ # (1) a 4d pytorch tensor or numpy array,
51
+ # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
52
+ # (3) a list of valid image
53
+ if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
54
+ return True
55
+ elif is_valid_image(images):
56
+ return True
57
+ elif isinstance(images, list):
58
+ return all(is_valid_image(image) for image in images)
59
+ return False
60
+
61
+
62
+ class VaeImageProcessorOneDiffuser(ConfigMixin):
63
+ """
64
+ Image processor for VAE.
65
+
66
+ Args:
67
+ do_resize (`bool`, *optional*, defaults to `True`):
68
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
69
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
70
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
71
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
72
+ resample (`str`, *optional*, defaults to `lanczos`):
73
+ Resampling filter to use when resizing the image.
74
+ do_normalize (`bool`, *optional*, defaults to `True`):
75
+ Whether to normalize the image to [-1,1].
76
+ do_binarize (`bool`, *optional*, defaults to `False`):
77
+ Whether to binarize the image to 0/1.
78
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
79
+ Whether to convert the images to RGB format.
80
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
81
+ Whether to convert the images to grayscale format.
82
+ """
83
+
84
+ config_name = CONFIG_NAME
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ do_resize: bool = True,
90
+ vae_scale_factor: int = 8,
91
+ vae_latent_channels: int = 4,
92
+ resample: str = "lanczos",
93
+ do_normalize: bool = True,
94
+ do_binarize: bool = False,
95
+ do_convert_rgb: bool = False,
96
+ do_convert_grayscale: bool = False,
97
+ ):
98
+ super().__init__()
99
+ if do_convert_rgb and do_convert_grayscale:
100
+ raise ValueError(
101
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
102
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
103
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
104
+ )
105
+
106
+ @staticmethod
107
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
108
+ """
109
+ Convert a numpy image or a batch of images to a PIL image.
110
+ """
111
+ if images.ndim == 3:
112
+ images = images[None, ...]
113
+ images = (images * 255).round().astype("uint8")
114
+ if images.shape[-1] == 1:
115
+ # special case for grayscale (single channel) images
116
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
117
+ else:
118
+ pil_images = [Image.fromarray(image) for image in images]
119
+
120
+ return pil_images
121
+
122
+ @staticmethod
123
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
124
+ """
125
+ Convert a PIL image or a list of PIL images to NumPy arrays.
126
+ """
127
+ if not isinstance(images, list):
128
+ images = [images]
129
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
130
+ images = np.stack(images, axis=0)
131
+
132
+ return images
133
+
134
+ @staticmethod
135
+ def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
136
+ """
137
+ Convert a NumPy image to a PyTorch tensor.
138
+ """
139
+ if images.ndim == 3:
140
+ images = images[..., None]
141
+
142
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
143
+ return images
144
+
145
+ @staticmethod
146
+ def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
147
+ """
148
+ Convert a PyTorch tensor to a NumPy image.
149
+ """
150
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
151
+ return images
152
+
153
+ @staticmethod
154
+ def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
155
+ """
156
+ Normalize an image array to [-1,1].
157
+ """
158
+ return 2.0 * images - 1.0
159
+
160
+ @staticmethod
161
+ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
162
+ """
163
+ Denormalize an image array to [0,1].
164
+ """
165
+ return (images / 2 + 0.5).clamp(0, 1)
166
+
167
+ @staticmethod
168
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
169
+ """
170
+ Converts a PIL image to RGB format.
171
+ """
172
+ image = image.convert("RGB")
173
+
174
+ return image
175
+
176
+ @staticmethod
177
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
178
+ """
179
+ Converts a PIL image to grayscale format.
180
+ """
181
+ image = image.convert("L")
182
+
183
+ return image
184
+
185
+ @staticmethod
186
+ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
187
+ """
188
+ Applies Gaussian blur to an image.
189
+ """
190
+ image = image.filter(ImageFilter.GaussianBlur(blur_factor))
191
+
192
+ return image
193
+
194
+ @staticmethod
195
+ def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
196
+ """
197
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
198
+ ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
199
+ processing are 512x512, the region will be expanded to 128x128.
200
+
201
+ Args:
202
+ mask_image (PIL.Image.Image): Mask image.
203
+ width (int): Width of the image to be processed.
204
+ height (int): Height of the image to be processed.
205
+ pad (int, optional): Padding to be added to the crop region. Defaults to 0.
206
+
207
+ Returns:
208
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
209
+ matches the original aspect ratio.
210
+ """
211
+
212
+ mask_image = mask_image.convert("L")
213
+ mask = np.array(mask_image)
214
+
215
+ # 1. find a rectangular region that contains all masked ares in an image
216
+ h, w = mask.shape
217
+ crop_left = 0
218
+ for i in range(w):
219
+ if not (mask[:, i] == 0).all():
220
+ break
221
+ crop_left += 1
222
+
223
+ crop_right = 0
224
+ for i in reversed(range(w)):
225
+ if not (mask[:, i] == 0).all():
226
+ break
227
+ crop_right += 1
228
+
229
+ crop_top = 0
230
+ for i in range(h):
231
+ if not (mask[i] == 0).all():
232
+ break
233
+ crop_top += 1
234
+
235
+ crop_bottom = 0
236
+ for i in reversed(range(h)):
237
+ if not (mask[i] == 0).all():
238
+ break
239
+ crop_bottom += 1
240
+
241
+ # 2. add padding to the crop region
242
+ x1, y1, x2, y2 = (
243
+ int(max(crop_left - pad, 0)),
244
+ int(max(crop_top - pad, 0)),
245
+ int(min(w - crop_right + pad, w)),
246
+ int(min(h - crop_bottom + pad, h)),
247
+ )
248
+
249
+ # 3. expands crop region to match the aspect ratio of the image to be processed
250
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
251
+ ratio_processing = width / height
252
+
253
+ if ratio_crop_region > ratio_processing:
254
+ desired_height = (x2 - x1) / ratio_processing
255
+ desired_height_diff = int(desired_height - (y2 - y1))
256
+ y1 -= desired_height_diff // 2
257
+ y2 += desired_height_diff - desired_height_diff // 2
258
+ if y2 >= mask_image.height:
259
+ diff = y2 - mask_image.height
260
+ y2 -= diff
261
+ y1 -= diff
262
+ if y1 < 0:
263
+ y2 -= y1
264
+ y1 -= y1
265
+ if y2 >= mask_image.height:
266
+ y2 = mask_image.height
267
+ else:
268
+ desired_width = (y2 - y1) * ratio_processing
269
+ desired_width_diff = int(desired_width - (x2 - x1))
270
+ x1 -= desired_width_diff // 2
271
+ x2 += desired_width_diff - desired_width_diff // 2
272
+ if x2 >= mask_image.width:
273
+ diff = x2 - mask_image.width
274
+ x2 -= diff
275
+ x1 -= diff
276
+ if x1 < 0:
277
+ x2 -= x1
278
+ x1 -= x1
279
+ if x2 >= mask_image.width:
280
+ x2 = mask_image.width
281
+
282
+ return x1, y1, x2, y2
283
+
284
+ def _resize_and_fill(
285
+ self,
286
+ image: PIL.Image.Image,
287
+ width: int,
288
+ height: int,
289
+ ) -> PIL.Image.Image:
290
+ """
291
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
292
+ the image within the dimensions, filling empty with data from image.
293
+
294
+ Args:
295
+ image: The image to resize.
296
+ width: The width to resize the image to.
297
+ height: The height to resize the image to.
298
+ """
299
+
300
+ ratio = width / height
301
+ src_ratio = image.width / image.height
302
+
303
+ src_w = width if ratio < src_ratio else image.width * height // image.height
304
+ src_h = height if ratio >= src_ratio else image.height * width // image.width
305
+
306
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
307
+ res = Image.new("RGB", (width, height))
308
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
309
+
310
+ if ratio < src_ratio:
311
+ fill_height = height // 2 - src_h // 2
312
+ if fill_height > 0:
313
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
314
+ res.paste(
315
+ resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
316
+ box=(0, fill_height + src_h),
317
+ )
318
+ elif ratio > src_ratio:
319
+ fill_width = width // 2 - src_w // 2
320
+ if fill_width > 0:
321
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
322
+ res.paste(
323
+ resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
324
+ box=(fill_width + src_w, 0),
325
+ )
326
+
327
+ return res
328
+
329
+ def _resize_and_crop(
330
+ self,
331
+ image: PIL.Image.Image,
332
+ width: int,
333
+ height: int,
334
+ ) -> PIL.Image.Image:
335
+ """
336
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
337
+ the image within the dimensions, cropping the excess.
338
+
339
+ Args:
340
+ image: The image to resize.
341
+ width: The width to resize the image to.
342
+ height: The height to resize the image to.
343
+ """
344
+ ratio = width / height
345
+ src_ratio = image.width / image.height
346
+
347
+ src_w = width if ratio > src_ratio else image.width * height // image.height
348
+ src_h = height if ratio <= src_ratio else image.height * width // image.width
349
+
350
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
351
+ res = Image.new("RGB", (width, height))
352
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
353
+ return res
354
+
355
+ def resize(
356
+ self,
357
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
358
+ height: int,
359
+ width: int,
360
+ resize_mode: str = "default", # "default", "fill", "crop"
361
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
362
+ """
363
+ Resize image.
364
+
365
+ Args:
366
+ image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
367
+ The image input, can be a PIL image, numpy array or pytorch tensor.
368
+ height (`int`):
369
+ The height to resize to.
370
+ width (`int`):
371
+ The width to resize to.
372
+ resize_mode (`str`, *optional*, defaults to `default`):
373
+ The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
374
+ within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
375
+ will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
376
+ then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
377
+ the image to fit within the specified width and height, maintaining the aspect ratio, and then center
378
+ the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
379
+ supported for PIL image input.
380
+
381
+ Returns:
382
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
383
+ The resized image.
384
+ """
385
+ if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
386
+ raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
387
+ if isinstance(image, PIL.Image.Image):
388
+ if resize_mode == "default":
389
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
390
+ elif resize_mode == "fill":
391
+ image = self._resize_and_fill(image, width, height)
392
+ elif resize_mode == "crop":
393
+ image = self._resize_and_crop(image, width, height)
394
+ else:
395
+ raise ValueError(f"resize_mode {resize_mode} is not supported")
396
+
397
+ elif isinstance(image, torch.Tensor):
398
+ image = torch.nn.functional.interpolate(
399
+ image,
400
+ size=(height, width),
401
+ )
402
+ elif isinstance(image, np.ndarray):
403
+ image = self.numpy_to_pt(image)
404
+ image = torch.nn.functional.interpolate(
405
+ image,
406
+ size=(height, width),
407
+ )
408
+ image = self.pt_to_numpy(image)
409
+ return image
410
+
411
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
412
+ """
413
+ Create a mask.
414
+
415
+ Args:
416
+ image (`PIL.Image.Image`):
417
+ The image input, should be a PIL image.
418
+
419
+ Returns:
420
+ `PIL.Image.Image`:
421
+ The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
422
+ """
423
+ image[image < 0.5] = 0
424
+ image[image >= 0.5] = 1
425
+
426
+ return image
427
+
428
+ def get_default_height_width(
429
+ self,
430
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
431
+ height: Optional[int] = None,
432
+ width: Optional[int] = None,
433
+ ) -> Tuple[int, int]:
434
+ """
435
+ This function return the height and width that are downscaled to the next integer multiple of
436
+ `vae_scale_factor`.
437
+
438
+ Args:
439
+ image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
440
+ The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
441
+ shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
442
+ have shape `[batch, channel, height, width]`.
443
+ height (`int`, *optional*, defaults to `None`):
444
+ The height in preprocessed image. If `None`, will use the height of `image` input.
445
+ width (`int`, *optional*`, defaults to `None`):
446
+ The width in preprocessed. If `None`, will use the width of the `image` input.
447
+ """
448
+
449
+ if height is None:
450
+ if isinstance(image, PIL.Image.Image):
451
+ height = image.height
452
+ elif isinstance(image, torch.Tensor):
453
+ height = image.shape[2]
454
+ else:
455
+ height = image.shape[1]
456
+
457
+ if width is None:
458
+ if isinstance(image, PIL.Image.Image):
459
+ width = image.width
460
+ elif isinstance(image, torch.Tensor):
461
+ width = image.shape[3]
462
+ else:
463
+ width = image.shape[2]
464
+
465
+ width, height = (
466
+ x - x % self.config.vae_scale_factor for x in (width, height)
467
+ ) # resize to integer multiple of vae_scale_factor
468
+
469
+ return height, width
470
+
471
+ def preprocess(
472
+ self,
473
+ image: PipelineImageInput,
474
+ height: Optional[int] = None,
475
+ width: Optional[int] = None,
476
+ resize_mode: str = "default", # "default", "fill", "crop"
477
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
478
+ do_crop: bool = True,
479
+ ) -> torch.Tensor:
480
+ """
481
+ Preprocess the image input.
482
+
483
+ Args:
484
+ image (`pipeline_image_input`):
485
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
486
+ supported formats.
487
+ height (`int`, *optional*, defaults to `None`):
488
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
489
+ height.
490
+ width (`int`, *optional*`, defaults to `None`):
491
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
492
+ resize_mode (`str`, *optional*, defaults to `default`):
493
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
494
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
495
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
496
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
497
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
498
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
499
+ supported for PIL image input.
500
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
501
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
502
+ """
503
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
504
+
505
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
506
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
507
+ if isinstance(image, torch.Tensor):
508
+ # if image is a pytorch tensor could have 2 possible shapes:
509
+ # 1. batch x height x width: we should insert the channel dimension at position 1
510
+ # 2. channel x height x width: we should insert batch dimension at position 0,
511
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
512
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
513
+ image = image.unsqueeze(1)
514
+ else:
515
+ # if it is a numpy array, it could have 2 possible shapes:
516
+ # 1. batch x height x width: insert channel dimension on last position
517
+ # 2. height x width x channel: insert batch dimension on first position
518
+ if image.shape[-1] == 1:
519
+ image = np.expand_dims(image, axis=0)
520
+ else:
521
+ image = np.expand_dims(image, axis=-1)
522
+
523
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
524
+ warnings.warn(
525
+ "Passing `image` as a list of 4d np.ndarray is deprecated."
526
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
527
+ FutureWarning,
528
+ )
529
+ image = np.concatenate(image, axis=0)
530
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
531
+ warnings.warn(
532
+ "Passing `image` as a list of 4d torch.Tensor is deprecated."
533
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
534
+ FutureWarning,
535
+ )
536
+ image = torch.cat(image, axis=0)
537
+
538
+ if not is_valid_image_imagelist(image):
539
+ raise ValueError(
540
+ f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
541
+ )
542
+ if not isinstance(image, list):
543
+ image = [image]
544
+
545
+ if isinstance(image[0], PIL.Image.Image):
546
+ pass
547
+ elif isinstance(image[0], np.ndarray):
548
+ image = self.numpy_to_pil(image)
549
+ elif isinstance(image[0], torch.Tensor):
550
+ image = self.pt_to_numpy(image)
551
+ image = self.numpy_to_pil(image)
552
+
553
+ if do_crop:
554
+ transforms = T.Compose([
555
+ T.Lambda(lambda image: image.convert('RGB')),
556
+ T.ToTensor(),
557
+ CenterCropResizeImage((height, width)),
558
+ T.Normalize([.5], [.5]),
559
+ ])
560
+ else:
561
+ transforms = T.Compose([
562
+ T.Lambda(lambda image: image.convert('RGB')),
563
+ T.ToTensor(),
564
+ T.Resize((height, width)),
565
+ T.Normalize([.5], [.5]),
566
+ ])
567
+ image = torch.stack([transforms(i) for i in image])
568
+
569
+ # expected range [0,1], normalize to [-1,1]
570
+ do_normalize = self.config.do_normalize
571
+ if do_normalize and image.min() < 0:
572
+ warnings.warn(
573
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
574
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
575
+ FutureWarning,
576
+ )
577
+ do_normalize = False
578
+ if do_normalize:
579
+ image = self.normalize(image)
580
+
581
+ if self.config.do_binarize:
582
+ image = self.binarize(image)
583
+
584
+ return image
585
+
586
+ def postprocess(
587
+ self,
588
+ image: torch.Tensor,
589
+ output_type: str = "pil",
590
+ do_denormalize: Optional[List[bool]] = None,
591
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
592
+ """
593
+ Postprocess the image output from tensor to `output_type`.
594
+
595
+ Args:
596
+ image (`torch.Tensor`):
597
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
598
+ output_type (`str`, *optional*, defaults to `pil`):
599
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
600
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
601
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
602
+ `VaeImageProcessor` config.
603
+
604
+ Returns:
605
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
606
+ The postprocessed image.
607
+ """
608
+ if not isinstance(image, torch.Tensor):
609
+ raise ValueError(
610
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
611
+ )
612
+ if output_type not in ["latent", "pt", "np", "pil"]:
613
+ deprecation_message = (
614
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
615
+ "`pil`, `np`, `pt`, `latent`"
616
+ )
617
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
618
+ output_type = "np"
619
+
620
+ if output_type == "latent":
621
+ return image
622
+
623
+ if do_denormalize is None:
624
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
625
+
626
+ image = torch.stack(
627
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
628
+ )
629
+
630
+ if output_type == "pt":
631
+ return image
632
+
633
+ image = self.pt_to_numpy(image)
634
+
635
+ if output_type == "np":
636
+ return image
637
+
638
+ if output_type == "pil":
639
+ return self.numpy_to_pil(image)
640
+
641
+ def apply_overlay(
642
+ self,
643
+ mask: PIL.Image.Image,
644
+ init_image: PIL.Image.Image,
645
+ image: PIL.Image.Image,
646
+ crop_coords: Optional[Tuple[int, int, int, int]] = None,
647
+ ) -> PIL.Image.Image:
648
+ """
649
+ overlay the inpaint output to the original image
650
+ """
651
+
652
+ width, height = image.width, image.height
653
+
654
+ init_image = self.resize(init_image, width=width, height=height)
655
+ mask = self.resize(mask, width=width, height=height)
656
+
657
+ init_image_masked = PIL.Image.new("RGBa", (width, height))
658
+ init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
659
+ init_image_masked = init_image_masked.convert("RGBA")
660
+
661
+ if crop_coords is not None:
662
+ x, y, x2, y2 = crop_coords
663
+ w = x2 - x
664
+ h = y2 - y
665
+ base_image = PIL.Image.new("RGBA", (width, height))
666
+ image = self.resize(image, height=h, width=w, resize_mode="crop")
667
+ base_image.paste(image, (x, y))
668
+ image = base_image.convert("RGB")
669
+
670
+ image = image.convert("RGBA")
671
+ image.alpha_composite(init_image_masked)
672
+ image = image.convert("RGB")
673
+
674
+ return image
onediffusion/diffusion/pipelines/onediffusion.py ADDED
@@ -0,0 +1,1080 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import inspect
3
+ import torch
4
+ import numpy as np
5
+ import PIL
6
+ import os
7
+
8
+ from dataclasses import dataclass
9
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
10
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
+ from diffusers.utils import (
12
+ CONFIG_NAME,
13
+ DEPRECATED_REVISION_ARGS,
14
+ BaseOutput,
15
+ PushToHubMixin,
16
+ deprecate,
17
+ is_accelerate_available,
18
+ is_accelerate_version,
19
+ is_torch_npu_available,
20
+ is_torch_version,
21
+ logging,
22
+ numpy_to_pil,
23
+ replace_example_docstring,
24
+ )
25
+ from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+ from diffusers.utils import BaseOutput
28
+ # from diffusers.image_processor import VaeImageProcessor
29
+ from transformers import T5EncoderModel, T5Tokenizer
30
+ from typing import Any, Callable, Dict, List, Optional, Union
31
+ from PIL import Image
32
+
33
+ from onediffusion.models.denoiser.nextdit import NextDiT
34
+ from onediffusion.dataset.utils import *
35
+ from onediffusion.dataset.multitask.multiview import calculate_rays
36
+ from onediffusion.diffusion.pipelines.image_processor import VaeImageProcessorOneDiffuser
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ SUPPORTED_DEVICE_MAP = ["balanced"]
41
+
42
+ EXAMPLE_DOC_STRING = """
43
+ Examples:
44
+ ```py
45
+ >>> import torch
46
+ >>> from one_diffusion import OneDiffusionPipeline
47
+
48
+ >>> pipe = OneDiffusionPipeline.from_pretrained("path_to_one_diffuser_model")
49
+ >>> pipe = pipe.to("cuda")
50
+
51
+ >>> prompt = "A beautiful sunset over the ocean"
52
+ >>> image = pipe(prompt).images[0]
53
+ >>> image.save("beautiful_sunset.png")
54
+ ```
55
+ """
56
+
57
+ def create_c2w_matrix(azimuth_deg, elevation_deg, distance=1.0, target=np.array([0, 0, 0])):
58
+ """
59
+ Create a Camera-to-World (C2W) matrix from azimuth and elevation angles.
60
+
61
+ Parameters:
62
+ - azimuth_deg: Azimuth angle in degrees.
63
+ - elevation_deg: Elevation angle in degrees.
64
+ - distance: Distance from the target point.
65
+ - target: The point the camera is looking at in world coordinates.
66
+
67
+ Returns:
68
+ - C2W: A 4x4 NumPy array representing the Camera-to-World transformation matrix.
69
+ """
70
+ # Convert angles from degrees to radians
71
+ azimuth = np.deg2rad(azimuth_deg)
72
+ elevation = np.deg2rad(elevation_deg)
73
+
74
+ # Spherical to Cartesian conversion for camera position
75
+ x = distance * np.cos(elevation) * np.cos(azimuth)
76
+ y = distance * np.cos(elevation) * np.sin(azimuth)
77
+ z = distance * np.sin(elevation)
78
+ camera_position = np.array([x, y, z])
79
+
80
+ # Define the forward vector (from camera to target)
81
+ target = 2*camera_position - target
82
+ forward = target - camera_position
83
+ forward /= np.linalg.norm(forward)
84
+
85
+ # Define the world up vector
86
+ world_up = np.array([0, 0, 1])
87
+
88
+ # Compute the right vector
89
+ right = np.cross(world_up, forward)
90
+ if np.linalg.norm(right) < 1e-6:
91
+ # Handle the singularity when forward is parallel to world_up
92
+ world_up = np.array([0, 1, 0])
93
+ right = np.cross(world_up, forward)
94
+ right /= np.linalg.norm(right)
95
+
96
+ # Recompute the orthogonal up vector
97
+ up = np.cross(forward, right)
98
+
99
+ # Construct the rotation matrix
100
+ rotation = np.vstack([right, up, forward]).T # 3x3
101
+
102
+ # Construct the full C2W matrix
103
+ C2W = np.eye(4)
104
+ C2W[:3, :3] = rotation
105
+ C2W[:3, 3] = camera_position
106
+
107
+ return C2W
108
+
109
+ @dataclass
110
+ class OneDiffusionPipelineOutput(BaseOutput):
111
+ """
112
+ Output class for Stable Diffusion pipelines.
113
+
114
+ Args:
115
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
116
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
117
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
118
+ """
119
+
120
+ images: Union[List[Image.Image], np.ndarray]
121
+ latents: Optional[torch.Tensor] = None
122
+
123
+
124
+ def retrieve_latents(
125
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
126
+ ):
127
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
128
+ return encoder_output.latent_dist.sample(generator)
129
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
130
+ return encoder_output.latent_dist.mode()
131
+ elif hasattr(encoder_output, "latents"):
132
+ return encoder_output.latents
133
+ else:
134
+ raise AttributeError("Could not access latents of provided encoder_output")
135
+
136
+
137
+ def calculate_shift(
138
+ image_seq_len,
139
+ base_seq_len: int = 256,
140
+ max_seq_len: int = 4096,
141
+ base_shift: float = 0.5,
142
+ max_shift: float = 1.16,
143
+ # max_clip: float = 1.5,
144
+ ):
145
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len) # 0.000169270833
146
+ b = base_shift - m * base_seq_len # 0.5-0.0433333332
147
+ mu = image_seq_len * m + b
148
+ # mu = min(mu, max_clip)
149
+ return mu
150
+
151
+
152
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
153
+ def retrieve_timesteps(
154
+ scheduler,
155
+ num_inference_steps: Optional[int] = None,
156
+ device: Optional[Union[str, torch.device]] = None,
157
+ timesteps: Optional[List[int]] = None,
158
+ sigmas: Optional[List[float]] = None,
159
+ **kwargs,
160
+ ):
161
+ """
162
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
163
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
164
+
165
+ Args:
166
+ scheduler (`SchedulerMixin`):
167
+ The scheduler to get timesteps from.
168
+ num_inference_steps (`int`):
169
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
170
+ must be `None`.
171
+ device (`str` or `torch.device`, *optional*):
172
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
173
+ timesteps (`List[int]`, *optional*):
174
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
175
+ `num_inference_steps` and `sigmas` must be `None`.
176
+ sigmas (`List[float]`, *optional*):
177
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
178
+ `num_inference_steps` and `timesteps` must be `None`.
179
+
180
+ Returns:
181
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
182
+ second element is the number of inference steps.
183
+ """
184
+ if timesteps is not None and sigmas is not None:
185
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
186
+ if timesteps is not None:
187
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
188
+ if not accepts_timesteps:
189
+ raise ValueError(
190
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
191
+ f" timestep schedules. Please check whether you are using the correct scheduler."
192
+ )
193
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
194
+ timesteps = scheduler.timesteps
195
+ num_inference_steps = len(timesteps)
196
+ elif sigmas is not None:
197
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
198
+ if not accept_sigmas:
199
+ raise ValueError(
200
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
201
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
202
+ )
203
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
204
+ timesteps = scheduler.timesteps
205
+ num_inference_steps = len(timesteps)
206
+ else:
207
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
208
+ timesteps = scheduler.timesteps
209
+ return timesteps, num_inference_steps
210
+
211
+
212
+
213
+ class OneDiffusionPipeline(DiffusionPipeline):
214
+ r"""
215
+ Pipeline for text-to-image generation using OneDiffuser.
216
+
217
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
218
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
219
+
220
+ Args:
221
+ transformer ([`NextDiT`]):
222
+ Conditional transformer (NextDiT) architecture to denoise the encoded image latents.
223
+ vae ([`AutoencoderKL`]):
224
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
225
+ text_encoder ([`T5EncoderModel`]):
226
+ Frozen text-encoder. OneDiffuser uses the T5 model as text encoder.
227
+ tokenizer (`T5Tokenizer`):
228
+ Tokenizer of class T5Tokenizer.
229
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
230
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ transformer: NextDiT,
236
+ vae: AutoencoderKL,
237
+ text_encoder: T5EncoderModel,
238
+ tokenizer: T5Tokenizer,
239
+ scheduler: FlowMatchEulerDiscreteScheduler,
240
+ ):
241
+ super().__init__()
242
+ self.register_modules(
243
+ transformer=transformer,
244
+ vae=vae,
245
+ text_encoder=text_encoder,
246
+ tokenizer=tokenizer,
247
+ scheduler=scheduler,
248
+ )
249
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
250
+ self.image_processor = VaeImageProcessorOneDiffuser(vae_scale_factor=self.vae_scale_factor)
251
+
252
+ def enable_vae_slicing(self):
253
+ self.vae.enable_slicing()
254
+
255
+ def disable_vae_slicing(self):
256
+ self.vae.disable_slicing()
257
+
258
+ def enable_sequential_cpu_offload(self, gpu_id=0):
259
+ if is_accelerate_available():
260
+ from accelerate import cpu_offload
261
+ else:
262
+ raise ImportError("Please install accelerate via `pip install accelerate`")
263
+
264
+ device = torch.device(f"cuda:{gpu_id}")
265
+
266
+ for cpu_offloaded_model in [self.transformer, self.text_encoder, self.vae]:
267
+ if cpu_offloaded_model is not None:
268
+ cpu_offload(cpu_offloaded_model, device)
269
+
270
+ @property
271
+ def _execution_device(self):
272
+ if self.device != torch.device("meta") or not hasattr(self.transformer, "_hf_hook"):
273
+ return self.device
274
+ for module in self.transformer.modules():
275
+ if (
276
+ hasattr(module, "_hf_hook")
277
+ and hasattr(module._hf_hook, "execution_device")
278
+ and module._hf_hook.execution_device is not None
279
+ ):
280
+ return torch.device(module._hf_hook.execution_device)
281
+ return self.device
282
+
283
+ def encode_prompt(
284
+ self,
285
+ prompt,
286
+ device,
287
+ num_images_per_prompt,
288
+ do_classifier_free_guidance,
289
+ negative_prompt=None,
290
+ max_length=300,
291
+ ):
292
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
293
+
294
+ text_inputs = self.tokenizer(
295
+ prompt,
296
+ padding="max_length",
297
+ max_length=max_length,
298
+ truncation=True,
299
+ add_special_tokens=True,
300
+ return_tensors="pt",
301
+ )
302
+ text_input_ids = text_inputs.input_ids
303
+ attention_mask = text_inputs.attention_mask
304
+
305
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
306
+
307
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
308
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
309
+ logger.warning(
310
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
311
+ f" {max_length} tokens: {removed_text}"
312
+ )
313
+
314
+ text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
315
+ prompt_embeds = text_encoder_output[0].to(torch.float32)
316
+
317
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
318
+ bs_embed, seq_len, _ = prompt_embeds.shape
319
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
320
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
321
+
322
+ # duplicate attention mask for each generation per prompt
323
+ attention_mask = attention_mask.repeat(1, num_images_per_prompt)
324
+ attention_mask = attention_mask.view(bs_embed * num_images_per_prompt, -1)
325
+
326
+ # get unconditional embeddings for classifier free guidance
327
+ if do_classifier_free_guidance:
328
+ uncond_tokens: List[str]
329
+ if negative_prompt is None:
330
+ uncond_tokens = [""] * batch_size
331
+ elif type(prompt) is not type(negative_prompt):
332
+ raise TypeError(
333
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
334
+ f" {type(prompt)}."
335
+ )
336
+ elif isinstance(negative_prompt, str):
337
+ uncond_tokens = [negative_prompt]
338
+ elif batch_size != len(negative_prompt):
339
+ raise ValueError(
340
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
341
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
342
+ " the batch size of `prompt`."
343
+ )
344
+ else:
345
+ uncond_tokens = negative_prompt
346
+
347
+ max_length = text_input_ids.shape[-1]
348
+ uncond_input = self.tokenizer(
349
+ uncond_tokens,
350
+ padding="max_length",
351
+ max_length=max_length,
352
+ truncation=True,
353
+ return_tensors="pt",
354
+ )
355
+
356
+ uncond_encoder_output = self.text_encoder(uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device))
357
+ negative_prompt_embeds = uncond_encoder_output[0].to(torch.float32)
358
+
359
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
360
+ seq_len = negative_prompt_embeds.shape[1]
361
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
362
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
363
+
364
+ # duplicate unconditional attention mask for each generation per prompt
365
+ uncond_attention_mask = uncond_input.attention_mask.repeat(1, num_images_per_prompt)
366
+ uncond_attention_mask = uncond_attention_mask.view(batch_size * num_images_per_prompt, -1)
367
+
368
+ # For classifier free guidance, we need to do two forward passes.
369
+ # Here we concatenate the unconditional and text embeddings into a single batch
370
+ # to avoid doing two forward passes
371
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
372
+ attention_mask = torch.cat([uncond_attention_mask, attention_mask])
373
+
374
+ return prompt_embeds.to(device), attention_mask.to(device)
375
+
376
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
377
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
378
+ if isinstance(generator, list) and len(generator) != batch_size:
379
+ raise ValueError(
380
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
381
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
382
+ )
383
+
384
+ if latents is None:
385
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
386
+ else:
387
+ latents = latents.to(device)
388
+
389
+ # scale the initial noise by the standard deviation required by the scheduler
390
+ latents = latents * self.scheduler.init_noise_sigma
391
+ return latents
392
+
393
+ @torch.no_grad()
394
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
395
+ def __call__(
396
+ self,
397
+ prompt: Union[str, List[str]] = None,
398
+ height: Optional[int] = None,
399
+ width: Optional[int] = None,
400
+ num_inference_steps: int = 50,
401
+ guidance_scale: float = 5.0,
402
+ negative_prompt: Optional[Union[str, List[str]]] = None,
403
+ num_images_per_prompt: Optional[int] = 1,
404
+ eta: float = 0.0,
405
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
406
+ latents: Optional[torch.FloatTensor] = None,
407
+ output_type: Optional[str] = "pil",
408
+ return_dict: bool = True,
409
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
410
+ callback_steps: int = 1,
411
+ forward_kwargs: Optional[Dict[str, Any]] = {},
412
+ **kwargs,
413
+ ):
414
+ r"""
415
+ Function invoked when calling the pipeline for generation.
416
+
417
+ Args:
418
+ prompt (`str` or `List[str]`, *optional*):
419
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
420
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size):
421
+ The height in pixels of the generated image.
422
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size):
423
+ The width in pixels of the generated image.
424
+ num_inference_steps (`int`, *optional*, defaults to 50):
425
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
426
+ expense of slower inference.
427
+ guidance_scale (`float`, *optional*, defaults to 7.5):
428
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
429
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
430
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
431
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
432
+ usually at the expense of lower image quality.
433
+ negative_prompt (`str` or `List[str]`, *optional*):
434
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
435
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
436
+ less than `1`).
437
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
438
+ The number of images to generate per prompt.
439
+ eta (`float`, *optional*, defaults to 0.0):
440
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
441
+ [`schedulers.DDIMScheduler`], will be ignored for others.
442
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
443
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
444
+ to make generation deterministic.
445
+ latents (`torch.FloatTensor`, *optional*):
446
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
447
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
448
+ tensor will ge generated by sampling using the supplied random `generator`.
449
+ output_type (`str`, *optional*, defaults to `"pil"`):
450
+ The output format of the generate image. Choose between
451
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
452
+ return_dict (`bool`, *optional*, defaults to `True`):
453
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
454
+ plain tuple.
455
+ callback (`Callable`, *optional*):
456
+ A function that will be called every `callback_steps` steps during inference. The function will be
457
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
458
+ callback_steps (`int`, *optional*, defaults to 1):
459
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
460
+ called at every step.
461
+
462
+ Examples:
463
+
464
+ Returns:
465
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
466
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
467
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
468
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
469
+ (nsfw) content, according to the `safety_checker`.
470
+ """
471
+ height = height or self.transformer.config.input_size[-2] * 8 # TODO: Hardcoded downscale factor of vae
472
+ width = width or self.transformer.config.input_size[-1] * 8
473
+
474
+ # check inputs. Raise error if not correct
475
+ self.check_inputs(prompt, height, width, callback_steps)
476
+
477
+ # define call parameters
478
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
479
+ device = self._execution_device
480
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
481
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf
482
+ do_classifier_free_guidance = guidance_scale > 1.0
483
+
484
+ encoder_hidden_states, encoder_attention_mask = self.encode_prompt(
485
+ prompt,
486
+ device,
487
+ num_images_per_prompt,
488
+ do_classifier_free_guidance,
489
+ negative_prompt,
490
+ )
491
+
492
+ # set timesteps
493
+ # # self.scheduler.set_timesteps(num_inference_steps, device=device)
494
+ # timesteps = self.scheduler.timesteps
495
+ timesteps = None
496
+
497
+ # prepare latent variables
498
+ num_channels_latents = self.transformer.config.in_channels
499
+ latents = self.prepare_latents(
500
+ batch_size * num_images_per_prompt,
501
+ num_channels_latents,
502
+ height,
503
+ width,
504
+ self.dtype,
505
+ device,
506
+ generator,
507
+ latents,
508
+ )
509
+
510
+ # prepare extra step kwargs
511
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
512
+
513
+ # 5. Prepare timesteps
514
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
515
+ image_seq_len = latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
516
+ mu = calculate_shift(
517
+ image_seq_len,
518
+ self.scheduler.config.base_image_seq_len,
519
+ self.scheduler.config.max_image_seq_len,
520
+ self.scheduler.config.base_shift,
521
+ self.scheduler.config.max_shift,
522
+ )
523
+ timesteps, num_inference_steps = retrieve_timesteps(
524
+ self.scheduler,
525
+ num_inference_steps,
526
+ device,
527
+ timesteps,
528
+ sigmas,
529
+ mu=mu,
530
+ )
531
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
532
+ self._num_timesteps = len(timesteps)
533
+
534
+ # denoising loop
535
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
536
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
537
+ for i, t in enumerate(timesteps):
538
+ # expand the latents if we are doing classifier free guidance
539
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
540
+ # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
541
+
542
+ # predict the noise residual
543
+ noise_pred = self.transformer(
544
+ samples=latent_model_input.to(self.dtype),
545
+ timesteps=torch.tensor([t] * latent_model_input.shape[0], device=device),
546
+ encoder_hidden_states=encoder_hidden_states.to(self.dtype),
547
+ encoder_attention_mask=encoder_attention_mask,
548
+ **forward_kwargs
549
+ )
550
+
551
+ # perform guidance
552
+ if do_classifier_free_guidance:
553
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
554
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
555
+
556
+ # compute the previous noisy sample x_t -> x_t-1
557
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
558
+
559
+ # call the callback, if provided
560
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
561
+ progress_bar.update()
562
+ if callback is not None and i % callback_steps == 0:
563
+ callback(i, t, latents)
564
+
565
+ # scale and decode the image latents with vae
566
+ latents = 1 / self.vae.config.scaling_factor * latents
567
+ if latents.ndim == 5:
568
+ latents = latents.squeeze(1)
569
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
570
+
571
+ image = (image / 2 + 0.5).clamp(0, 1)
572
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
573
+
574
+ if output_type == "pil":
575
+ image = self.numpy_to_pil(image)
576
+
577
+ if not return_dict:
578
+ return (image, None)
579
+
580
+ return OneDiffusionPipelineOutput(images=image)
581
+
582
+ @torch.no_grad()
583
+ def img2img(
584
+ self,
585
+ prompt: Union[str, List[str]] = None,
586
+ image: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
587
+ height: Optional[int] = None,
588
+ width: Optional[int] = None,
589
+ num_inference_steps: int = 50,
590
+ guidance_scale: float = 5.0,
591
+ denoise_mask: Optional[List[int]] = [1, 0],
592
+ negative_prompt: Optional[Union[str, List[str]]] = None,
593
+ num_images_per_prompt: Optional[int] = 1,
594
+ eta: float = 0.0,
595
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
596
+ latents: Optional[torch.FloatTensor] = None,
597
+ output_type: Optional[str] = "pil",
598
+ return_dict: bool = True,
599
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
600
+ callback_steps: int = 1,
601
+ do_crop: bool = True,
602
+ is_multiview: bool = False,
603
+ multiview_azimuths: Optional[List[int]] = [0, 30, 60, 90],
604
+ multiview_elevations: Optional[List[int]] = [0, 0, 0, 0],
605
+ multiview_distances: float = 1.7,
606
+ multiview_c2ws: Optional[List[torch.Tensor]] = None,
607
+ multiview_intrinsics: Optional[torch.Tensor] = None,
608
+ multiview_focal_length: float = 1.3887,
609
+ forward_kwargs: Optional[Dict[str, Any]] = {},
610
+ noise_scale: float = 1.0,
611
+ **kwargs,
612
+ ):
613
+ # Convert single image to list for consistent handling
614
+ if isinstance(image, PIL.Image.Image):
615
+ image = [image]
616
+
617
+ if height is None or width is None:
618
+ closest_ar = get_closest_ratio(height=image[0].size[1], width=image[0].size[0], ratios=ASPECT_RATIO_512)
619
+ height, width = int(closest_ar[0][0]), int(closest_ar[0][1])
620
+
621
+ if not isinstance(multiview_distances, list) and not isinstance(multiview_distances, tuple):
622
+ multiview_distances = [multiview_distances] * len(multiview_azimuths)
623
+
624
+ # height = height or self.transformer.config.input_size[-2] * 8 # TODO: Hardcoded downscale factor of vae
625
+ # width = width or self.transformer.config.input_size[-1] * 8
626
+
627
+ # 1. check inputs. Raise error if not correct
628
+ self.check_inputs(prompt, height, width, callback_steps)
629
+
630
+ # Additional input validation for image list
631
+ if not all(isinstance(img, PIL.Image.Image) for img in image):
632
+ raise ValueError("All elements in image list must be PIL.Image objects")
633
+
634
+ # 2. define call parameters
635
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
636
+ device = self._execution_device
637
+ do_classifier_free_guidance = guidance_scale > 1.0
638
+
639
+ # 3. Encode input prompt
640
+ encoder_hidden_states, encoder_attention_mask = self.encode_prompt(
641
+ prompt,
642
+ device,
643
+ num_images_per_prompt,
644
+ do_classifier_free_guidance,
645
+ negative_prompt,
646
+ )
647
+
648
+ # 4. Preprocess all images
649
+ if image is not None and len(image) > 0:
650
+ processed_image = self.image_processor.preprocess(image, height=height, width=width, do_crop=do_crop)
651
+ else:
652
+ processed_image = None
653
+
654
+ # # Stack processed images along the sequence dimension
655
+ # if len(processed_images) > 1:
656
+ # processed_image = torch.cat(processed_images, dim=0)
657
+ # else:
658
+ # processed_image = processed_images[0]
659
+
660
+ timesteps = None
661
+
662
+ # 6. prepare latent variables
663
+ num_channels_latents = self.transformer.config.in_channels
664
+ if processed_image is not None:
665
+ cond_latents = self.prepare_latents(
666
+ batch_size * num_images_per_prompt,
667
+ num_channels_latents,
668
+ height,
669
+ width,
670
+ self.dtype,
671
+ device,
672
+ generator,
673
+ latents,
674
+ image=processed_image,
675
+ )
676
+ else:
677
+ cond_latents = None
678
+
679
+ # 7. prepare extra step kwargs
680
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
681
+ denoise_mask = torch.tensor(denoise_mask, device=device)
682
+ denoise_indices = torch.where(denoise_mask == 1)[0]
683
+ cond_indices = torch.where(denoise_mask == 0)[0]
684
+ seq_length = denoise_mask.shape[0]
685
+
686
+ latents = self.prepare_init_latents(
687
+ batch_size * num_images_per_prompt,
688
+ seq_length,
689
+ num_channels_latents,
690
+ height,
691
+ width,
692
+ self.dtype,
693
+ device,
694
+ generator,
695
+ )
696
+
697
+ # 5. Prepare timesteps
698
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
699
+ # image_seq_len = latents.shape[1] * latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
700
+ image_seq_len = noise_scale * sum(denoise_mask) * latents.shape[-1] * latents.shape[-2] / self.transformer.config.patch_size[-1] / self.transformer.config.patch_size[-2]
701
+ # image_seq_len = 256
702
+ mu = calculate_shift(
703
+ image_seq_len,
704
+ self.scheduler.config.base_image_seq_len,
705
+ self.scheduler.config.max_image_seq_len,
706
+ self.scheduler.config.base_shift,
707
+ self.scheduler.config.max_shift,
708
+ )
709
+ timesteps, num_inference_steps = retrieve_timesteps(
710
+ self.scheduler,
711
+ num_inference_steps,
712
+ device,
713
+ timesteps,
714
+ sigmas,
715
+ mu=mu,
716
+ )
717
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
718
+ self._num_timesteps = len(timesteps)
719
+
720
+ if is_multiview:
721
+ cond_indices_images = [index // 2 for index in cond_indices if index % 2 == 0]
722
+ cond_indices_rays = [index // 2 for index in cond_indices if index % 2 == 1]
723
+
724
+ multiview_elevations = [element for element in multiview_elevations if element is not None]
725
+ multiview_azimuths = [element for element in multiview_azimuths if element is not None]
726
+ multiview_distances = [element for element in multiview_distances if element is not None]
727
+
728
+ if multiview_c2ws is None:
729
+ multiview_c2ws = [
730
+ torch.tensor(create_c2w_matrix(azimuth, elevation, distance)) for azimuth, elevation, distance in zip(multiview_azimuths, multiview_elevations, multiview_distances)
731
+ ]
732
+ c2ws = torch.stack(multiview_c2ws).float()
733
+ else:
734
+ c2ws = torch.Tensor(multiview_c2ws).float()
735
+
736
+ c2ws[:, 0:3, 1:3] *= -1
737
+ c2ws = c2ws[:, [1, 0, 2, 3], :]
738
+ c2ws[:, 2, :] *= -1
739
+
740
+ w2cs = torch.inverse(c2ws)
741
+ if multiview_intrinsics is None:
742
+ multiview_intrinsics = torch.Tensor([[[multiview_focal_length, 0, 0.5], [0, multiview_focal_length, 0.5], [0, 0, 1]]]).repeat(c2ws.shape[0], 1, 1)
743
+ K = multiview_intrinsics
744
+ Rs = w2cs[:, :3, :3]
745
+ Ts = w2cs[:, :3, 3]
746
+ sizes = torch.Tensor([[1, 1]]).repeat(c2ws.shape[0], 1)
747
+
748
+ assert height == width
749
+ cond_rays = calculate_rays(K, sizes, Rs, Ts, height // 8)
750
+ cond_rays = cond_rays.reshape(-1, height // 8, width // 8, 6)
751
+ # padding = (0, 10)
752
+ # cond_rays = torch.nn.functional.pad(cond_rays, padding, "constant", 0)
753
+ cond_rays = torch.cat([cond_rays, cond_rays, cond_rays[..., :4]], dim=-1) * 1.658
754
+ cond_rays = cond_rays[None].repeat(batch_size * num_images_per_prompt, 1, 1, 1, 1)
755
+ cond_rays = cond_rays.permute(0, 1, 4, 2, 3)
756
+ cond_rays = cond_rays.to(device, dtype=self.dtype)
757
+
758
+ latents = einops.rearrange(latents, "b (f n) c h w -> b f n c h w", n=2)
759
+ if cond_latents is not None:
760
+ latents[:, cond_indices_images, 0] = cond_latents
761
+ latents[:, cond_indices_rays, 1] = cond_rays
762
+ latents = einops.rearrange(latents, "b f n c h w -> b (f n) c h w")
763
+ else:
764
+ if cond_latents is not None:
765
+ latents[:, cond_indices] = cond_latents
766
+
767
+ # denoising loop
768
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
769
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
770
+ for i, t in enumerate(timesteps):
771
+ # expand the latents if we are doing classifier free guidance
772
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
773
+ input_t = torch.broadcast_to(einops.repeat(torch.Tensor([t]).to(device), "1 -> 1 f 1 1 1", f=latent_model_input.shape[1]), latent_model_input.shape).clone()
774
+
775
+ if is_multiview:
776
+ input_t = einops.rearrange(input_t, "b (f n) c h w -> b f n c h w", n=2)
777
+ input_t[:, cond_indices_images, 0] = self.scheduler.timesteps[-1]
778
+ input_t[:, cond_indices_rays, 1] = self.scheduler.timesteps[-1]
779
+ input_t = einops.rearrange(input_t, "b f n c h w -> b (f n) c h w")
780
+ else:
781
+ input_t[:, cond_indices] = self.scheduler.timesteps[-1]
782
+
783
+ # predict the noise residual
784
+ noise_pred = self.transformer(
785
+ samples=latent_model_input.to(self.dtype),
786
+ timesteps=input_t,
787
+ encoder_hidden_states=encoder_hidden_states.to(self.dtype),
788
+ encoder_attention_mask=encoder_attention_mask,
789
+ **forward_kwargs
790
+ )
791
+
792
+ # perform guidance
793
+ if do_classifier_free_guidance:
794
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
795
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
796
+
797
+ # compute the previous noisy sample x_t -> x_t-1
798
+ bs, n_frame = noise_pred.shape[:2]
799
+ noise_pred = einops.rearrange(noise_pred, "b f c h w -> (b f) c h w")
800
+ latents = einops.rearrange(latents, "b f c h w -> (b f) c h w")
801
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
802
+ latents = einops.rearrange(latents, "(b f) c h w -> b f c h w", b=bs, f=n_frame)
803
+ if is_multiview:
804
+ latents = einops.rearrange(latents, "b (f n) c h w -> b f n c h w", n=2)
805
+ if cond_latents is not None:
806
+ latents[:, cond_indices_images, 0] = cond_latents
807
+ latents[:, cond_indices_rays, 1] = cond_rays
808
+ latents = einops.rearrange(latents, "b f n c h w -> b (f n) c h w")
809
+ else:
810
+ if cond_latents is not None:
811
+ latents[:, cond_indices] = cond_latents
812
+
813
+ # call the callback, if provided
814
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
815
+ progress_bar.update()
816
+ if callback is not None and i % callback_steps == 0:
817
+ callback(i, t, latents)
818
+
819
+ decoded_latents = latents / 1.658
820
+ # scale and decode the image latents with vae
821
+ latents = 1 / self.vae.config.scaling_factor * latents
822
+ if latents.ndim == 5:
823
+ latents = latents[:, denoise_indices]
824
+ latents = einops.rearrange(latents, "b f c h w -> (b f) c h w")
825
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
826
+
827
+ image = (image / 2 + 0.5).clamp(0, 1)
828
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
829
+
830
+ if output_type == "pil":
831
+ image = self.numpy_to_pil(image)
832
+
833
+ if not return_dict:
834
+ return (image, None)
835
+
836
+ return OneDiffusionPipelineOutput(images=image, latents=decoded_latents)
837
+
838
+ def prepare_extra_step_kwargs(self, generator, eta):
839
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
840
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
841
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
842
+ # and should be between [0, 1]
843
+
844
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
845
+ extra_step_kwargs = {}
846
+ if accepts_eta:
847
+ extra_step_kwargs["eta"] = eta
848
+
849
+ # check if the scheduler accepts generator
850
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
851
+ if accepts_generator:
852
+ extra_step_kwargs["generator"] = generator
853
+ return extra_step_kwargs
854
+
855
+ def check_inputs(self, prompt, height, width, callback_steps):
856
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
857
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
858
+
859
+ if height % 16 != 0 or width % 16 != 0:
860
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
861
+
862
+ if (callback_steps is None) or (
863
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
864
+ ):
865
+ raise ValueError(
866
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
867
+ f" {type(callback_steps)}."
868
+ )
869
+
870
+ def get_timesteps(self, num_inference_steps, strength, device):
871
+ # get the original timestep using init_timestep
872
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
873
+
874
+ t_start = max(num_inference_steps - init_timestep, 0)
875
+ timesteps = self.scheduler.timesteps[t_start:]
876
+
877
+ return timesteps, num_inference_steps - t_start
878
+
879
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None):
880
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
881
+ if isinstance(generator, list) and len(generator) != batch_size:
882
+ raise ValueError(
883
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
884
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
885
+ )
886
+
887
+ if latents is None:
888
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
889
+ else:
890
+ latents = latents.to(device)
891
+
892
+ if image is None:
893
+ # scale the initial noise by the standard deviation required by the scheduler
894
+ # latents = latents * self.scheduler.init_noise_sigma
895
+ return latents
896
+
897
+ image = image.to(device=device, dtype=dtype)
898
+
899
+ if isinstance(generator, list) and len(generator) != batch_size:
900
+ raise ValueError(
901
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
902
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
903
+ )
904
+ elif isinstance(generator, list):
905
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
906
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
907
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
908
+ raise ValueError(
909
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
910
+ )
911
+ init_latents = [
912
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
913
+ for i in range(batch_size)
914
+ ]
915
+ init_latents = torch.cat(init_latents, dim=0)
916
+ else:
917
+ init_latents = retrieve_latents(self.vae.encode(image.to(self.vae.dtype)), generator=generator)
918
+
919
+ init_latents = self.vae.config.scaling_factor * init_latents
920
+ init_latents = init_latents.to(device=device, dtype=dtype)
921
+
922
+ init_latents = einops.rearrange(init_latents, "(bs views) c h w -> bs views c h w", bs=batch_size, views=init_latents.shape[0]//batch_size)
923
+ # latents = einops.rearrange(latents, "b c h w -> b 1 c h w")
924
+ # latents = torch.concat([latents, init_latents], dim=1)
925
+ return init_latents
926
+
927
+ def prepare_init_latents(self, batch_size, seq_length, num_channels_latents, height, width, dtype, device, generator, latents=None):
928
+ shape = (batch_size, seq_length, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
929
+ if isinstance(generator, list) and len(generator) != batch_size:
930
+ raise ValueError(
931
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
932
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
933
+ )
934
+
935
+ if latents is None:
936
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
937
+ else:
938
+ latents = latents.to(device)
939
+
940
+ return latents
941
+
942
+ @torch.no_grad()
943
+ def generate(
944
+ self,
945
+ prompt: Union[str, List[str]],
946
+ num_inference_steps: int = 50,
947
+ guidance_scale: float = 5.0,
948
+ negative_prompt: Optional[Union[str, List[str]]] = None,
949
+ num_images_per_prompt: Optional[int] = 1,
950
+ height: Optional[int] = None,
951
+ width: Optional[int] = None,
952
+ eta: float = 0.0,
953
+ generator: Optional[torch.Generator] = None,
954
+ latents: Optional[torch.FloatTensor] = None,
955
+ output_type: Optional[str] = "pil",
956
+ return_dict: bool = True,
957
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
958
+ callback_steps: Optional[int] = 1,
959
+ ):
960
+ """
961
+ Function for image generation using the OneDiffusionPipeline.
962
+ """
963
+ return self(
964
+ prompt=prompt,
965
+ num_inference_steps=num_inference_steps,
966
+ guidance_scale=guidance_scale,
967
+ negative_prompt=negative_prompt,
968
+ num_images_per_prompt=num_images_per_prompt,
969
+ height=height,
970
+ width=width,
971
+ eta=eta,
972
+ generator=generator,
973
+ latents=latents,
974
+ output_type=output_type,
975
+ return_dict=return_dict,
976
+ callback=callback,
977
+ callback_steps=callback_steps,
978
+ )
979
+
980
+ @staticmethod
981
+ def numpy_to_pil(images):
982
+ """
983
+ Convert a numpy image or a batch of images to a PIL image.
984
+ """
985
+ if images.ndim == 3:
986
+ images = images[None, ...]
987
+ images = (images * 255).round().astype("uint8")
988
+ if images.shape[-1] == 1:
989
+ # special case for grayscale (single channel) images
990
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
991
+ else:
992
+ pil_images = [Image.fromarray(image) for image in images]
993
+
994
+ return pil_images
995
+
996
+ @classmethod
997
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
998
+ model_path = pretrained_model_name_or_path
999
+ cache_dir = kwargs.pop("cache_dir", None)
1000
+ force_download = kwargs.pop("force_download", False)
1001
+ proxies = kwargs.pop("proxies", None)
1002
+ local_files_only = kwargs.pop("local_files_only", None)
1003
+ token = kwargs.pop("token", None)
1004
+ revision = kwargs.pop("revision", None)
1005
+ from_flax = kwargs.pop("from_flax", False)
1006
+ torch_dtype = kwargs.pop("torch_dtype", None)
1007
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
1008
+ custom_revision = kwargs.pop("custom_revision", None)
1009
+ provider = kwargs.pop("provider", None)
1010
+ sess_options = kwargs.pop("sess_options", None)
1011
+ device_map = kwargs.pop("device_map", None)
1012
+ max_memory = kwargs.pop("max_memory", None)
1013
+ offload_folder = kwargs.pop("offload_folder", None)
1014
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1015
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
1016
+ variant = kwargs.pop("variant", None)
1017
+ use_safetensors = kwargs.pop("use_safetensors", None)
1018
+ use_onnx = kwargs.pop("use_onnx", None)
1019
+ load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
1020
+
1021
+ if low_cpu_mem_usage and not is_accelerate_available():
1022
+ low_cpu_mem_usage = False
1023
+ logger.warning(
1024
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
1025
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
1026
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
1027
+ " install accelerate\n```\n."
1028
+ )
1029
+
1030
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
1031
+ raise NotImplementedError(
1032
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
1033
+ " `low_cpu_mem_usage=False`."
1034
+ )
1035
+
1036
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1037
+ raise NotImplementedError(
1038
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1039
+ " `device_map=None`."
1040
+ )
1041
+
1042
+ if device_map is not None and not is_accelerate_available():
1043
+ raise NotImplementedError(
1044
+ "Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
1045
+ )
1046
+
1047
+ if device_map is not None and not isinstance(device_map, str):
1048
+ raise ValueError("`device_map` must be a string.")
1049
+
1050
+ if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
1051
+ raise NotImplementedError(
1052
+ f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
1053
+ )
1054
+
1055
+ if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
1056
+ if is_accelerate_version("<", "0.28.0"):
1057
+ raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
1058
+
1059
+ if low_cpu_mem_usage is False and device_map is not None:
1060
+ raise ValueError(
1061
+ f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
1062
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
1063
+ )
1064
+
1065
+ transformer = NextDiT.from_pretrained(f"{model_path}", subfolder="transformer", torch_dtype=torch.float32, cache_dir=cache_dir)
1066
+ vae = AutoencoderKL.from_pretrained(f"{model_path}", subfolder="vae", cache_dir=cache_dir)
1067
+ text_encoder = T5EncoderModel.from_pretrained(f"{model_path}", subfolder="text_encoder", torch_dtype=torch.float16, cache_dir=cache_dir)
1068
+ tokenizer = T5Tokenizer.from_pretrained(model_path, subfolder="tokenizer", cache_dir=cache_dir)
1069
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler", cache_dir=cache_dir)
1070
+
1071
+ pipeline = cls(
1072
+ transformer=transformer,
1073
+ vae=vae,
1074
+ text_encoder=text_encoder,
1075
+ tokenizer=tokenizer,
1076
+ scheduler=scheduler,
1077
+ **kwargs
1078
+ )
1079
+
1080
+ return pipeline
onediffusion/models/denoiser/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import (
2
+ nextdit
3
+ )
onediffusion/models/denoiser/nextdit/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modeling_nextdit import NextDiT
onediffusion/models/denoiser/nextdit/layers.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from typing import Callable, Optional
6
+
7
+ import warnings
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ try:
13
+ from apex.normalization import FusedRMSNorm as RMSNorm
14
+ except ImportError:
15
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
16
+
17
+
18
+ class RMSNorm(torch.nn.Module):
19
+ def __init__(self, dim: int, eps: float = 1e-6):
20
+ """
21
+ Initialize the RMSNorm normalization layer.
22
+ Args:
23
+ dim (int): The dimension of the input tensor.
24
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
25
+ Attributes:
26
+ eps (float): A small value added to the denominator for numerical stability.
27
+ weight (nn.Parameter): Learnable scaling parameter.
28
+ """
29
+ super().__init__()
30
+ self.eps = eps
31
+ self.weight = nn.Parameter(torch.ones(dim))
32
+
33
+ def _norm(self, x):
34
+ """
35
+ Apply the RMSNorm normalization to the input tensor.
36
+ Args:
37
+ x (torch.Tensor): The input tensor.
38
+ Returns:
39
+ torch.Tensor: The normalized tensor.
40
+ """
41
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
42
+
43
+ def forward(self, x):
44
+ """
45
+ Forward pass through the RMSNorm layer.
46
+ Args:
47
+ x (torch.Tensor): The input tensor.
48
+ Returns:
49
+ torch.Tensor: The output tensor after applying RMSNorm.
50
+ """
51
+ output = self._norm(x.float()).type_as(x)
52
+ return output * self.weight
53
+
54
+
55
+ def modulate(x, scale):
56
+ return x * (1 + scale.unsqueeze(1))
57
+
58
+ class LLamaFeedForward(nn.Module):
59
+ """
60
+ Corresponds to the FeedForward layer in Next DiT.
61
+ """
62
+ def __init__(
63
+ self,
64
+ dim: int,
65
+ hidden_dim: int,
66
+ multiple_of: int,
67
+ ffn_dim_multiplier: Optional[float] = None,
68
+ zeros_initialize: bool = True,
69
+ dtype: torch.dtype = torch.float32,
70
+ ):
71
+ super().__init__()
72
+ self.dim = dim
73
+ self.hidden_dim = hidden_dim
74
+ self.multiple_of = multiple_of
75
+ self.ffn_dim_multiplier = ffn_dim_multiplier
76
+ self.zeros_initialize = zeros_initialize
77
+ self.dtype = dtype
78
+
79
+ # Compute hidden_dim based on the given formula
80
+ hidden_dim_calculated = int(2 * self.hidden_dim / 3)
81
+ if self.ffn_dim_multiplier is not None:
82
+ hidden_dim_calculated = int(self.ffn_dim_multiplier * hidden_dim_calculated)
83
+ hidden_dim_calculated = self.multiple_of * ((hidden_dim_calculated + self.multiple_of - 1) // self.multiple_of)
84
+
85
+ # Define linear layers
86
+ self.w1 = nn.Linear(self.dim, hidden_dim_calculated, bias=False)
87
+ self.w2 = nn.Linear(hidden_dim_calculated, self.dim, bias=False)
88
+ self.w3 = nn.Linear(self.dim, hidden_dim_calculated, bias=False)
89
+
90
+ # Initialize weights
91
+ if self.zeros_initialize:
92
+ nn.init.zeros_(self.w2.weight)
93
+ else:
94
+ nn.init.xavier_uniform_(self.w2.weight)
95
+ nn.init.xavier_uniform_(self.w1.weight)
96
+ nn.init.xavier_uniform_(self.w3.weight)
97
+
98
+ def _forward_silu_gating(self, x1, x3):
99
+ return F.silu(x1) * x3
100
+
101
+ def forward(self, x):
102
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
103
+
104
+ class FinalLayer(nn.Module):
105
+ """
106
+ The final layer of Next-DiT.
107
+ """
108
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
109
+ super().__init__()
110
+ self.hidden_size = hidden_size
111
+ self.patch_size = patch_size
112
+ self.out_channels = out_channels
113
+
114
+ # LayerNorm without learnable parameters (elementwise_affine=False)
115
+ self.norm_final = nn.LayerNorm(self.hidden_size, eps=1e-6, elementwise_affine=False)
116
+ self.linear = nn.Linear(self.hidden_size, np.prod(self.patch_size) * self.out_channels, bias=True)
117
+ nn.init.zeros_(self.linear.weight)
118
+ nn.init.zeros_(self.linear.bias)
119
+
120
+ self.adaLN_modulation = nn.Sequential(
121
+ nn.SiLU(),
122
+ nn.Linear(self.hidden_size, self.hidden_size),
123
+ )
124
+ # Initialize the last layer with zeros
125
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
126
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
127
+
128
+ def forward(self, x, c):
129
+ scale = self.adaLN_modulation(c)
130
+ x = modulate(self.norm_final(x), scale)
131
+ x = self.linear(x)
132
+ return x
onediffusion/models/denoiser/nextdit/modeling_nextdit.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import einops
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from typing import Any, Tuple, Optional
10
+ from flash_attn import flash_attn_varlen_func
11
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
12
+
13
+ from .layers import LLamaFeedForward, RMSNorm
14
+
15
+ # import frasch
16
+
17
+
18
+ def modulate(x, scale):
19
+ return x * (1 + scale)
20
+
21
+ class TimestepEmbedder(nn.Module):
22
+ """
23
+ Embeds scalar timesteps into vector representations.
24
+ """
25
+ def __init__(self, hidden_size, frequency_embedding_size=256):
26
+ super().__init__()
27
+ self.hidden_size = hidden_size
28
+ self.frequency_embedding_size = frequency_embedding_size
29
+ self.mlp = nn.Sequential(
30
+ nn.Linear(self.frequency_embedding_size, self.hidden_size),
31
+ nn.SiLU(),
32
+ nn.Linear(self.hidden_size, self.hidden_size),
33
+ )
34
+
35
+ @staticmethod
36
+ def timestep_embedding(t, dim, max_period=10000):
37
+ """
38
+ Create sinusoidal timestep embeddings.
39
+ :param t: a 1-D Tensor of N indices, one per batch element.
40
+ :param dim: the dimension of the output.
41
+ :param max_period: controls the minimum frequency of the embeddings.
42
+ :return: an (N, D) Tensor of positional embeddings.
43
+ """
44
+ half = dim // 2
45
+ freqs = torch.exp(
46
+ -np.log(max_period) * torch.arange(0, half, dtype=t.dtype) / half
47
+ ).to(t.device)
48
+ args = t[:, :, None] * freqs[None, :]
49
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
50
+ if dim % 2:
51
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :, :1])], dim=-1)
52
+ return embedding
53
+
54
+ def forward(self, t):
55
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
56
+ t_freq = t_freq.to(self.mlp[0].weight.dtype)
57
+ return self.mlp(t_freq)
58
+
59
+ class FinalLayer(nn.Module):
60
+ def __init__(self, hidden_size, num_patches, out_channels):
61
+ super().__init__()
62
+ self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
63
+ self.linear = nn.Linear(hidden_size, num_patches * out_channels)
64
+ self.adaLN_modulation = nn.Sequential(
65
+ nn.SiLU(),
66
+ nn.Linear(min(hidden_size, 1024), hidden_size),
67
+ )
68
+
69
+ def forward(self, x, c):
70
+ scale = self.adaLN_modulation(c)
71
+ x = modulate(self.norm_final(x), scale)
72
+ x = self.linear(x)
73
+ return x
74
+
75
+ class Attention(nn.Module):
76
+ def __init__(
77
+ self,
78
+ dim,
79
+ n_heads,
80
+ n_kv_heads=None,
81
+ qk_norm=False,
82
+ y_dim=0,
83
+ base_seqlen=None,
84
+ proportional_attn=False,
85
+ attention_dropout=0.0,
86
+ max_position_embeddings=384,
87
+ ):
88
+ super().__init__()
89
+ self.dim = dim
90
+ self.n_heads = n_heads
91
+ self.n_kv_heads = n_kv_heads or n_heads
92
+ self.qk_norm = qk_norm
93
+ self.y_dim = y_dim
94
+ self.base_seqlen = base_seqlen
95
+ self.proportional_attn = proportional_attn
96
+ self.attention_dropout = attention_dropout
97
+ self.max_position_embeddings = max_position_embeddings
98
+
99
+ self.head_dim = dim // n_heads
100
+
101
+ self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
102
+ self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
103
+ self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
104
+
105
+ if y_dim > 0:
106
+ self.wk_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias=False)
107
+ self.wv_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias=False)
108
+ self.gate = nn.Parameter(torch.zeros(n_heads))
109
+
110
+ self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
111
+
112
+ if qk_norm:
113
+ self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
114
+ self.k_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim)
115
+ if y_dim > 0:
116
+ self.ky_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim, eps=1e-6)
117
+ else:
118
+ self.ky_norm = nn.Identity()
119
+ else:
120
+ self.q_norm = nn.Identity()
121
+ self.k_norm = nn.Identity()
122
+ self.ky_norm = nn.Identity()
123
+
124
+
125
+ @staticmethod
126
+ def apply_rotary_emb(xq, xk, freqs_cis):
127
+ # xq, xk: [batch_size, seq_len, n_heads, head_dim]
128
+ # freqs_cis: [1, seq_len, 1, head_dim]
129
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
130
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
131
+
132
+ xq_complex = torch.view_as_complex(xq_)
133
+ xk_complex = torch.view_as_complex(xk_)
134
+
135
+ freqs_cis = freqs_cis.unsqueeze(2)
136
+
137
+ # Apply freqs_cis
138
+ xq_out = xq_complex * freqs_cis
139
+ xk_out = xk_complex * freqs_cis
140
+
141
+ # Convert back to real numbers
142
+ xq_out = torch.view_as_real(xq_out).flatten(-2)
143
+ xk_out = torch.view_as_real(xk_out).flatten(-2)
144
+
145
+ return xq_out.type_as(xq), xk_out.type_as(xk)
146
+
147
+ # copied from huggingface modeling_llama.py
148
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
149
+ def _get_unpad_data(attention_mask):
150
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
151
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
152
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
153
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
154
+ return (
155
+ indices,
156
+ cu_seqlens,
157
+ max_seqlen_in_batch,
158
+ )
159
+
160
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
161
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
162
+
163
+ key_layer = index_first_axis(
164
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
165
+ indices_k,
166
+ )
167
+ value_layer = index_first_axis(
168
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
169
+ indices_k,
170
+ )
171
+ if query_length == kv_seq_len:
172
+ query_layer = index_first_axis(
173
+ query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim),
174
+ indices_k,
175
+ )
176
+ cu_seqlens_q = cu_seqlens_k
177
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
178
+ indices_q = indices_k
179
+ elif query_length == 1:
180
+ max_seqlen_in_batch_q = 1
181
+ cu_seqlens_q = torch.arange(
182
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
183
+ ) # There is a memcpy here, that is very bad.
184
+ indices_q = cu_seqlens_q[:-1]
185
+ query_layer = query_layer.squeeze(1)
186
+ else:
187
+ # The -q_len: slice assumes left padding.
188
+ attention_mask = attention_mask[:, -query_length:]
189
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
190
+
191
+ return (
192
+ query_layer,
193
+ key_layer,
194
+ value_layer,
195
+ indices_q,
196
+ (cu_seqlens_q, cu_seqlens_k),
197
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
198
+ )
199
+
200
+ def forward(
201
+ self,
202
+ x,
203
+ x_mask,
204
+ freqs_cis,
205
+ y=None,
206
+ y_mask=None,
207
+ init_cache=False,
208
+ ):
209
+ bsz, seqlen, _ = x.size()
210
+ xq = self.wq(x)
211
+ xk = self.wk(x)
212
+ xv = self.wv(x)
213
+
214
+ if x_mask is None:
215
+ x_mask = torch.ones(bsz, seqlen, dtype=torch.bool, device=x.device)
216
+ inp_dtype = xq.dtype
217
+
218
+ xq = self.q_norm(xq)
219
+ xk = self.k_norm(xk)
220
+
221
+ xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
222
+ xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
223
+ xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
224
+
225
+ if self.n_kv_heads != self.n_heads:
226
+ n_rep = self.n_heads // self.n_kv_heads
227
+ xk = xk.repeat_interleave(n_rep, dim=2)
228
+ xv = xv.repeat_interleave(n_rep, dim=2)
229
+
230
+ freqs_cis = freqs_cis.to(xq.device)
231
+ xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis)
232
+
233
+ if inp_dtype in [torch.float16, torch.bfloat16]:
234
+ # begin var_len flash attn
235
+ (
236
+ query_states,
237
+ key_states,
238
+ value_states,
239
+ indices_q,
240
+ cu_seq_lens,
241
+ max_seq_lens,
242
+ ) = self._upad_input(xq, xk, xv, x_mask, seqlen)
243
+
244
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
245
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
246
+
247
+ attn_output_unpad = flash_attn_varlen_func(
248
+ query_states.to(inp_dtype),
249
+ key_states.to(inp_dtype),
250
+ value_states.to(inp_dtype),
251
+ cu_seqlens_q=cu_seqlens_q,
252
+ cu_seqlens_k=cu_seqlens_k,
253
+ max_seqlen_q=max_seqlen_in_batch_q,
254
+ max_seqlen_k=max_seqlen_in_batch_k,
255
+ dropout_p=0.0,
256
+ causal=False,
257
+ softmax_scale=None,
258
+ softcap=30,
259
+ )
260
+ output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
261
+ else:
262
+ output = (
263
+ F.scaled_dot_product_attention(
264
+ xq.permute(0, 2, 1, 3),
265
+ xk.permute(0, 2, 1, 3),
266
+ xv.permute(0, 2, 1, 3),
267
+ attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_heads, seqlen, -1),
268
+ scale=None,
269
+ )
270
+ .permute(0, 2, 1, 3)
271
+ .to(inp_dtype)
272
+ ) #ok
273
+
274
+
275
+ if hasattr(self, "wk_y"):
276
+ yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_kv_heads, self.head_dim)
277
+ yv = self.wv_y(y).view(bsz, -1, self.n_kv_heads, self.head_dim)
278
+ n_rep = self.n_heads // self.n_kv_heads
279
+ # if n_rep >= 1:
280
+ # yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
281
+ # yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
282
+ if n_rep >= 1:
283
+ yk = einops.repeat(yk, "b l h d -> b l (repeat h) d", repeat=n_rep)
284
+ yv = einops.repeat(yv, "b l h d -> b l (repeat h) d", repeat=n_rep)
285
+ output_y = F.scaled_dot_product_attention(
286
+ xq.permute(0, 2, 1, 3),
287
+ yk.permute(0, 2, 1, 3),
288
+ yv.permute(0, 2, 1, 3),
289
+ y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_heads, seqlen, -1).to(torch.bool),
290
+ ).permute(0, 2, 1, 3)
291
+ output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
292
+ output = output + output_y
293
+
294
+ output = output.flatten(-2)
295
+ output = self.wo(output)
296
+
297
+ return output.to(inp_dtype)
298
+
299
+ class TransformerBlock(nn.Module):
300
+ """
301
+ Corresponds to the Transformer block in the JAX code.
302
+ """
303
+ def __init__(
304
+ self,
305
+ dim,
306
+ n_heads,
307
+ n_kv_heads,
308
+ multiple_of,
309
+ ffn_dim_multiplier,
310
+ norm_eps,
311
+ qk_norm,
312
+ y_dim,
313
+ max_position_embeddings,
314
+ ):
315
+ super().__init__()
316
+ self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim=y_dim, max_position_embeddings=max_position_embeddings)
317
+ self.feed_forward = LLamaFeedForward(
318
+ dim=dim,
319
+ hidden_dim=4 * dim,
320
+ multiple_of=multiple_of,
321
+ ffn_dim_multiplier=ffn_dim_multiplier,
322
+ )
323
+ self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
324
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
325
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
326
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
327
+ self.adaLN_modulation = nn.Sequential(
328
+ nn.SiLU(),
329
+ nn.Linear(min(dim, 1024), 4 * dim),
330
+ )
331
+ self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps)
332
+
333
+ def forward(
334
+ self,
335
+ x,
336
+ x_mask,
337
+ freqs_cis,
338
+ y,
339
+ y_mask,
340
+ adaln_input=None,
341
+ ):
342
+ if adaln_input is not None:
343
+ scales_gates = self.adaLN_modulation(adaln_input)
344
+ # TODO: Duong - check the dimension of chunking
345
+ # scale_msa, gate_msa, scale_mlp, gate_mlp = scales_gates.chunk(4, dim=-1)
346
+ scale_msa, gate_msa, scale_mlp, gate_mlp = scales_gates.chunk(4, dim=-1)
347
+ x = x + torch.tanh(gate_msa) * self.attention_norm2(
348
+ self.attention(
349
+ modulate(self.attention_norm1(x), scale_msa), # ok
350
+ x_mask,
351
+ freqs_cis,
352
+ self.attention_y_norm(y), # ok
353
+ y_mask,
354
+ )
355
+ )
356
+ x = x + torch.tanh(gate_mlp) * self.ffn_norm2(
357
+ self.feed_forward(
358
+ modulate(self.ffn_norm1(x), scale_mlp),
359
+ )
360
+ )
361
+ else:
362
+ x = x + self.attention_norm2(
363
+ self.attention(
364
+ self.attention_norm1(x),
365
+ x_mask,
366
+ freqs_cis,
367
+ self.attention_y_norm(y),
368
+ y_mask,
369
+ )
370
+ )
371
+ x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
372
+ return x
373
+
374
+
375
+ class NextDiT(ModelMixin, ConfigMixin):
376
+ """
377
+ Diffusion model with a Transformer backbone for joint image-video training.
378
+ """
379
+ @register_to_config
380
+ def __init__(
381
+ self,
382
+ input_size=(1, 32, 32),
383
+ patch_size=(1, 2, 2),
384
+ in_channels=16,
385
+ hidden_size=4096,
386
+ depth=32,
387
+ num_heads=32,
388
+ num_kv_heads=None,
389
+ multiple_of=256,
390
+ ffn_dim_multiplier=None,
391
+ norm_eps=1e-5,
392
+ pred_sigma=False,
393
+ caption_channels=4096,
394
+ qk_norm=False,
395
+ norm_type="rms",
396
+ model_max_length=120,
397
+ rotary_max_length=384,
398
+ rotary_max_length_t=None
399
+ ):
400
+ super().__init__()
401
+ self.input_size = input_size
402
+ self.patch_size = patch_size
403
+ self.in_channels = in_channels
404
+ self.hidden_size = hidden_size
405
+ self.depth = depth
406
+ self.num_heads = num_heads
407
+ self.num_kv_heads = num_kv_heads or num_heads
408
+ self.multiple_of = multiple_of
409
+ self.ffn_dim_multiplier = ffn_dim_multiplier
410
+ self.norm_eps = norm_eps
411
+ self.pred_sigma = pred_sigma
412
+ self.caption_channels = caption_channels
413
+ self.qk_norm = qk_norm
414
+ self.norm_type = norm_type
415
+ self.model_max_length = model_max_length
416
+ self.rotary_max_length = rotary_max_length
417
+ self.rotary_max_length_t = rotary_max_length_t
418
+ self.out_channels = in_channels * 2 if pred_sigma else in_channels
419
+
420
+ self.x_embedder = nn.Linear(np.prod(self.patch_size) * in_channels, hidden_size)
421
+
422
+ self.t_embedder = TimestepEmbedder(min(hidden_size, 1024))
423
+ self.y_embedder = nn.Sequential(
424
+ nn.LayerNorm(caption_channels, eps=1e-6),
425
+ nn.Linear(caption_channels, min(hidden_size, 1024)),
426
+ )
427
+
428
+ self.layers = nn.ModuleList([
429
+ TransformerBlock(
430
+ dim=hidden_size,
431
+ n_heads=num_heads,
432
+ n_kv_heads=self.num_kv_heads,
433
+ multiple_of=multiple_of,
434
+ ffn_dim_multiplier=ffn_dim_multiplier,
435
+ norm_eps=norm_eps,
436
+ qk_norm=qk_norm,
437
+ y_dim=caption_channels,
438
+ max_position_embeddings=rotary_max_length,
439
+ )
440
+ for _ in range(depth)
441
+ ])
442
+
443
+ self.final_layer = FinalLayer(
444
+ hidden_size=hidden_size,
445
+ num_patches=np.prod(patch_size),
446
+ out_channels=self.out_channels,
447
+ )
448
+
449
+ assert (hidden_size // num_heads) % 6 == 0, "3d rope needs head dim to be divisible by 6"
450
+
451
+ self.freqs_cis = self.precompute_freqs_cis(
452
+ hidden_size // num_heads,
453
+ self.rotary_max_length,
454
+ end_t=self.rotary_max_length_t
455
+ )
456
+
457
+ def to(self, *args, **kwargs):
458
+ self = super().to(*args, **kwargs)
459
+ # self.freqs_cis = self.freqs_cis.to(*args, **kwargs)
460
+ return self
461
+
462
+ @staticmethod
463
+ def precompute_freqs_cis(
464
+ dim: int,
465
+ end: int,
466
+ end_t: int = None,
467
+ theta: float = 10000.0,
468
+ scale_factor: float = 1.0,
469
+ scale_watershed: float = 1.0,
470
+ timestep: float = 1.0,
471
+ ):
472
+ if timestep < scale_watershed:
473
+ linear_factor = scale_factor
474
+ ntk_factor = 1.0
475
+ else:
476
+ linear_factor = 1.0
477
+ ntk_factor = scale_factor
478
+
479
+ theta = theta * ntk_factor
480
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 6)[: (dim // 6)] / dim)) / linear_factor
481
+
482
+ timestep = torch.arange(end, dtype=torch.float32)
483
+ freqs = torch.outer(timestep, freqs).float()
484
+ freqs_cis = torch.exp(1j * freqs)
485
+
486
+ if end_t is not None:
487
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim, 6)[: (dim // 6)] / dim)) / linear_factor
488
+ timestep_t = torch.arange(end_t, dtype=torch.float32)
489
+ freqs_t = torch.outer(timestep_t, freqs_t).float()
490
+ freqs_cis_t = torch.exp(1j * freqs_t)
491
+ freqs_cis_t = freqs_cis_t.view(end_t, 1, 1, dim // 6).repeat(1, end, end, 1)
492
+ else:
493
+ end_t = end
494
+ freqs_cis_t = freqs_cis.view(end_t, 1, 1, dim // 6).repeat(1, end, end, 1)
495
+
496
+ freqs_cis_h = freqs_cis.view(1, end, 1, dim // 6).repeat(end_t, 1, end, 1)
497
+ freqs_cis_w = freqs_cis.view(1, 1, end, dim // 6).repeat(end_t, end, 1, 1)
498
+ freqs_cis = torch.cat([freqs_cis_t, freqs_cis_h, freqs_cis_w], dim=-1).view(end_t, end, end, -1)
499
+ return freqs_cis
500
+
501
+ def forward(
502
+ self,
503
+ samples,
504
+ timesteps,
505
+ encoder_hidden_states,
506
+ encoder_attention_mask,
507
+ scale_factor: float = 1.0, # scale_factor for rotary embedding
508
+ scale_watershed: float = 1.0, # scale_watershed for rotary embedding
509
+ ):
510
+ if samples.ndim == 4: # B C H W
511
+ samples = samples[:, None, ...] # B F C H W
512
+
513
+ precomputed_freqs_cis = None
514
+ if scale_factor != 1 or scale_watershed != 1:
515
+ precomputed_freqs_cis = self.precompute_freqs_cis(
516
+ self.hidden_size // self.num_heads,
517
+ self.rotary_max_length,
518
+ end_t=self.rotary_max_length_t,
519
+ scale_factor=scale_factor,
520
+ scale_watershed=scale_watershed,
521
+ timestep=torch.max(timesteps.cpu()).item()
522
+ )
523
+
524
+ if len(timesteps.shape) == 5:
525
+ t, *_ = self.patchify(timesteps, precomputed_freqs_cis)
526
+ timesteps = t.mean(dim=-1)
527
+ elif len(timesteps.shape) == 1:
528
+ timesteps = timesteps[:, None, None, None, None].expand_as(samples)
529
+ t, *_ = self.patchify(timesteps, precomputed_freqs_cis)
530
+ timesteps = t.mean(dim=-1)
531
+ samples, T, H, W, freqs_cis = self.patchify(samples, precomputed_freqs_cis)
532
+ samples = self.x_embedder(samples)
533
+ t = self.t_embedder(timesteps)
534
+
535
+ encoder_attention_mask_float = encoder_attention_mask[..., None].float()
536
+ encoder_hidden_states_pool = (encoder_hidden_states * encoder_attention_mask_float).sum(dim=1) / (encoder_attention_mask_float.sum(dim=1) + 1e-8)
537
+ encoder_hidden_states_pool = encoder_hidden_states_pool.to(samples.dtype)
538
+ y = self.y_embedder(encoder_hidden_states_pool)
539
+ y = y.unsqueeze(1).expand(-1, samples.size(1), -1)
540
+
541
+ adaln_input = t + y
542
+
543
+ for block in self.layers:
544
+ samples = block(samples, None, freqs_cis, encoder_hidden_states, encoder_attention_mask, adaln_input)
545
+
546
+ samples = self.final_layer(samples, adaln_input)
547
+ samples = self.unpatchify(samples, T, H, W)
548
+
549
+ return samples
550
+
551
+ def patchify(self, x, precompute_freqs_cis=None):
552
+ # pytorch is C, H, W
553
+ B, T, C, H, W = x.size()
554
+ pT, pH, pW = self.patch_size
555
+ x = x.view(B, T // pT, pT, C, H // pH, pH, W // pW, pW)
556
+ x = x.permute(0, 1, 4, 6, 2, 5, 7, 3)
557
+ x = x.reshape(B, -1, pT * pH * pW * C)
558
+ if precompute_freqs_cis is None:
559
+ freqs_cis = self.freqs_cis[: T // pT, :H // pH, :W // pW].reshape(-1, * self.freqs_cis.shape[3:])[None].to(x.device)
560
+ else:
561
+ freqs_cis = precompute_freqs_cis[: T // pT, :H // pH, :W // pW].reshape(-1, * precompute_freqs_cis.shape[3:])[None].to(x.device)
562
+ return x, T // pT, H // pH, W // pW, freqs_cis
563
+
564
+ def unpatchify(self, x, T, H, W):
565
+ B = x.size(0)
566
+ C = self.out_channels
567
+ pT, pH, pW = self.patch_size
568
+ x = x.view(B, T, H, W, pT, pH, pW, C)
569
+ x = x.permute(0, 1, 4, 7, 2, 5, 3, 6)
570
+ x = x.reshape(B, T * pT, C, H * pH, W * pW)
571
+ return x
requirements.txt CHANGED
@@ -1,6 +1,27 @@
1
- transformers
2
- diffusers
3
- peft
4
- opencv-python
5
- protobuf
6
- sentencepiece
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytest
2
+ matplotlib
3
+ scikit-learn
4
+ scipy
5
+ spacy
6
+ numpy
7
+ einops
8
+ einsum
9
+ fvcore
10
+ h5py
11
+ twine
12
+ transformers==4.45.2
13
+ huggingface_hub==0.24
14
+ accelerate==0.34.2
15
+ diffusers==0.30.3
16
+ pillow==10.2.0
17
+ torch==2.3.1
18
+ torchvision==0.18.1
19
+ torchaudio==2.3.1
20
+ flash-attn==2.6.3
21
+ git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib
22
+ jaxtyping
23
+ mediapipe
24
+ gradio
25
+ git+https://github.com/facebookresearch/pytorch3d.git
26
+ opencv-python==4.5.5.64
27
+ opencv-python-headless==4.5.5.64