Spaces:
Running
on
Zero
Running
on
Zero
Upload 33 files
Browse files- .gitattributes +13 -0
- LICENSE +201 -0
- attn_ctrl/attention_control.py +285 -0
- checkpoints/.DS_Store +0 -0
- checkpoints/svd_reverse_motion_with_attnflip/unet/config.json +38 -0
- custom_diffusers/pipelines/pipeline_frame_interpolation_with_noise_injection.py +576 -0
- custom_diffusers/pipelines/pipeline_stable_video_diffusion_with_ref_attnmap.py +514 -0
- custom_diffusers/schedulers/scheduling_euler_discrete.py +466 -0
- dataset/stable_video_dataset.py +70 -0
- enviroment.yml +45 -0
- eval/val/0010.png +0 -0
- eval/val/0022.png +0 -0
- eval/val/0023.png +3 -0
- eval/val/turtle.png +3 -0
- examples/.gitignore +1 -0
- examples/example_001.gif +3 -0
- examples/example_001/frame1.png +3 -0
- examples/example_001/frame2.png +3 -0
- examples/example_002.gif +3 -0
- examples/example_002/frame1.png +3 -0
- examples/example_002/frame2.png +3 -0
- examples/example_003.gif +3 -0
- examples/example_003/frame1.png +0 -0
- examples/example_003/frame2.png +3 -0
- examples/example_004.gif +3 -0
- examples/example_004/frame1.png +3 -0
- examples/example_004/frame2.png +3 -0
- gradio_app.py +137 -0
- keyframe_interpolation.py +98 -0
- keyframe_interpolation.sh +26 -0
- requirements.txt +35 -0
- train_reverse_motion_with_attnflip.py +591 -0
- train_reverse_motion_with_attnflip.sh +20 -0
- utils/parse_args.py +224 -0
.gitattributes
CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
eval/val/0023.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
eval/val/turtle.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
examples/example_001.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/example_001/frame1.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
examples/example_001/frame2.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
examples/example_002.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
examples/example_002/frame1.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
examples/example_002/frame2.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
examples/example_003.gif filter=lfs diff=lfs merge=lfs -text
|
45 |
+
examples/example_003/frame2.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
examples/example_004.gif filter=lfs diff=lfs merge=lfs -text
|
47 |
+
examples/example_004/frame1.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
examples/example_004/frame2.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
attn_ctrl/attention_control.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import torch
|
3 |
+
from typing import Tuple, List
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
class AttentionControl(abc.ABC):
|
7 |
+
|
8 |
+
def step_callback(self, x_t):
|
9 |
+
return x_t
|
10 |
+
|
11 |
+
def between_steps(self):
|
12 |
+
return
|
13 |
+
|
14 |
+
@property
|
15 |
+
def num_uncond_att_layers(self):
|
16 |
+
return 0
|
17 |
+
|
18 |
+
@abc.abstractmethod
|
19 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
23 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
24 |
+
self.forward(attn, is_cross, place_in_unet)
|
25 |
+
self.cur_att_layer += 1
|
26 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
27 |
+
self.cur_att_layer = 0
|
28 |
+
self.cur_step += 1
|
29 |
+
self.between_steps()
|
30 |
+
|
31 |
+
def reset(self):
|
32 |
+
self.cur_step = 0
|
33 |
+
self.cur_att_layer = 0
|
34 |
+
|
35 |
+
def __init__(self):
|
36 |
+
self.cur_step = 0
|
37 |
+
self.num_att_layers = -1
|
38 |
+
self.cur_att_layer = 0
|
39 |
+
|
40 |
+
class AttentionStore(AttentionControl):
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def get_empty_store():
|
44 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
45 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
46 |
+
|
47 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
48 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
49 |
+
#if attn.shape[1] <= 32 ** 2: # avoid memory overhead
|
50 |
+
self.step_store[key].append(attn)
|
51 |
+
return attn
|
52 |
+
|
53 |
+
def between_steps(self):
|
54 |
+
self.attention_store = self.step_store
|
55 |
+
if self.save_global_store:
|
56 |
+
with torch.no_grad():
|
57 |
+
if len(self.global_store) == 0:
|
58 |
+
self.global_store = self.step_store
|
59 |
+
else:
|
60 |
+
for key in self.global_store:
|
61 |
+
for i in range(len(self.global_store[key])):
|
62 |
+
self.global_store[key][i] += self.step_store[key][i].detach()
|
63 |
+
self.step_store = self.get_empty_store()
|
64 |
+
self.step_store = self.get_empty_store()
|
65 |
+
|
66 |
+
def get_average_attention(self):
|
67 |
+
average_attention = self.attention_store
|
68 |
+
return average_attention
|
69 |
+
|
70 |
+
def get_average_global_attention(self):
|
71 |
+
average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
|
72 |
+
self.attention_store}
|
73 |
+
return average_attention
|
74 |
+
|
75 |
+
def reset(self):
|
76 |
+
super(AttentionStore, self).reset()
|
77 |
+
self.step_store = self.get_empty_store()
|
78 |
+
self.attention_store = {}
|
79 |
+
self.global_store = {}
|
80 |
+
|
81 |
+
def __init__(self, save_global_store=False):
|
82 |
+
'''
|
83 |
+
Initialize an empty AttentionStore
|
84 |
+
:param step_index: used to visualize only a specific step in the diffusion process
|
85 |
+
'''
|
86 |
+
super(AttentionStore, self).__init__()
|
87 |
+
self.save_global_store = save_global_store
|
88 |
+
self.step_store = self.get_empty_store()
|
89 |
+
self.attention_store = {}
|
90 |
+
self.global_store = {}
|
91 |
+
self.curr_step_index = 0
|
92 |
+
|
93 |
+
class AttentionStoreProcessor:
|
94 |
+
|
95 |
+
def __init__(self, attnstore, place_in_unet):
|
96 |
+
super().__init__()
|
97 |
+
self.attnstore = attnstore
|
98 |
+
self.place_in_unet = place_in_unet
|
99 |
+
|
100 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
101 |
+
residual = hidden_states
|
102 |
+
if attn.spatial_norm is not None:
|
103 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
104 |
+
|
105 |
+
input_ndim = hidden_states.ndim
|
106 |
+
|
107 |
+
if input_ndim == 4:
|
108 |
+
batch_size, channel, height, width = hidden_states.shape
|
109 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
110 |
+
|
111 |
+
batch_size, sequence_length, _ = (
|
112 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
113 |
+
)
|
114 |
+
|
115 |
+
if attention_mask is not None:
|
116 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
117 |
+
|
118 |
+
if attn.group_norm is not None:
|
119 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
120 |
+
|
121 |
+
query = attn.to_q(hidden_states)
|
122 |
+
|
123 |
+
if encoder_hidden_states is None:
|
124 |
+
encoder_hidden_states = hidden_states
|
125 |
+
elif attn.norm_cross:
|
126 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
127 |
+
|
128 |
+
key = attn.to_k(encoder_hidden_states)
|
129 |
+
value = attn.to_v(encoder_hidden_states)
|
130 |
+
|
131 |
+
|
132 |
+
query = attn.head_to_batch_dim(query)
|
133 |
+
key = attn.head_to_batch_dim(key)
|
134 |
+
value = attn.head_to_batch_dim(value)
|
135 |
+
|
136 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
137 |
+
self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet)
|
138 |
+
|
139 |
+
hidden_states = torch.bmm(attention_probs, value)
|
140 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
141 |
+
|
142 |
+
# linear proj
|
143 |
+
hidden_states = attn.to_out[0](hidden_states)
|
144 |
+
# dropout
|
145 |
+
hidden_states = attn.to_out[1](hidden_states)
|
146 |
+
|
147 |
+
if input_ndim == 4:
|
148 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
149 |
+
|
150 |
+
if attn.residual_connection:
|
151 |
+
hidden_states = hidden_states + residual
|
152 |
+
|
153 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
154 |
+
|
155 |
+
return hidden_states
|
156 |
+
|
157 |
+
|
158 |
+
class AttentionFlipCtrlProcessor:
|
159 |
+
|
160 |
+
def __init__(self, attnstore, attnstore_ref, place_in_unet):
|
161 |
+
super().__init__()
|
162 |
+
self.attnstore = attnstore
|
163 |
+
self.attnrstore_ref = attnstore_ref
|
164 |
+
self.place_in_unet = place_in_unet
|
165 |
+
|
166 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
167 |
+
residual = hidden_states
|
168 |
+
if attn.spatial_norm is not None:
|
169 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
170 |
+
|
171 |
+
input_ndim = hidden_states.ndim
|
172 |
+
|
173 |
+
if input_ndim == 4:
|
174 |
+
batch_size, channel, height, width = hidden_states.shape
|
175 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
176 |
+
|
177 |
+
batch_size, sequence_length, _ = (
|
178 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
179 |
+
)
|
180 |
+
|
181 |
+
if attention_mask is not None:
|
182 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
183 |
+
|
184 |
+
if attn.group_norm is not None:
|
185 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
186 |
+
|
187 |
+
query = attn.to_q(hidden_states)
|
188 |
+
|
189 |
+
if encoder_hidden_states is None:
|
190 |
+
encoder_hidden_states = hidden_states
|
191 |
+
elif attn.norm_cross:
|
192 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
193 |
+
|
194 |
+
key = attn.to_k(encoder_hidden_states)
|
195 |
+
value = attn.to_v(encoder_hidden_states)
|
196 |
+
|
197 |
+
query = attn.head_to_batch_dim(query)
|
198 |
+
key = attn.head_to_batch_dim(key)
|
199 |
+
value = attn.head_to_batch_dim(value)
|
200 |
+
|
201 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
202 |
+
|
203 |
+
if self.place_in_unet == 'mid':
|
204 |
+
cur_att_layer = self.attnstore.cur_att_layer-len(self.attnrstore_ref.attention_store["down_self"])
|
205 |
+
elif self.place_in_unet == 'up':
|
206 |
+
cur_att_layer = self.attnstore.cur_att_layer-(len(self.attnrstore_ref.attention_store["down_self"])+len(self.attnrstore_ref.attention_store["mid_self"]))
|
207 |
+
else:
|
208 |
+
cur_att_layer = self.attnstore.cur_att_layer
|
209 |
+
|
210 |
+
attention_probs_ref = self.attnrstore_ref.attention_store[f"{self.place_in_unet}_{'self'}"][cur_att_layer]
|
211 |
+
attention_probs_ref = rearrange(attention_probs_ref, 'b h i j -> (b h) i j')
|
212 |
+
attention_probs = 0.0 * attention_probs + 1.0 * torch.flip(attention_probs_ref, dims=(-2, -1))
|
213 |
+
|
214 |
+
self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet)
|
215 |
+
hidden_states = torch.bmm(attention_probs, value)
|
216 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
217 |
+
|
218 |
+
# linear proj
|
219 |
+
hidden_states = attn.to_out[0](hidden_states)
|
220 |
+
# dropout
|
221 |
+
hidden_states = attn.to_out[1](hidden_states)
|
222 |
+
|
223 |
+
if input_ndim == 4:
|
224 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
225 |
+
|
226 |
+
if attn.residual_connection:
|
227 |
+
hidden_states = hidden_states + residual
|
228 |
+
|
229 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
230 |
+
|
231 |
+
return hidden_states
|
232 |
+
|
233 |
+
def register_temporal_self_attention_control(unet, controller):
|
234 |
+
|
235 |
+
attn_procs = {}
|
236 |
+
temporal_self_att_count = 0
|
237 |
+
for name in unet.attn_processors.keys():
|
238 |
+
if name.endswith("temporal_transformer_blocks.0.attn1.processor"):
|
239 |
+
if name.startswith("mid_block"):
|
240 |
+
place_in_unet = "mid"
|
241 |
+
elif name.startswith("up_blocks"):
|
242 |
+
block_id = int(name[len("up_blocks.")])
|
243 |
+
place_in_unet = "up"
|
244 |
+
elif name.startswith("down_blocks"):
|
245 |
+
block_id = int(name[len("down_blocks.")])
|
246 |
+
place_in_unet = "down"
|
247 |
+
else:
|
248 |
+
continue
|
249 |
+
|
250 |
+
temporal_self_att_count += 1
|
251 |
+
attn_procs[name] = AttentionStoreProcessor(
|
252 |
+
attnstore=controller, place_in_unet=place_in_unet
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
attn_procs[name] = unet.attn_processors[name]
|
256 |
+
|
257 |
+
unet.set_attn_processor(attn_procs)
|
258 |
+
controller.num_att_layers = temporal_self_att_count
|
259 |
+
|
260 |
+
def register_temporal_self_attention_flip_control(unet, controller, controller_ref):
|
261 |
+
|
262 |
+
attn_procs = {}
|
263 |
+
temporal_self_att_count = 0
|
264 |
+
for name in unet.attn_processors.keys():
|
265 |
+
if name.endswith("temporal_transformer_blocks.0.attn1.processor"):
|
266 |
+
if name.startswith("mid_block"):
|
267 |
+
place_in_unet = "mid"
|
268 |
+
elif name.startswith("up_blocks"):
|
269 |
+
block_id = int(name[len("up_blocks.")])
|
270 |
+
place_in_unet = "up"
|
271 |
+
elif name.startswith("down_blocks"):
|
272 |
+
block_id = int(name[len("down_blocks.")])
|
273 |
+
place_in_unet = "down"
|
274 |
+
else:
|
275 |
+
continue
|
276 |
+
|
277 |
+
temporal_self_att_count += 1
|
278 |
+
attn_procs[name] = AttentionFlipCtrlProcessor(
|
279 |
+
attnstore=controller, attnstore_ref=controller_ref, place_in_unet=place_in_unet
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
attn_procs[name] = unet.attn_processors[name]
|
283 |
+
|
284 |
+
unet.set_attn_processor(attn_procs)
|
285 |
+
controller.num_att_layers = temporal_self_att_count
|
checkpoints/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
checkpoints/svd_reverse_motion_with_attnflip/unet/config.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "UNetSpatioTemporalConditionModel",
|
3 |
+
"_diffusers_version": "0.27.0",
|
4 |
+
"_name_or_path": "/gscratch/realitylab/xiaojwan/projects/video_narratives/stabilityai/stable-video-diffusion-img2vid",
|
5 |
+
"addition_time_embed_dim": 256,
|
6 |
+
"block_out_channels": [
|
7 |
+
320,
|
8 |
+
640,
|
9 |
+
1280,
|
10 |
+
1280
|
11 |
+
],
|
12 |
+
"cross_attention_dim": 1024,
|
13 |
+
"down_block_types": [
|
14 |
+
"CrossAttnDownBlockSpatioTemporal",
|
15 |
+
"CrossAttnDownBlockSpatioTemporal",
|
16 |
+
"CrossAttnDownBlockSpatioTemporal",
|
17 |
+
"DownBlockSpatioTemporal"
|
18 |
+
],
|
19 |
+
"in_channels": 8,
|
20 |
+
"layers_per_block": 2,
|
21 |
+
"num_attention_heads": [
|
22 |
+
5,
|
23 |
+
10,
|
24 |
+
20,
|
25 |
+
20
|
26 |
+
],
|
27 |
+
"num_frames": 14,
|
28 |
+
"out_channels": 4,
|
29 |
+
"projection_class_embeddings_input_dim": 768,
|
30 |
+
"sample_size": 96,
|
31 |
+
"transformer_layers_per_block": 1,
|
32 |
+
"up_block_types": [
|
33 |
+
"UpBlockSpatioTemporal",
|
34 |
+
"CrossAttnUpBlockSpatioTemporal",
|
35 |
+
"CrossAttnUpBlockSpatioTemporal",
|
36 |
+
"CrossAttnUpBlockSpatioTemporal"
|
37 |
+
]
|
38 |
+
}
|
custom_diffusers/pipelines/pipeline_frame_interpolation_with_noise_injection.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adpated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_stable_video_diffusion.py
|
2 |
+
import inspect
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Callable, Dict, List, Optional, Union
|
5 |
+
import copy
|
6 |
+
import numpy as np
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
10 |
+
|
11 |
+
from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
12 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
13 |
+
from diffusers.utils import logging
|
14 |
+
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
15 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
16 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
17 |
+
_append_dims,
|
18 |
+
tensor2vid,
|
19 |
+
_resize_with_antialiasing,
|
20 |
+
StableVideoDiffusionPipelineOutput
|
21 |
+
)
|
22 |
+
from ..schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
|
23 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
24 |
+
|
25 |
+
class FrameInterpolationWithNoiseInjectionPipeline(DiffusionPipeline):
|
26 |
+
r"""
|
27 |
+
Pipeline to generate video from an input image using Stable Video Diffusion.
|
28 |
+
|
29 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
30 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
31 |
+
|
32 |
+
Args:
|
33 |
+
vae ([`AutoencoderKLTemporalDecoder`]):
|
34 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
35 |
+
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
|
36 |
+
Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
|
37 |
+
unet ([`UNetSpatioTemporalConditionModel`]):
|
38 |
+
A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
|
39 |
+
scheduler ([`EulerDiscreteScheduler`]):
|
40 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
41 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
42 |
+
A `CLIPImageProcessor` to extract features from generated images.
|
43 |
+
"""
|
44 |
+
|
45 |
+
model_cpu_offload_seq = "image_encoder->unet->vae"
|
46 |
+
_callback_tensor_inputs = ["latents"]
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
vae: AutoencoderKLTemporalDecoder,
|
51 |
+
image_encoder: CLIPVisionModelWithProjection,
|
52 |
+
unet: UNetSpatioTemporalConditionModel,
|
53 |
+
scheduler: EulerDiscreteScheduler,
|
54 |
+
feature_extractor: CLIPImageProcessor,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.register_modules(
|
59 |
+
vae=vae,
|
60 |
+
image_encoder=image_encoder,
|
61 |
+
unet=unet,
|
62 |
+
scheduler=scheduler,
|
63 |
+
feature_extractor=feature_extractor,
|
64 |
+
)
|
65 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
66 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
67 |
+
self.ori_unet = copy.deepcopy(unet)
|
68 |
+
|
69 |
+
def _encode_image(
|
70 |
+
self,
|
71 |
+
image: PipelineImageInput,
|
72 |
+
device: Union[str, torch.device],
|
73 |
+
num_videos_per_prompt: int,
|
74 |
+
do_classifier_free_guidance: bool,
|
75 |
+
) -> torch.FloatTensor:
|
76 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
77 |
+
|
78 |
+
if not isinstance(image, torch.Tensor):
|
79 |
+
image = self.image_processor.pil_to_numpy(image)
|
80 |
+
image = self.image_processor.numpy_to_pt(image)
|
81 |
+
|
82 |
+
# We normalize the image before resizing to match with the original implementation.
|
83 |
+
# Then we unnormalize it after resizing.
|
84 |
+
image = image * 2.0 - 1.0
|
85 |
+
image = _resize_with_antialiasing(image, (224, 224))
|
86 |
+
image = (image + 1.0) / 2.0
|
87 |
+
|
88 |
+
# Normalize the image with for CLIP input
|
89 |
+
image = self.feature_extractor(
|
90 |
+
images=image,
|
91 |
+
do_normalize=True,
|
92 |
+
do_center_crop=False,
|
93 |
+
do_resize=False,
|
94 |
+
do_rescale=False,
|
95 |
+
return_tensors="pt",
|
96 |
+
).pixel_values
|
97 |
+
|
98 |
+
image = image.to(device=device, dtype=dtype)
|
99 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
100 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
101 |
+
|
102 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
103 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
104 |
+
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
105 |
+
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
106 |
+
|
107 |
+
if do_classifier_free_guidance:
|
108 |
+
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
109 |
+
|
110 |
+
# For classifier free guidance, we need to do two forward passes.
|
111 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
112 |
+
# to avoid doing two forward passes
|
113 |
+
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
|
114 |
+
|
115 |
+
return image_embeddings
|
116 |
+
|
117 |
+
def _encode_vae_image(
|
118 |
+
self,
|
119 |
+
image: torch.Tensor,
|
120 |
+
device: Union[str, torch.device],
|
121 |
+
num_videos_per_prompt: int,
|
122 |
+
do_classifier_free_guidance: bool,
|
123 |
+
):
|
124 |
+
image = image.to(device=device)
|
125 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
126 |
+
|
127 |
+
if do_classifier_free_guidance:
|
128 |
+
negative_image_latents = torch.zeros_like(image_latents)
|
129 |
+
|
130 |
+
# For classifier free guidance, we need to do two forward passes.
|
131 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
132 |
+
# to avoid doing two forward passes
|
133 |
+
image_latents = torch.cat([negative_image_latents, image_latents])
|
134 |
+
|
135 |
+
# duplicate image_latents for each generation per prompt, using mps friendly method
|
136 |
+
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
|
137 |
+
|
138 |
+
return image_latents
|
139 |
+
|
140 |
+
def _get_add_time_ids(
|
141 |
+
self,
|
142 |
+
fps: int,
|
143 |
+
motion_bucket_id: int,
|
144 |
+
noise_aug_strength: float,
|
145 |
+
dtype: torch.dtype,
|
146 |
+
batch_size: int,
|
147 |
+
num_videos_per_prompt: int,
|
148 |
+
do_classifier_free_guidance: bool,
|
149 |
+
):
|
150 |
+
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
|
151 |
+
|
152 |
+
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
|
153 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
154 |
+
|
155 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
156 |
+
raise ValueError(
|
157 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
158 |
+
)
|
159 |
+
|
160 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
161 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
|
162 |
+
|
163 |
+
if do_classifier_free_guidance:
|
164 |
+
add_time_ids = torch.cat([add_time_ids, add_time_ids])
|
165 |
+
|
166 |
+
return add_time_ids
|
167 |
+
|
168 |
+
def decode_latents(self, latents: torch.FloatTensor, num_frames: int, decode_chunk_size: int = 14):
|
169 |
+
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
170 |
+
latents = latents.flatten(0, 1)
|
171 |
+
|
172 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
173 |
+
|
174 |
+
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
175 |
+
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
176 |
+
|
177 |
+
# decode decode_chunk_size frames at a time to avoid OOM
|
178 |
+
frames = []
|
179 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
180 |
+
num_frames_in = latents[i : i + decode_chunk_size].shape[0]
|
181 |
+
decode_kwargs = {}
|
182 |
+
if accepts_num_frames:
|
183 |
+
# we only pass num_frames_in if it's expected
|
184 |
+
decode_kwargs["num_frames"] = num_frames_in
|
185 |
+
|
186 |
+
frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
|
187 |
+
frames.append(frame)
|
188 |
+
frames = torch.cat(frames, dim=0)
|
189 |
+
|
190 |
+
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
191 |
+
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
192 |
+
|
193 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
194 |
+
frames = frames.float()
|
195 |
+
return frames
|
196 |
+
|
197 |
+
def check_inputs(self, image, height, width):
|
198 |
+
if (
|
199 |
+
not isinstance(image, torch.Tensor)
|
200 |
+
and not isinstance(image, PIL.Image.Image)
|
201 |
+
and not isinstance(image, list)
|
202 |
+
):
|
203 |
+
raise ValueError(
|
204 |
+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
205 |
+
f" {type(image)}"
|
206 |
+
)
|
207 |
+
|
208 |
+
if height % 8 != 0 or width % 8 != 0:
|
209 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
210 |
+
|
211 |
+
def prepare_latents(
|
212 |
+
self,
|
213 |
+
batch_size: int,
|
214 |
+
num_frames: int,
|
215 |
+
num_channels_latents: int,
|
216 |
+
height: int,
|
217 |
+
width: int,
|
218 |
+
dtype: torch.dtype,
|
219 |
+
device: Union[str, torch.device],
|
220 |
+
generator: torch.Generator,
|
221 |
+
latents: Optional[torch.FloatTensor] = None,
|
222 |
+
):
|
223 |
+
shape = (
|
224 |
+
batch_size,
|
225 |
+
num_frames,
|
226 |
+
num_channels_latents // 2,
|
227 |
+
height // self.vae_scale_factor,
|
228 |
+
width // self.vae_scale_factor,
|
229 |
+
)
|
230 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
231 |
+
raise ValueError(
|
232 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
233 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
234 |
+
)
|
235 |
+
|
236 |
+
if latents is None:
|
237 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
238 |
+
else:
|
239 |
+
latents = latents.to(device)
|
240 |
+
|
241 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
242 |
+
latents = latents * self.scheduler.init_noise_sigma
|
243 |
+
return latents
|
244 |
+
|
245 |
+
@property
|
246 |
+
def guidance_scale(self):
|
247 |
+
return self._guidance_scale
|
248 |
+
|
249 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
250 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
251 |
+
# corresponds to doing no classifier free guidance.
|
252 |
+
@property
|
253 |
+
def do_classifier_free_guidance(self):
|
254 |
+
if isinstance(self.guidance_scale, (int, float)):
|
255 |
+
return self.guidance_scale > 1
|
256 |
+
return self.guidance_scale.max() > 1
|
257 |
+
|
258 |
+
@property
|
259 |
+
def num_timesteps(self):
|
260 |
+
return self._num_timesteps
|
261 |
+
|
262 |
+
|
263 |
+
@torch.no_grad()
|
264 |
+
def multidiffusion_step(self, latents, t,
|
265 |
+
image1_embeddings,
|
266 |
+
image2_embeddings,
|
267 |
+
image1_latents,
|
268 |
+
image2_latents,
|
269 |
+
added_time_ids,
|
270 |
+
avg_weight
|
271 |
+
):
|
272 |
+
# expand the latents if we are doing classifier free guidance
|
273 |
+
latents1 = latents
|
274 |
+
latents2 = torch.flip(latents, (1,))
|
275 |
+
latent_model_input1 = torch.cat([latents1] * 2) if self.do_classifier_free_guidance else latents1
|
276 |
+
latent_model_input1 = self.scheduler.scale_model_input(latent_model_input1, t)
|
277 |
+
|
278 |
+
latent_model_input2 = torch.cat([latents2] * 2) if self.do_classifier_free_guidance else latents2
|
279 |
+
latent_model_input2= self.scheduler.scale_model_input(latent_model_input2, t)
|
280 |
+
|
281 |
+
|
282 |
+
# Concatenate image_latents over channels dimention
|
283 |
+
latent_model_input1 = torch.cat([latent_model_input1, image1_latents], dim=2)
|
284 |
+
latent_model_input2 = torch.cat([latent_model_input2, image2_latents], dim=2)
|
285 |
+
|
286 |
+
# predict the noise residual
|
287 |
+
noise_pred1 = self.ori_unet(
|
288 |
+
latent_model_input1,
|
289 |
+
t,
|
290 |
+
encoder_hidden_states=image1_embeddings,
|
291 |
+
added_time_ids=added_time_ids,
|
292 |
+
return_dict=False,
|
293 |
+
)[0]
|
294 |
+
noise_pred2 = self.unet(
|
295 |
+
latent_model_input2,
|
296 |
+
t,
|
297 |
+
encoder_hidden_states=image2_embeddings,
|
298 |
+
added_time_ids=added_time_ids,
|
299 |
+
return_dict=False,
|
300 |
+
)[0]
|
301 |
+
# perform guidance
|
302 |
+
if self.do_classifier_free_guidance:
|
303 |
+
noise_pred_uncond1, noise_pred_cond1 = noise_pred1.chunk(2)
|
304 |
+
noise_pred1 = noise_pred_uncond1 + self.guidance_scale * (noise_pred_cond1 - noise_pred_uncond1)
|
305 |
+
|
306 |
+
noise_pred_uncond2, noise_pred_cond2 = noise_pred2.chunk(2)
|
307 |
+
noise_pred2 = noise_pred_uncond2 + self.guidance_scale * (noise_pred_cond2 - noise_pred_uncond2)
|
308 |
+
|
309 |
+
noise_pred2 = torch.flip(noise_pred2, (1,))
|
310 |
+
noise_pred = avg_weight*noise_pred1+ (1-avg_weight)*noise_pred2
|
311 |
+
return noise_pred
|
312 |
+
|
313 |
+
|
314 |
+
@torch.no_grad()
|
315 |
+
def __call__(
|
316 |
+
self,
|
317 |
+
image1: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
318 |
+
image2: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
319 |
+
height: int = 576,
|
320 |
+
width: int = 1024,
|
321 |
+
num_frames: Optional[int] = None,
|
322 |
+
num_inference_steps: int = 25,
|
323 |
+
min_guidance_scale: float = 1.0,
|
324 |
+
max_guidance_scale: float = 3.0,
|
325 |
+
fps: int = 7,
|
326 |
+
motion_bucket_id: int = 127,
|
327 |
+
noise_aug_strength: float = 0.02,
|
328 |
+
decode_chunk_size: Optional[int] = None,
|
329 |
+
num_videos_per_prompt: Optional[int] = 1,
|
330 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
331 |
+
latents: Optional[torch.FloatTensor] = None,
|
332 |
+
output_type: Optional[str] = "pil",
|
333 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
334 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
335 |
+
weighted_average: bool = False,
|
336 |
+
noise_injection_steps: int = 0,
|
337 |
+
noise_injection_ratio: float=0.0,
|
338 |
+
return_dict: bool = True,
|
339 |
+
):
|
340 |
+
r"""
|
341 |
+
The call function to the pipeline for generation.
|
342 |
+
|
343 |
+
Args:
|
344 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
345 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
346 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
347 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
348 |
+
The height in pixels of the generated image.
|
349 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
350 |
+
The width in pixels of the generated image.
|
351 |
+
num_frames (`int`, *optional*):
|
352 |
+
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
|
353 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
354 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
355 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
356 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
357 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
358 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
359 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
360 |
+
fps (`int`, *optional*, defaults to 7):
|
361 |
+
Frames per second. The rate at which the generated images shall be exported to a video after generation.
|
362 |
+
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
363 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
364 |
+
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
|
365 |
+
noise_aug_strength (`float`, *optional*, defaults to 0.02):
|
366 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
|
367 |
+
decode_chunk_size (`int`, *optional*):
|
368 |
+
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
|
369 |
+
between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
|
370 |
+
for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
|
371 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
372 |
+
The number of images to generate per prompt.
|
373 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
374 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
375 |
+
generation deterministic.
|
376 |
+
latents (`torch.FloatTensor`, *optional*):
|
377 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
378 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
379 |
+
tensor is generated by sampling using the supplied random `generator`.
|
380 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
381 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
382 |
+
callback_on_step_end (`Callable`, *optional*):
|
383 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
384 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
385 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
386 |
+
`callback_on_step_end_tensor_inputs`.
|
387 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
388 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
389 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
390 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
391 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
392 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
393 |
+
plain tuple.
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
397 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
|
398 |
+
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
|
399 |
+
|
400 |
+
Examples:
|
401 |
+
|
402 |
+
```py
|
403 |
+
from diffusers import StableVideoDiffusionPipeline
|
404 |
+
from diffusers.utils import load_image, export_to_video
|
405 |
+
|
406 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
|
407 |
+
pipe.to("cuda")
|
408 |
+
|
409 |
+
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
|
410 |
+
image = image.resize((1024, 576))
|
411 |
+
|
412 |
+
frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
|
413 |
+
export_to_video(frames, "generated.mp4", fps=7)
|
414 |
+
```
|
415 |
+
"""
|
416 |
+
# 0. Default height and width to unet
|
417 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
418 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
419 |
+
|
420 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
421 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
422 |
+
|
423 |
+
# 1. Check inputs. Raise error if not correct
|
424 |
+
self.check_inputs(image1, height, width)
|
425 |
+
self.check_inputs(image2, height, width)
|
426 |
+
|
427 |
+
# 2. Define call parameters
|
428 |
+
if isinstance(image1, PIL.Image.Image):
|
429 |
+
batch_size = 1
|
430 |
+
elif isinstance(image1, list):
|
431 |
+
batch_size = len(image1)
|
432 |
+
else:
|
433 |
+
batch_size = image1.shape[0]
|
434 |
+
device = self._execution_device
|
435 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
436 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
437 |
+
# corresponds to doing no classifier free guidance.
|
438 |
+
self._guidance_scale = max_guidance_scale
|
439 |
+
|
440 |
+
# 3. Encode input image
|
441 |
+
image1_embeddings = self._encode_image(image1, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
442 |
+
image2_embeddings = self._encode_image(image2, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
443 |
+
|
444 |
+
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
|
445 |
+
# is why it is reduced here.
|
446 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
447 |
+
fps = fps - 1
|
448 |
+
|
449 |
+
# 4. Encode input image using VAE
|
450 |
+
image1 = self.image_processor.preprocess(image1, height=height, width=width).to(device)
|
451 |
+
image2 = self.image_processor.preprocess(image2, height=height, width=width).to(device)
|
452 |
+
noise = randn_tensor(image1.shape, generator=generator, device=image1.device, dtype=image1.dtype)
|
453 |
+
image1 = image1 + noise_aug_strength * noise
|
454 |
+
image2 = image2 + noise_aug_strength * noise
|
455 |
+
|
456 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
457 |
+
if needs_upcasting:
|
458 |
+
self.vae.to(dtype=torch.float32)
|
459 |
+
|
460 |
+
|
461 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
462 |
+
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
463 |
+
image1_latent = self._encode_vae_image(image1, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
464 |
+
image1_latent = image1_latent.to(image1_embeddings.dtype)
|
465 |
+
image1_latents = image1_latent.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
466 |
+
|
467 |
+
image2_latent = self._encode_vae_image(image2, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
468 |
+
image2_latent = image2_latent.to(image2_embeddings.dtype)
|
469 |
+
image2_latents = image2_latent.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
470 |
+
|
471 |
+
# cast back to fp16 if needed
|
472 |
+
if needs_upcasting:
|
473 |
+
self.vae.to(dtype=torch.float16)
|
474 |
+
|
475 |
+
# 5. Get Added Time IDs
|
476 |
+
added_time_ids = self._get_add_time_ids(
|
477 |
+
fps,
|
478 |
+
motion_bucket_id,
|
479 |
+
noise_aug_strength,
|
480 |
+
image1_embeddings.dtype,
|
481 |
+
batch_size,
|
482 |
+
num_videos_per_prompt,
|
483 |
+
self.do_classifier_free_guidance,
|
484 |
+
)
|
485 |
+
added_time_ids = added_time_ids.to(device)
|
486 |
+
|
487 |
+
# 4. Prepare timesteps
|
488 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
489 |
+
timesteps = self.scheduler.timesteps
|
490 |
+
|
491 |
+
# 5. Prepare latent variables
|
492 |
+
num_channels_latents = self.unet.config.in_channels
|
493 |
+
latents = self.prepare_latents(
|
494 |
+
batch_size * num_videos_per_prompt,
|
495 |
+
num_frames,
|
496 |
+
num_channels_latents,
|
497 |
+
height,
|
498 |
+
width,
|
499 |
+
image1_embeddings.dtype,
|
500 |
+
device,
|
501 |
+
generator,
|
502 |
+
latents,
|
503 |
+
)
|
504 |
+
|
505 |
+
# 7. Prepare guidance scale
|
506 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
507 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
508 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
509 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
510 |
+
|
511 |
+
if weighted_average:
|
512 |
+
self._guidance_scale = guidance_scale
|
513 |
+
w = torch.linspace(1, 0, num_frames).unsqueeze(0).to(device, latents.dtype)
|
514 |
+
w = w.repeat(batch_size*num_videos_per_prompt, 1)
|
515 |
+
w = _append_dims(w, latents.ndim)
|
516 |
+
else:
|
517 |
+
self._guidance_scale = (guidance_scale+torch.flip(guidance_scale, (1,)))*0.5
|
518 |
+
w = 0.5
|
519 |
+
|
520 |
+
# 8. Denoising loop
|
521 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
522 |
+
self._num_timesteps = len(timesteps)
|
523 |
+
self.ori_unet = self.ori_unet.to(device)
|
524 |
+
|
525 |
+
noise_injection_step_threshold = int(num_inference_steps*noise_injection_ratio)
|
526 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
527 |
+
for i, t in enumerate(timesteps):
|
528 |
+
|
529 |
+
noise_pred = self.multidiffusion_step(latents, t,
|
530 |
+
image1_embeddings, image2_embeddings,
|
531 |
+
image1_latents, image2_latents, added_time_ids, w
|
532 |
+
)
|
533 |
+
# compute the previous noisy sample x_t -> x_t-1
|
534 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
535 |
+
if i < noise_injection_step_threshold and noise_injection_steps > 0:
|
536 |
+
sigma_t = self.scheduler.sigmas[self.scheduler.step_index]
|
537 |
+
sigma_tm1 = self.scheduler.sigmas[self.scheduler.step_index+1]
|
538 |
+
sigma = torch.sqrt(sigma_t**2-sigma_tm1**2)
|
539 |
+
for j in range(noise_injection_steps):
|
540 |
+
noise = randn_tensor(latents.shape, device=latents.device, dtype=latents.dtype)
|
541 |
+
noise = noise * sigma
|
542 |
+
latents = latents + noise
|
543 |
+
noise_pred = self.multidiffusion_step(latents, t,
|
544 |
+
image1_embeddings, image2_embeddings,
|
545 |
+
image1_latents, image2_latents, added_time_ids, w
|
546 |
+
)
|
547 |
+
# compute the previous noisy sample x_t -> x_t-1
|
548 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
549 |
+
self.scheduler._step_index += 1
|
550 |
+
|
551 |
+
if callback_on_step_end is not None:
|
552 |
+
callback_kwargs = {}
|
553 |
+
for k in callback_on_step_end_tensor_inputs:
|
554 |
+
callback_kwargs[k] = locals()[k]
|
555 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
556 |
+
|
557 |
+
latents = callback_outputs.pop("latents", latents)
|
558 |
+
|
559 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
560 |
+
progress_bar.update()
|
561 |
+
|
562 |
+
if not output_type == "latent":
|
563 |
+
# cast back to fp16 if needed
|
564 |
+
if needs_upcasting:
|
565 |
+
self.vae.to(dtype=torch.float16)
|
566 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
567 |
+
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
568 |
+
else:
|
569 |
+
frames = latents
|
570 |
+
|
571 |
+
self.maybe_free_model_hooks()
|
572 |
+
|
573 |
+
if not return_dict:
|
574 |
+
return frames
|
575 |
+
|
576 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
custom_diffusers/pipelines/pipeline_stable_video_diffusion_with_ref_attnmap.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adpated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_stable_video_diffusion.py
|
2 |
+
import inspect
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Callable, Dict, List, Optional, Union
|
5 |
+
import copy
|
6 |
+
import numpy as np
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
10 |
+
|
11 |
+
from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
12 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
13 |
+
from diffusers.utils import logging
|
14 |
+
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
15 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
16 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
17 |
+
_append_dims,
|
18 |
+
tensor2vid,
|
19 |
+
_resize_with_antialiasing,
|
20 |
+
StableVideoDiffusionPipelineOutput
|
21 |
+
)
|
22 |
+
|
23 |
+
from ..schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
|
24 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25 |
+
|
26 |
+
class StableVideoDiffusionWithRefAttnMapPipeline(DiffusionPipeline):
|
27 |
+
|
28 |
+
model_cpu_offload_seq = "image_encoder->unet->vae"
|
29 |
+
_callback_tensor_inputs = ["latents"]
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
vae: AutoencoderKLTemporalDecoder,
|
34 |
+
image_encoder: CLIPVisionModelWithProjection,
|
35 |
+
unet: UNetSpatioTemporalConditionModel,
|
36 |
+
scheduler: EulerDiscreteScheduler,
|
37 |
+
feature_extractor: CLIPImageProcessor,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.register_modules(
|
42 |
+
vae=vae,
|
43 |
+
image_encoder=image_encoder,
|
44 |
+
unet=unet,
|
45 |
+
scheduler=scheduler,
|
46 |
+
feature_extractor=feature_extractor,
|
47 |
+
)
|
48 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
49 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
50 |
+
|
51 |
+
def _encode_image(
|
52 |
+
self,
|
53 |
+
image: PipelineImageInput,
|
54 |
+
device: Union[str, torch.device],
|
55 |
+
num_videos_per_prompt: int,
|
56 |
+
do_classifier_free_guidance: bool,
|
57 |
+
) -> torch.FloatTensor:
|
58 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
59 |
+
|
60 |
+
if not isinstance(image, torch.Tensor):
|
61 |
+
image = self.image_processor.pil_to_numpy(image)
|
62 |
+
image = self.image_processor.numpy_to_pt(image)
|
63 |
+
|
64 |
+
# We normalize the image before resizing to match with the original implementation.
|
65 |
+
# Then we unnormalize it after resizing.
|
66 |
+
image = image * 2.0 - 1.0
|
67 |
+
image = _resize_with_antialiasing(image, (224, 224))
|
68 |
+
image = (image + 1.0) / 2.0
|
69 |
+
|
70 |
+
# Normalize the image with for CLIP input
|
71 |
+
image = self.feature_extractor(
|
72 |
+
images=image,
|
73 |
+
do_normalize=True,
|
74 |
+
do_center_crop=False,
|
75 |
+
do_resize=False,
|
76 |
+
do_rescale=False,
|
77 |
+
return_tensors="pt",
|
78 |
+
).pixel_values
|
79 |
+
|
80 |
+
image = image.to(device=device, dtype=dtype)
|
81 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
82 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
83 |
+
|
84 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
85 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
86 |
+
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
87 |
+
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
88 |
+
|
89 |
+
if do_classifier_free_guidance:
|
90 |
+
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
91 |
+
|
92 |
+
# For classifier free guidance, we need to do two forward passes.
|
93 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
94 |
+
# to avoid doing two forward passes
|
95 |
+
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
|
96 |
+
|
97 |
+
return image_embeddings
|
98 |
+
|
99 |
+
def _encode_vae_image(
|
100 |
+
self,
|
101 |
+
image: torch.Tensor,
|
102 |
+
device: Union[str, torch.device],
|
103 |
+
num_videos_per_prompt: int,
|
104 |
+
do_classifier_free_guidance: bool,
|
105 |
+
):
|
106 |
+
image = image.to(device=device)
|
107 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
108 |
+
|
109 |
+
if do_classifier_free_guidance:
|
110 |
+
negative_image_latents = torch.zeros_like(image_latents)
|
111 |
+
|
112 |
+
# For classifier free guidance, we need to do two forward passes.
|
113 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
114 |
+
# to avoid doing two forward passes
|
115 |
+
image_latents = torch.cat([negative_image_latents, image_latents])
|
116 |
+
|
117 |
+
# duplicate image_latents for each generation per prompt, using mps friendly method
|
118 |
+
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
|
119 |
+
|
120 |
+
return image_latents
|
121 |
+
|
122 |
+
def _get_add_time_ids(
|
123 |
+
self,
|
124 |
+
fps: int,
|
125 |
+
motion_bucket_id: int,
|
126 |
+
noise_aug_strength: float,
|
127 |
+
dtype: torch.dtype,
|
128 |
+
batch_size: int,
|
129 |
+
num_videos_per_prompt: int,
|
130 |
+
do_classifier_free_guidance: bool,
|
131 |
+
):
|
132 |
+
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
|
133 |
+
|
134 |
+
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
|
135 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
136 |
+
|
137 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
138 |
+
raise ValueError(
|
139 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
140 |
+
)
|
141 |
+
|
142 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
143 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
|
144 |
+
|
145 |
+
if do_classifier_free_guidance:
|
146 |
+
add_time_ids = torch.cat([add_time_ids, add_time_ids])
|
147 |
+
|
148 |
+
return add_time_ids
|
149 |
+
|
150 |
+
def decode_latents(self, latents: torch.FloatTensor, num_frames: int, decode_chunk_size: int = 14):
|
151 |
+
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
152 |
+
latents = latents.flatten(0, 1)
|
153 |
+
|
154 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
155 |
+
|
156 |
+
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
157 |
+
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
158 |
+
|
159 |
+
# decode decode_chunk_size frames at a time to avoid OOM
|
160 |
+
frames = []
|
161 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
162 |
+
num_frames_in = latents[i : i + decode_chunk_size].shape[0]
|
163 |
+
decode_kwargs = {}
|
164 |
+
if accepts_num_frames:
|
165 |
+
# we only pass num_frames_in if it's expected
|
166 |
+
decode_kwargs["num_frames"] = num_frames_in
|
167 |
+
|
168 |
+
frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
|
169 |
+
frames.append(frame)
|
170 |
+
frames = torch.cat(frames, dim=0)
|
171 |
+
|
172 |
+
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
173 |
+
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
174 |
+
|
175 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
176 |
+
frames = frames.float()
|
177 |
+
return frames
|
178 |
+
|
179 |
+
def check_inputs(self, image, height, width):
|
180 |
+
if (
|
181 |
+
not isinstance(image, torch.Tensor)
|
182 |
+
and not isinstance(image, PIL.Image.Image)
|
183 |
+
and not isinstance(image, list)
|
184 |
+
):
|
185 |
+
raise ValueError(
|
186 |
+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
187 |
+
f" {type(image)}"
|
188 |
+
)
|
189 |
+
|
190 |
+
if height % 8 != 0 or width % 8 != 0:
|
191 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
192 |
+
|
193 |
+
def prepare_latents(
|
194 |
+
self,
|
195 |
+
batch_size: int,
|
196 |
+
num_frames: int,
|
197 |
+
num_channels_latents: int,
|
198 |
+
height: int,
|
199 |
+
width: int,
|
200 |
+
dtype: torch.dtype,
|
201 |
+
device: Union[str, torch.device],
|
202 |
+
generator: torch.Generator,
|
203 |
+
latents: Optional[torch.FloatTensor] = None,
|
204 |
+
):
|
205 |
+
shape = (
|
206 |
+
batch_size,
|
207 |
+
num_frames,
|
208 |
+
num_channels_latents // 2,
|
209 |
+
height // self.vae_scale_factor,
|
210 |
+
width // self.vae_scale_factor,
|
211 |
+
)
|
212 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
213 |
+
raise ValueError(
|
214 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
215 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
216 |
+
)
|
217 |
+
|
218 |
+
if latents is None:
|
219 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
220 |
+
else:
|
221 |
+
latents = latents.to(device)
|
222 |
+
|
223 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
224 |
+
latents = latents * self.scheduler.init_noise_sigma
|
225 |
+
return latents
|
226 |
+
|
227 |
+
@property
|
228 |
+
def guidance_scale(self):
|
229 |
+
return self._guidance_scale
|
230 |
+
|
231 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
232 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
233 |
+
# corresponds to doing no classifier free guidance.
|
234 |
+
@property
|
235 |
+
def do_classifier_free_guidance(self):
|
236 |
+
if isinstance(self.guidance_scale, (int, float)):
|
237 |
+
return self.guidance_scale > 1
|
238 |
+
return self.guidance_scale.max() > 1
|
239 |
+
|
240 |
+
@property
|
241 |
+
def num_timesteps(self):
|
242 |
+
return self._num_timesteps
|
243 |
+
|
244 |
+
|
245 |
+
@torch.no_grad()
|
246 |
+
def __call__(
|
247 |
+
self,
|
248 |
+
ref_unet: UNetSpatioTemporalConditionModel,
|
249 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
250 |
+
ref_image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
251 |
+
height: int = 576,
|
252 |
+
width: int = 1024,
|
253 |
+
num_frames: Optional[int] = None,
|
254 |
+
num_inference_steps: int = 25,
|
255 |
+
min_guidance_scale: float = 1.0,
|
256 |
+
max_guidance_scale: float = 3.0,
|
257 |
+
fps: int = 7,
|
258 |
+
motion_bucket_id: int = 127,
|
259 |
+
noise_aug_strength: float = 0.02,
|
260 |
+
decode_chunk_size: Optional[int] = None,
|
261 |
+
num_videos_per_prompt: Optional[int] = 1,
|
262 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
263 |
+
latents: Optional[torch.FloatTensor] = None,
|
264 |
+
output_type: Optional[str] = "pil",
|
265 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
266 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
267 |
+
return_dict: bool = True,
|
268 |
+
):
|
269 |
+
r"""
|
270 |
+
The call function to the pipeline for generation.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
274 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
275 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
276 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
277 |
+
The height in pixels of the generated image.
|
278 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
279 |
+
The width in pixels of the generated image.
|
280 |
+
num_frames (`int`, *optional*):
|
281 |
+
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
|
282 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
283 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
284 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
285 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
286 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
287 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
288 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
289 |
+
fps (`int`, *optional*, defaults to 7):
|
290 |
+
Frames per second. The rate at which the generated images shall be exported to a video after generation.
|
291 |
+
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
292 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
293 |
+
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
|
294 |
+
noise_aug_strength (`float`, *optional*, defaults to 0.02):
|
295 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
|
296 |
+
decode_chunk_size (`int`, *optional*):
|
297 |
+
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
|
298 |
+
between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
|
299 |
+
for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
|
300 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
301 |
+
The number of images to generate per prompt.
|
302 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
303 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
304 |
+
generation deterministic.
|
305 |
+
latents (`torch.FloatTensor`, *optional*):
|
306 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
307 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
308 |
+
tensor is generated by sampling using the supplied random `generator`.
|
309 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
310 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
311 |
+
callback_on_step_end (`Callable`, *optional*):
|
312 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
313 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
314 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
315 |
+
`callback_on_step_end_tensor_inputs`.
|
316 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
317 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
318 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
319 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
320 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
321 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
322 |
+
plain tuple.
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
326 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
|
327 |
+
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
|
328 |
+
|
329 |
+
Examples:
|
330 |
+
|
331 |
+
```py
|
332 |
+
from diffusers import StableVideoDiffusionPipeline
|
333 |
+
from diffusers.utils import load_image, export_to_video
|
334 |
+
|
335 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
|
336 |
+
pipe.to("cuda")
|
337 |
+
|
338 |
+
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
|
339 |
+
image = image.resize((1024, 576))
|
340 |
+
|
341 |
+
frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
|
342 |
+
export_to_video(frames, "generated.mp4", fps=7)
|
343 |
+
```
|
344 |
+
"""
|
345 |
+
# 0. Default height and width to unet
|
346 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
347 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
348 |
+
|
349 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
350 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
351 |
+
|
352 |
+
# 1. Check inputs. Raise error if not correct
|
353 |
+
self.check_inputs(image, height, width)
|
354 |
+
self.check_inputs(ref_image, height, width)
|
355 |
+
|
356 |
+
# 2. Define call parameters
|
357 |
+
if isinstance(image, PIL.Image.Image):
|
358 |
+
batch_size = 1
|
359 |
+
elif isinstance(image, list):
|
360 |
+
batch_size = len(image)
|
361 |
+
else:
|
362 |
+
batch_size = image.shape[0]
|
363 |
+
device = self._execution_device
|
364 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
365 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
366 |
+
# corresponds to doing no classifier free guidance.
|
367 |
+
self._guidance_scale = max_guidance_scale
|
368 |
+
|
369 |
+
# 3. Encode input image
|
370 |
+
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
371 |
+
ref_image_embeddings = self._encode_image(ref_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
372 |
+
|
373 |
+
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
|
374 |
+
# is why it is reduced here.
|
375 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
376 |
+
fps = fps - 1
|
377 |
+
|
378 |
+
# 4. Encode input image using VAE
|
379 |
+
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
|
380 |
+
ref_image = self.image_processor.preprocess(ref_image, height=height, width=width).to(device)
|
381 |
+
noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
|
382 |
+
image = image + noise_aug_strength * noise
|
383 |
+
ref_image = ref_image + noise_aug_strength * noise
|
384 |
+
|
385 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
386 |
+
if needs_upcasting:
|
387 |
+
self.vae.to(dtype=torch.float32)
|
388 |
+
|
389 |
+
|
390 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
391 |
+
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
392 |
+
image_latent = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
393 |
+
image_latent = image_latent.to(image_embeddings.dtype)
|
394 |
+
image_latents = image_latent.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
395 |
+
|
396 |
+
ref_image_latent = self._encode_vae_image(ref_image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
397 |
+
ref_image_latent = ref_image_latent.to(ref_image_embeddings.dtype)
|
398 |
+
ref_image_latents = ref_image_latent.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
399 |
+
|
400 |
+
# cast back to fp16 if needed
|
401 |
+
if needs_upcasting:
|
402 |
+
self.vae.to(dtype=torch.float16)
|
403 |
+
|
404 |
+
# 5. Get Added Time IDs
|
405 |
+
added_time_ids = self._get_add_time_ids(
|
406 |
+
fps,
|
407 |
+
motion_bucket_id,
|
408 |
+
noise_aug_strength,
|
409 |
+
image_embeddings.dtype,
|
410 |
+
batch_size,
|
411 |
+
num_videos_per_prompt,
|
412 |
+
self.do_classifier_free_guidance,
|
413 |
+
)
|
414 |
+
added_time_ids = added_time_ids.to(device)
|
415 |
+
|
416 |
+
# 4. Prepare timesteps
|
417 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
418 |
+
timesteps = self.scheduler.timesteps
|
419 |
+
|
420 |
+
# 5. Prepare latent variables
|
421 |
+
num_channels_latents = self.unet.config.in_channels
|
422 |
+
latents = self.prepare_latents(
|
423 |
+
batch_size * num_videos_per_prompt,
|
424 |
+
num_frames,
|
425 |
+
num_channels_latents,
|
426 |
+
height,
|
427 |
+
width,
|
428 |
+
image_embeddings.dtype,
|
429 |
+
device,
|
430 |
+
generator,
|
431 |
+
latents,
|
432 |
+
)
|
433 |
+
|
434 |
+
# 7. Prepare guidance scale
|
435 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
436 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
437 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
438 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
439 |
+
self._guidance_scale = guidance_scale
|
440 |
+
|
441 |
+
# 8. Denoising loop
|
442 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
443 |
+
self._num_timesteps = len(timesteps)
|
444 |
+
ref_unet = ref_unet.to(device)
|
445 |
+
ref_latents = latents.clone()
|
446 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
447 |
+
for i, t in enumerate(timesteps):
|
448 |
+
# expand the latents if we are doing classifier free guidance
|
449 |
+
ref_latent_model_input= torch.cat([ref_latents] * 2) if self.do_classifier_free_guidance else ref_latents
|
450 |
+
ref_latent_model_input = self.scheduler.scale_model_input(ref_latent_model_input, t)
|
451 |
+
|
452 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
453 |
+
latent_model_input= self.scheduler.scale_model_input(latent_model_input, t)
|
454 |
+
|
455 |
+
|
456 |
+
# Concatenate image_latents over channels dimention
|
457 |
+
ref_latent_model_input = torch.cat([ref_latent_model_input, ref_image_latents], dim=2)
|
458 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
459 |
+
|
460 |
+
# predict the noise residual
|
461 |
+
noise_pred_ref = ref_unet(
|
462 |
+
ref_latent_model_input,
|
463 |
+
t,
|
464 |
+
encoder_hidden_states=ref_image_embeddings,
|
465 |
+
added_time_ids=added_time_ids,
|
466 |
+
return_dict=False,
|
467 |
+
)[0]
|
468 |
+
noise_pred = self.unet(
|
469 |
+
latent_model_input,
|
470 |
+
t,
|
471 |
+
encoder_hidden_states=image_embeddings,
|
472 |
+
added_time_ids=added_time_ids,
|
473 |
+
return_dict=False,
|
474 |
+
)[0]
|
475 |
+
# perform guidance
|
476 |
+
if self.do_classifier_free_guidance:
|
477 |
+
noise_pred_uncond_ref, noise_pred_cond_ref = noise_pred_ref.chunk(2)
|
478 |
+
noise_pred_ref = noise_pred_uncond_ref+ self.guidance_scale * (noise_pred_cond_ref - noise_pred_uncond_ref)
|
479 |
+
|
480 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
481 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
482 |
+
|
483 |
+
|
484 |
+
# compute the previous noisy sample x_t -> x_t-1
|
485 |
+
ref_latents = self.scheduler.step(noise_pred_ref, t, ref_latents).prev_sample
|
486 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
487 |
+
self.scheduler._step_index += 1
|
488 |
+
|
489 |
+
if callback_on_step_end is not None:
|
490 |
+
callback_kwargs = {}
|
491 |
+
for k in callback_on_step_end_tensor_inputs:
|
492 |
+
callback_kwargs[k] = locals()[k]
|
493 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
494 |
+
|
495 |
+
latents = callback_outputs.pop("latents", latents)
|
496 |
+
|
497 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
498 |
+
progress_bar.update()
|
499 |
+
|
500 |
+
if not output_type == "latent":
|
501 |
+
# cast back to fp16 if needed
|
502 |
+
if needs_upcasting:
|
503 |
+
self.vae.to(dtype=torch.float16)
|
504 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
505 |
+
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
506 |
+
else:
|
507 |
+
frames = latents
|
508 |
+
|
509 |
+
self.maybe_free_model_hooks()
|
510 |
+
|
511 |
+
if not return_dict:
|
512 |
+
return frames
|
513 |
+
|
514 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
custom_diffusers/schedulers/scheduling_euler_discrete.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from diffusers.utils import BaseOutput, logging
|
10 |
+
from diffusers.utils.torch_utils import randn_tensor
|
11 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
12 |
+
|
13 |
+
|
14 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
15 |
+
|
16 |
+
from diffusers.schedulers.scheduling_euler_discrete import (EulerDiscreteSchedulerOutput,
|
17 |
+
betas_for_alpha_bar,
|
18 |
+
rescale_zero_terminal_snr
|
19 |
+
)
|
20 |
+
|
21 |
+
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
22 |
+
"""
|
23 |
+
Euler scheduler.
|
24 |
+
|
25 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
26 |
+
methods the library implements for all schedulers such as loading and saving.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
num_train_timesteps (`int`, defaults to 1000):
|
30 |
+
The number of diffusion steps to train the model.
|
31 |
+
beta_start (`float`, defaults to 0.0001):
|
32 |
+
The starting `beta` value of inference.
|
33 |
+
beta_end (`float`, defaults to 0.02):
|
34 |
+
The final `beta` value.
|
35 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
36 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
37 |
+
`linear` or `scaled_linear`.
|
38 |
+
trained_betas (`np.ndarray`, *optional*):
|
39 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
40 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
41 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
42 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
43 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
44 |
+
interpolation_type(`str`, defaults to `"linear"`, *optional*):
|
45 |
+
The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
|
46 |
+
`"linear"` or `"log_linear"`.
|
47 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
48 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
49 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
50 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
51 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
52 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
53 |
+
steps_offset (`int`, defaults to 0):
|
54 |
+
An offset added to the inference steps, as required by some model families.
|
55 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
56 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
57 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
58 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
59 |
+
"""
|
60 |
+
|
61 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
62 |
+
order = 1
|
63 |
+
|
64 |
+
@register_to_config
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
num_train_timesteps: int = 1000,
|
68 |
+
beta_start: float = 0.0001,
|
69 |
+
beta_end: float = 0.02,
|
70 |
+
beta_schedule: str = "linear",
|
71 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
72 |
+
prediction_type: str = "epsilon",
|
73 |
+
interpolation_type: str = "linear",
|
74 |
+
use_karras_sigmas: Optional[bool] = False,
|
75 |
+
sigma_min: Optional[float] = None,
|
76 |
+
sigma_max: Optional[float] = None,
|
77 |
+
timestep_spacing: str = "linspace",
|
78 |
+
timestep_type: str = "discrete", # can be "discrete" or "continuous"
|
79 |
+
steps_offset: int = 0,
|
80 |
+
rescale_betas_zero_snr: bool = False,
|
81 |
+
):
|
82 |
+
if trained_betas is not None:
|
83 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
84 |
+
elif beta_schedule == "linear":
|
85 |
+
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
86 |
+
elif beta_schedule == "scaled_linear":
|
87 |
+
# this schedule is very specific to the latent diffusion model.
|
88 |
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
89 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
90 |
+
# Glide cosine schedule
|
91 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
94 |
+
|
95 |
+
if rescale_betas_zero_snr:
|
96 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
97 |
+
|
98 |
+
self.alphas = 1.0 - self.betas
|
99 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
100 |
+
|
101 |
+
if rescale_betas_zero_snr:
|
102 |
+
# Close to 0 without being 0 so first sigma is not inf
|
103 |
+
# FP16 smallest positive subnormal works well here
|
104 |
+
self.alphas_cumprod[-1] = 2**-24
|
105 |
+
|
106 |
+
sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)
|
107 |
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
108 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
109 |
+
|
110 |
+
# setable values
|
111 |
+
self.num_inference_steps = None
|
112 |
+
|
113 |
+
# TODO: Support the full EDM scalings for all prediction types and timestep types
|
114 |
+
if timestep_type == "continuous" and prediction_type == "v_prediction":
|
115 |
+
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
|
116 |
+
else:
|
117 |
+
self.timesteps = timesteps
|
118 |
+
|
119 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
120 |
+
|
121 |
+
self.is_scale_input_called = False
|
122 |
+
self.use_karras_sigmas = use_karras_sigmas
|
123 |
+
|
124 |
+
self._step_index = None
|
125 |
+
self._begin_index = None
|
126 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
127 |
+
|
128 |
+
@property
|
129 |
+
def init_noise_sigma(self):
|
130 |
+
# standard deviation of the initial noise distribution
|
131 |
+
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
|
132 |
+
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
133 |
+
return max_sigma
|
134 |
+
|
135 |
+
return (max_sigma**2 + 1) ** 0.5
|
136 |
+
|
137 |
+
@property
|
138 |
+
def step_index(self):
|
139 |
+
"""
|
140 |
+
The index counter for current timestep. It will increae 1 after each scheduler step.
|
141 |
+
"""
|
142 |
+
return self._step_index
|
143 |
+
|
144 |
+
@property
|
145 |
+
def begin_index(self):
|
146 |
+
"""
|
147 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
148 |
+
"""
|
149 |
+
return self._begin_index
|
150 |
+
|
151 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
152 |
+
def set_begin_index(self, begin_index: int = 0):
|
153 |
+
"""
|
154 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
begin_index (`int`):
|
158 |
+
The begin index for the scheduler.
|
159 |
+
"""
|
160 |
+
self._begin_index = begin_index
|
161 |
+
|
162 |
+
def scale_model_input(
|
163 |
+
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
164 |
+
) -> torch.FloatTensor:
|
165 |
+
"""
|
166 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
167 |
+
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
sample (`torch.FloatTensor`):
|
171 |
+
The input sample.
|
172 |
+
timestep (`int`, *optional*):
|
173 |
+
The current timestep in the diffusion chain.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
`torch.FloatTensor`:
|
177 |
+
A scaled input sample.
|
178 |
+
"""
|
179 |
+
if self.step_index is None:
|
180 |
+
self._init_step_index(timestep)
|
181 |
+
|
182 |
+
sigma = self.sigmas[self.step_index]
|
183 |
+
sample = sample / ((sigma**2 + 1) ** 0.5)
|
184 |
+
|
185 |
+
self.is_scale_input_called = True
|
186 |
+
return sample
|
187 |
+
|
188 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
189 |
+
"""
|
190 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
191 |
+
|
192 |
+
Args:
|
193 |
+
num_inference_steps (`int`):
|
194 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
195 |
+
device (`str` or `torch.device`, *optional*):
|
196 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
197 |
+
"""
|
198 |
+
self.num_inference_steps = num_inference_steps
|
199 |
+
|
200 |
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
201 |
+
if self.config.timestep_spacing == "linspace":
|
202 |
+
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
|
203 |
+
::-1
|
204 |
+
].copy()
|
205 |
+
elif self.config.timestep_spacing == "leading":
|
206 |
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
207 |
+
# creates integer timesteps by multiplying by ratio
|
208 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
209 |
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
210 |
+
timesteps += self.config.steps_offset
|
211 |
+
elif self.config.timestep_spacing == "trailing":
|
212 |
+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
213 |
+
# creates integer timesteps by multiplying by ratio
|
214 |
+
# casting to int to avoid issues when num_inference_step is power of 3
|
215 |
+
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
216 |
+
timesteps -= 1
|
217 |
+
else:
|
218 |
+
raise ValueError(
|
219 |
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
220 |
+
)
|
221 |
+
|
222 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
223 |
+
log_sigmas = np.log(sigmas)
|
224 |
+
|
225 |
+
if self.config.interpolation_type == "linear":
|
226 |
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
227 |
+
elif self.config.interpolation_type == "log_linear":
|
228 |
+
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
|
229 |
+
else:
|
230 |
+
raise ValueError(
|
231 |
+
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
|
232 |
+
" 'linear' or 'log_linear'"
|
233 |
+
)
|
234 |
+
|
235 |
+
if self.use_karras_sigmas:
|
236 |
+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
237 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
238 |
+
|
239 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
240 |
+
|
241 |
+
# TODO: Support the full EDM scalings for all prediction types and timestep types
|
242 |
+
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
|
243 |
+
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
|
244 |
+
else:
|
245 |
+
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
|
246 |
+
|
247 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
248 |
+
self._step_index = None
|
249 |
+
self._begin_index = None
|
250 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
251 |
+
|
252 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
253 |
+
# get log sigma
|
254 |
+
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
255 |
+
|
256 |
+
# get distribution
|
257 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
258 |
+
|
259 |
+
# get sigmas range
|
260 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
261 |
+
high_idx = low_idx + 1
|
262 |
+
|
263 |
+
low = log_sigmas[low_idx]
|
264 |
+
high = log_sigmas[high_idx]
|
265 |
+
|
266 |
+
# interpolate sigmas
|
267 |
+
w = (low - log_sigma) / (low - high)
|
268 |
+
w = np.clip(w, 0, 1)
|
269 |
+
|
270 |
+
# transform interpolation to time range
|
271 |
+
t = (1 - w) * low_idx + w * high_idx
|
272 |
+
t = t.reshape(sigma.shape)
|
273 |
+
return t
|
274 |
+
|
275 |
+
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
276 |
+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
277 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
278 |
+
|
279 |
+
# Hack to make sure that other schedulers which copy this function don't break
|
280 |
+
# TODO: Add this logic to the other schedulers
|
281 |
+
if hasattr(self.config, "sigma_min"):
|
282 |
+
sigma_min = self.config.sigma_min
|
283 |
+
else:
|
284 |
+
sigma_min = None
|
285 |
+
|
286 |
+
if hasattr(self.config, "sigma_max"):
|
287 |
+
sigma_max = self.config.sigma_max
|
288 |
+
else:
|
289 |
+
sigma_max = None
|
290 |
+
|
291 |
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
292 |
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
293 |
+
|
294 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
295 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
296 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
297 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
298 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
299 |
+
return sigmas
|
300 |
+
|
301 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
302 |
+
if schedule_timesteps is None:
|
303 |
+
schedule_timesteps = self.timesteps
|
304 |
+
|
305 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
306 |
+
|
307 |
+
# The sigma index that is taken for the **very** first `step`
|
308 |
+
# is always the second index (or the last index if there is only 1)
|
309 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
310 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
311 |
+
pos = 1 if len(indices) > 1 else 0
|
312 |
+
|
313 |
+
return indices[pos].item()
|
314 |
+
|
315 |
+
def _init_step_index(self, timestep):
|
316 |
+
if self.begin_index is None:
|
317 |
+
if isinstance(timestep, torch.Tensor):
|
318 |
+
timestep = timestep.to(self.timesteps.device)
|
319 |
+
self._step_index = self.index_for_timestep(timestep)
|
320 |
+
else:
|
321 |
+
self._step_index = self._begin_index
|
322 |
+
|
323 |
+
def step(
|
324 |
+
self,
|
325 |
+
model_output: torch.FloatTensor,
|
326 |
+
timestep: Union[float, torch.FloatTensor],
|
327 |
+
sample: torch.FloatTensor,
|
328 |
+
s_churn: float = 0.0,
|
329 |
+
s_tmin: float = 0.0,
|
330 |
+
s_tmax: float = float("inf"),
|
331 |
+
s_noise: float = 1.0,
|
332 |
+
generator: Optional[torch.Generator] = None,
|
333 |
+
return_dict: bool = True
|
334 |
+
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
|
335 |
+
"""
|
336 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
337 |
+
process from the learned model outputs (most often the predicted noise).
|
338 |
+
|
339 |
+
Args:
|
340 |
+
model_output (`torch.FloatTensor`):
|
341 |
+
The direct output from learned diffusion model.
|
342 |
+
timestep (`float`):
|
343 |
+
The current discrete timestep in the diffusion chain.
|
344 |
+
sample (`torch.FloatTensor`):
|
345 |
+
A current instance of a sample created by the diffusion process.
|
346 |
+
s_churn (`float`):
|
347 |
+
s_tmin (`float`):
|
348 |
+
s_tmax (`float`):
|
349 |
+
s_noise (`float`, defaults to 1.0):
|
350 |
+
Scaling factor for noise added to the sample.
|
351 |
+
generator (`torch.Generator`, *optional*):
|
352 |
+
A random number generator.
|
353 |
+
return_dict (`bool`):
|
354 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
355 |
+
tuple.
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
359 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
360 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
361 |
+
"""
|
362 |
+
|
363 |
+
if (
|
364 |
+
isinstance(timestep, int)
|
365 |
+
or isinstance(timestep, torch.IntTensor)
|
366 |
+
or isinstance(timestep, torch.LongTensor)
|
367 |
+
):
|
368 |
+
raise ValueError(
|
369 |
+
(
|
370 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
371 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
372 |
+
" one of the `scheduler.timesteps` as a timestep."
|
373 |
+
),
|
374 |
+
)
|
375 |
+
|
376 |
+
if not self.is_scale_input_called:
|
377 |
+
logger.warning(
|
378 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
379 |
+
"See `StableDiffusionPipeline` for a usage example."
|
380 |
+
)
|
381 |
+
|
382 |
+
if self.step_index is None:
|
383 |
+
self._init_step_index(timestep)
|
384 |
+
|
385 |
+
# Upcast to avoid precision issues when computing prev_sample
|
386 |
+
sample = sample.to(torch.float32)
|
387 |
+
|
388 |
+
sigma = self.sigmas[self.step_index]
|
389 |
+
|
390 |
+
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
391 |
+
|
392 |
+
noise = randn_tensor(
|
393 |
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
394 |
+
)
|
395 |
+
|
396 |
+
eps = noise * s_noise
|
397 |
+
sigma_hat = sigma * (gamma + 1)
|
398 |
+
|
399 |
+
if gamma > 0:
|
400 |
+
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
401 |
+
|
402 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
403 |
+
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
404 |
+
# backwards compatibility
|
405 |
+
if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
|
406 |
+
pred_original_sample = model_output
|
407 |
+
elif self.config.prediction_type == "epsilon":
|
408 |
+
pred_original_sample = sample - sigma_hat * model_output
|
409 |
+
elif self.config.prediction_type == "v_prediction":
|
410 |
+
# denoised = model_output * c_out + input * c_skip
|
411 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
412 |
+
else:
|
413 |
+
raise ValueError(
|
414 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
415 |
+
)
|
416 |
+
|
417 |
+
# 2. Convert to an ODE derivative
|
418 |
+
derivative = (sample - pred_original_sample) / sigma_hat
|
419 |
+
|
420 |
+
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
421 |
+
|
422 |
+
prev_sample = sample + derivative * dt
|
423 |
+
|
424 |
+
# Cast sample back to model compatible dtype
|
425 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
426 |
+
|
427 |
+
# if increment_step_idx:
|
428 |
+
# # upon completion increase step index by one
|
429 |
+
# self._step_index += 1
|
430 |
+
|
431 |
+
if not return_dict:
|
432 |
+
return (prev_sample,)
|
433 |
+
|
434 |
+
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
435 |
+
|
436 |
+
def add_noise(
|
437 |
+
self,
|
438 |
+
original_samples: torch.FloatTensor,
|
439 |
+
noise: torch.FloatTensor,
|
440 |
+
timesteps: torch.FloatTensor,
|
441 |
+
) -> torch.FloatTensor:
|
442 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
443 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
444 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
445 |
+
# mps does not support float64
|
446 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
447 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
448 |
+
else:
|
449 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
450 |
+
timesteps = timesteps.to(original_samples.device)
|
451 |
+
|
452 |
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
453 |
+
if self.begin_index is None:
|
454 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
455 |
+
else:
|
456 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
457 |
+
|
458 |
+
sigma = sigmas[step_indices].flatten()
|
459 |
+
while len(sigma.shape) < len(original_samples.shape):
|
460 |
+
sigma = sigma.unsqueeze(-1)
|
461 |
+
|
462 |
+
noisy_samples = original_samples + noise * sigma
|
463 |
+
return noisy_samples
|
464 |
+
|
465 |
+
def __len__(self):
|
466 |
+
return self.config.num_train_timesteps
|
dataset/stable_video_dataset.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
from torchvision import transforms
|
8 |
+
from torch.utils.data.dataset import Dataset
|
9 |
+
|
10 |
+
class StableVideoDataset(Dataset):
|
11 |
+
def __init__(self,
|
12 |
+
video_data_dir,
|
13 |
+
max_num_videos=None,
|
14 |
+
frame_hight=576, frame_width=1024, num_frames=14,
|
15 |
+
is_reverse_video=True,
|
16 |
+
random_seed=42,
|
17 |
+
double_sampling_rate=False,
|
18 |
+
):
|
19 |
+
self.video_data_dir = video_data_dir
|
20 |
+
video_names = sorted([video for video in os.listdir(video_data_dir)
|
21 |
+
if os.path.isdir(os.path.join(video_data_dir, video))])
|
22 |
+
|
23 |
+
self.length = min(len(video_names), max_num_videos) if max_num_videos is not None else len(video_names)
|
24 |
+
|
25 |
+
self.video_names = video_names[:self.length]
|
26 |
+
if double_sampling_rate:
|
27 |
+
self.sample_frames = num_frames*2-1
|
28 |
+
self.sample_stride = 2
|
29 |
+
else:
|
30 |
+
self.sample_frames = num_frames
|
31 |
+
self.sample_stride = 1
|
32 |
+
|
33 |
+
self.frame_width = frame_width
|
34 |
+
self.frame_height = frame_hight
|
35 |
+
self.pixel_transforms = transforms.Compose([
|
36 |
+
transforms.Resize((self.frame_height, self.frame_width), interpolation=transforms.InterpolationMode.BILINEAR),
|
37 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
38 |
+
])
|
39 |
+
self.is_reverse_video=is_reverse_video
|
40 |
+
np.random.seed(random_seed)
|
41 |
+
|
42 |
+
def get_batch(self, idx):
|
43 |
+
video_name = self.video_names[idx]
|
44 |
+
video_frame_paths = sorted(glob(os.path.join(self.video_data_dir, video_name, '*.png')))
|
45 |
+
start_idx = np.random.randint(len(video_frame_paths)-self.sample_frames+1)
|
46 |
+
video_frame_paths = video_frame_paths[start_idx:start_idx+self.sample_frames:self.sample_stride]
|
47 |
+
video_frames = [np.asarray(Image.open(frame_path).convert('RGB')).astype(np.float32)/255.0 for frame_path in video_frame_paths]
|
48 |
+
video_frames = np.stack(video_frames, axis=0)
|
49 |
+
pixel_values = torch.from_numpy(video_frames.transpose(0, 3, 1, 2))
|
50 |
+
return pixel_values
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return self.length
|
54 |
+
|
55 |
+
def __getitem__(self, idx):
|
56 |
+
while True:
|
57 |
+
try:
|
58 |
+
pixel_values = self.get_batch(idx)
|
59 |
+
break
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
idx = random.randint(0, self.length-1)
|
63 |
+
|
64 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
65 |
+
conditions = pixel_values[-1]
|
66 |
+
if self.is_reverse_video:
|
67 |
+
pixel_values = torch.flip(pixel_values, (0,))
|
68 |
+
|
69 |
+
sample = dict(pixel_values=pixel_values, conditions=conditions)
|
70 |
+
return sample
|
enviroment.yml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: diffusers-0-27-0
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.8.5
|
7 |
+
- pip=20.3
|
8 |
+
- cudatoolkit=11.8
|
9 |
+
- pytorch=2.0.1
|
10 |
+
- torchvision=0.15.2
|
11 |
+
- numpy=1.23.1
|
12 |
+
- pip:
|
13 |
+
- diffusers==0.27.0
|
14 |
+
- albumentations==0.4.3
|
15 |
+
- opencv-python==4.6.0.66
|
16 |
+
- pudb==2019.2
|
17 |
+
- imageio==2.9.0
|
18 |
+
- imageio-ffmpeg==0.4.2
|
19 |
+
- omegaconf==2.1.1
|
20 |
+
- test-tube>=0.7.5
|
21 |
+
- einops==0.3.0
|
22 |
+
- torch-fidelity==0.3.0
|
23 |
+
- torchmetrics==0.11.0
|
24 |
+
- transformers==4.36.0
|
25 |
+
- webdataset==0.2.5
|
26 |
+
- open-clip-torch==2.7.0
|
27 |
+
- invisible-watermark>=0.1.5
|
28 |
+
- accelerate==0.25.0
|
29 |
+
- xformers==0.0.23
|
30 |
+
- peft==0.7.0
|
31 |
+
- torch-ema==0.3
|
32 |
+
- moviepy
|
33 |
+
- tensorboard
|
34 |
+
- Jinja2
|
35 |
+
- ftfy
|
36 |
+
- datasets
|
37 |
+
- wandb
|
38 |
+
- pytorch-fid
|
39 |
+
- notebook
|
40 |
+
- matplotlib
|
41 |
+
- kornia==0.7.2
|
42 |
+
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
43 |
+
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
44 |
+
- -e git+https://github.com/Stability-AI/stablediffusion.git@main#egg=stable-diffusion
|
45 |
+
|
eval/val/0010.png
ADDED
eval/val/0022.png
ADDED
eval/val/0023.png
ADDED
Git LFS Details
|
eval/val/turtle.png
ADDED
Git LFS Details
|
examples/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
results/
|
examples/example_001.gif
ADDED
Git LFS Details
|
examples/example_001/frame1.png
ADDED
Git LFS Details
|
examples/example_001/frame2.png
ADDED
Git LFS Details
|
examples/example_002.gif
ADDED
Git LFS Details
|
examples/example_002/frame1.png
ADDED
Git LFS Details
|
examples/example_002/frame2.png
ADDED
Git LFS Details
|
examples/example_003.gif
ADDED
Git LFS Details
|
examples/example_003/frame1.png
ADDED
examples/example_003/frame2.png
ADDED
Git LFS Details
|
examples/example_004.gif
ADDED
Git LFS Details
|
examples/example_004/frame1.png
ADDED
Git LFS Details
|
examples/example_004/frame2.png
ADDED
Git LFS Details
|
gradio_app.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# import argparse
|
6 |
+
|
7 |
+
checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
|
8 |
+
|
9 |
+
from diffusers.utils import load_image, export_to_video
|
10 |
+
from diffusers import UNetSpatioTemporalConditionModel
|
11 |
+
from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
|
12 |
+
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
|
13 |
+
from attn_ctrl.attention_control import (AttentionStore,
|
14 |
+
register_temporal_self_attention_control,
|
15 |
+
register_temporal_self_attention_flip_control,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
|
20 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
21 |
+
|
22 |
+
pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
|
23 |
+
pretrained_model_name_or_path,
|
24 |
+
scheduler=noise_scheduler,
|
25 |
+
variant="fp16",
|
26 |
+
torch_dtype=torch.float16,
|
27 |
+
)
|
28 |
+
ref_unet = pipe.ori_unet
|
29 |
+
|
30 |
+
state_dict = pipe.unet.state_dict()
|
31 |
+
# computing delta w
|
32 |
+
finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
33 |
+
checkpoint_dir,
|
34 |
+
subfolder="unet",
|
35 |
+
torch_dtype=torch.float16,
|
36 |
+
)
|
37 |
+
assert finetuned_unet.config.num_frames==14
|
38 |
+
ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
39 |
+
"stabilityai/stable-video-diffusion-img2vid",
|
40 |
+
subfolder="unet",
|
41 |
+
variant='fp16',
|
42 |
+
torch_dtype=torch.float16,
|
43 |
+
)
|
44 |
+
|
45 |
+
finetuned_state_dict = finetuned_unet.state_dict()
|
46 |
+
ori_state_dict = ori_unet.state_dict()
|
47 |
+
for name, param in finetuned_state_dict.items():
|
48 |
+
if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
|
49 |
+
delta_w = param - ori_state_dict[name]
|
50 |
+
state_dict[name] = state_dict[name] + delta_w
|
51 |
+
pipe.unet.load_state_dict(state_dict)
|
52 |
+
|
53 |
+
controller_ref= AttentionStore()
|
54 |
+
register_temporal_self_attention_control(ref_unet, controller_ref)
|
55 |
+
|
56 |
+
controller = AttentionStore()
|
57 |
+
register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
|
58 |
+
|
59 |
+
device = "cuda"
|
60 |
+
pipe = pipe.to(device)
|
61 |
+
|
62 |
+
def check_outputs_folder(folder_path):
|
63 |
+
# Check if the folder exists
|
64 |
+
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
65 |
+
# Delete all contents inside the folder
|
66 |
+
for filename in os.listdir(folder_path):
|
67 |
+
file_path = os.path.join(folder_path, filename)
|
68 |
+
try:
|
69 |
+
if os.path.isfile(file_path) or os.path.islink(file_path):
|
70 |
+
os.unlink(file_path) # Remove file or link
|
71 |
+
elif os.path.isdir(file_path):
|
72 |
+
shutil.rmtree(file_path) # Remove directory
|
73 |
+
except Exception as e:
|
74 |
+
print(f'Failed to delete {file_path}. Reason: {e}')
|
75 |
+
else:
|
76 |
+
print(f'The folder {folder_path} does not exist.')
|
77 |
+
|
78 |
+
def infer(frame1_path, frame2_path):
|
79 |
+
|
80 |
+
seed = 42
|
81 |
+
num_inference_steps = 25
|
82 |
+
noise_injection_steps = 0
|
83 |
+
noise_injection_ratio = 0.5
|
84 |
+
weighted_average = True
|
85 |
+
|
86 |
+
generator = torch.Generator(device)
|
87 |
+
if seed is not None:
|
88 |
+
generator = generator.manual_seed(seed)
|
89 |
+
|
90 |
+
|
91 |
+
frame1 = load_image(frame1_path)
|
92 |
+
frame1 = frame1.resize((1024, 576))
|
93 |
+
|
94 |
+
frame2 = load_image(frame2_path)
|
95 |
+
frame2 = frame2.resize((1024, 576))
|
96 |
+
|
97 |
+
frames = pipe(image1=frame1, image2=frame2,
|
98 |
+
num_inference_steps=num_inference_steps, # 50
|
99 |
+
generator=generator,
|
100 |
+
weighted_average=weighted_average, # True
|
101 |
+
noise_injection_steps=noise_injection_steps, # 0
|
102 |
+
noise_injection_ratio= noise_injection_ratio, # 0.5
|
103 |
+
).frames[0]
|
104 |
+
|
105 |
+
out_dir = "result"
|
106 |
+
|
107 |
+
check_outputs_folder(out_dir)
|
108 |
+
os.makedirs(out_dir, exist_ok=True)
|
109 |
+
out_path = "result/video_result.mp4"
|
110 |
+
|
111 |
+
if out_path.endswith('.gif'):
|
112 |
+
frames[0].save(out_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
|
113 |
+
else:
|
114 |
+
export_to_video(frames, out_path, fps=7)
|
115 |
+
|
116 |
+
return out_path
|
117 |
+
|
118 |
+
with gr.Blocks() as demo:
|
119 |
+
|
120 |
+
with gr.Column():
|
121 |
+
gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
|
122 |
+
with gr.Row():
|
123 |
+
with gr.Column():
|
124 |
+
image_input1 = gr.Image(type="filepath")
|
125 |
+
image_input2 = gr.Image(type="filepath")
|
126 |
+
submit_btn = gr.Button("Submit")
|
127 |
+
with gr.Column():
|
128 |
+
output = gr.Video()
|
129 |
+
|
130 |
+
submit_btn.click(
|
131 |
+
fn = infer,
|
132 |
+
inputs = [image_input1, image_input2],
|
133 |
+
outputs = [output],
|
134 |
+
show_api = False
|
135 |
+
)
|
136 |
+
|
137 |
+
demo.queue().launch(show_api=False, show_error=True)
|
keyframe_interpolation.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import copy
|
5 |
+
from diffusers.utils import load_image, export_to_video
|
6 |
+
from diffusers import UNetSpatioTemporalConditionModel
|
7 |
+
from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
|
8 |
+
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
|
9 |
+
from attn_ctrl.attention_control import (AttentionStore,
|
10 |
+
register_temporal_self_attention_control,
|
11 |
+
register_temporal_self_attention_flip_control,
|
12 |
+
)
|
13 |
+
|
14 |
+
def main(args):
|
15 |
+
|
16 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
17 |
+
pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
|
18 |
+
args.pretrained_model_name_or_path,
|
19 |
+
scheduler=noise_scheduler,
|
20 |
+
variant="fp16",
|
21 |
+
torch_dtype=torch.float16,
|
22 |
+
)
|
23 |
+
ref_unet = pipe.ori_unet
|
24 |
+
|
25 |
+
|
26 |
+
state_dict = pipe.unet.state_dict()
|
27 |
+
# computing delta w
|
28 |
+
finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
29 |
+
args.checkpoint_dir,
|
30 |
+
subfolder="unet",
|
31 |
+
torch_dtype=torch.float16,
|
32 |
+
)
|
33 |
+
assert finetuned_unet.config.num_frames==14
|
34 |
+
ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
35 |
+
"stabilityai/stable-video-diffusion-img2vid",
|
36 |
+
subfolder="unet",
|
37 |
+
variant='fp16',
|
38 |
+
torch_dtype=torch.float16,
|
39 |
+
)
|
40 |
+
|
41 |
+
finetuned_state_dict = finetuned_unet.state_dict()
|
42 |
+
ori_state_dict = ori_unet.state_dict()
|
43 |
+
for name, param in finetuned_state_dict.items():
|
44 |
+
if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
|
45 |
+
delta_w = param - ori_state_dict[name]
|
46 |
+
state_dict[name] = state_dict[name] + delta_w
|
47 |
+
pipe.unet.load_state_dict(state_dict)
|
48 |
+
|
49 |
+
controller_ref= AttentionStore()
|
50 |
+
register_temporal_self_attention_control(ref_unet, controller_ref)
|
51 |
+
|
52 |
+
controller = AttentionStore()
|
53 |
+
register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
|
54 |
+
|
55 |
+
pipe = pipe.to(args.device)
|
56 |
+
|
57 |
+
# run inference
|
58 |
+
generator = torch.Generator(device=args.device)
|
59 |
+
if args.seed is not None:
|
60 |
+
generator = generator.manual_seed(args.seed)
|
61 |
+
|
62 |
+
|
63 |
+
frame1 = load_image(args.frame1_path)
|
64 |
+
frame1 = frame1.resize((1024, 576))
|
65 |
+
|
66 |
+
frame2 = load_image(args.frame2_path)
|
67 |
+
frame2 = frame2.resize((1024, 576))
|
68 |
+
|
69 |
+
frames = pipe(image1=frame1, image2=frame2,
|
70 |
+
num_inference_steps=args.num_inference_steps,
|
71 |
+
generator=generator,
|
72 |
+
weighted_average=args.weighted_average,
|
73 |
+
noise_injection_steps=args.noise_injection_steps,
|
74 |
+
noise_injection_ratio= args.noise_injection_ratio,
|
75 |
+
).frames[0]
|
76 |
+
|
77 |
+
if args.out_path.endswith('.gif'):
|
78 |
+
frames[0].save(args.out_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
|
79 |
+
else:
|
80 |
+
export_to_video(frames, args.out_path, fps=7)
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
parser = argparse.ArgumentParser()
|
84 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default="stabilityai/stable-video-diffusion-img2vid-xt")
|
85 |
+
parser.add_argument("--checkpoint_dir", type=str, required=True)
|
86 |
+
parser.add_argument('--frame1_path', type=str, required=True)
|
87 |
+
parser.add_argument('--frame2_path', type=str, required=True)
|
88 |
+
parser.add_argument('--out_path', type=str, required=True)
|
89 |
+
parser.add_argument('--seed', type=int, default=42)
|
90 |
+
parser.add_argument('--num_inference_steps', type=int, default=50)
|
91 |
+
parser.add_argument('--weighted_average', action='store_true')
|
92 |
+
parser.add_argument('--noise_injection_steps', type=int, default=0)
|
93 |
+
parser.add_argument('--noise_injection_ratio', type=float, default=0.5)
|
94 |
+
parser.add_argument('--device', type=str, default='cuda:0')
|
95 |
+
args = parser.parse_args()
|
96 |
+
out_dir = os.path.dirname(args.out_path)
|
97 |
+
os.makedirs(out_dir, exist_ok=True)
|
98 |
+
main(args)
|
keyframe_interpolation.sh
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!bin/bash
|
2 |
+
noise_injection_steps=5
|
3 |
+
noise_injection_ratio=0.5
|
4 |
+
EVAL_DIR=examples
|
5 |
+
CHECKPOINT_DIR=checkpoints/svd_reverse_motion_with_attnflip
|
6 |
+
MODEL_NAME=stabilityai/stable-video-diffusion-img2vid-xt
|
7 |
+
OUT_DIR=results
|
8 |
+
|
9 |
+
mkdir -p $OUT_DIR
|
10 |
+
for example_dir in $(ls -d $EVAL_DIR/*)
|
11 |
+
do
|
12 |
+
example_name=$(basename $example_dir)
|
13 |
+
echo $example_name
|
14 |
+
|
15 |
+
out_fn=$OUT_DIR/$example_name'.gif'
|
16 |
+
python keyframe_interpolation.py \
|
17 |
+
--frame1_path=$example_dir/frame1.png \
|
18 |
+
--frame2_path=$example_dir/frame2.png \
|
19 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
20 |
+
--checkpoint_dir=$CHECKPOINT_DIR \
|
21 |
+
--noise_injection_steps=$noise_injection_steps \
|
22 |
+
--noise_injection_ratio=$noise_injection_ratio \
|
23 |
+
--out_path=$out_fn
|
24 |
+
done
|
25 |
+
|
26 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch=2.0.1
|
2 |
+
torchvision=0.15.2
|
3 |
+
numpy=1.23.1
|
4 |
+
diffusers==0.27.0
|
5 |
+
albumentations==0.4.3
|
6 |
+
opencv-python==4.6.0.66
|
7 |
+
pudb==2019.2
|
8 |
+
imageio==2.9.0
|
9 |
+
imageio-ffmpeg==0.4.2
|
10 |
+
omegaconf==2.1.1
|
11 |
+
test-tube>=0.7.5
|
12 |
+
einops==0.3.0
|
13 |
+
torch-fidelity==0.3.0
|
14 |
+
torchmetrics==0.11.0
|
15 |
+
transformers==4.36.0
|
16 |
+
webdataset==0.2.5
|
17 |
+
open-clip-torch==2.7.0
|
18 |
+
invisible-watermark>=0.1.5
|
19 |
+
accelerate==0.25.0
|
20 |
+
xformers==0.0.23
|
21 |
+
peft==0.7.0
|
22 |
+
torch-ema==0.3
|
23 |
+
moviepy
|
24 |
+
tensorboard
|
25 |
+
Jinja2
|
26 |
+
ftfy
|
27 |
+
datasets
|
28 |
+
wandb
|
29 |
+
pytorch-fid
|
30 |
+
notebook
|
31 |
+
matplotlib
|
32 |
+
kornia==0.7.2
|
33 |
+
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
34 |
+
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
35 |
+
-e git+https://github.com/Stability-AI/stablediffusion.git@main#egg=stable-diffusion
|
train_reverse_motion_with_attnflip.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Fine-tuning script for Stable Video Diffusion for image2video with support for LoRA."""
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
from glob import glob
|
7 |
+
from pathlib import Path
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
import accelerate
|
11 |
+
import datasets
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torch.utils.checkpoint
|
16 |
+
|
17 |
+
from einops import rearrange
|
18 |
+
import transformers
|
19 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
20 |
+
|
21 |
+
from accelerate import Accelerator
|
22 |
+
from accelerate.logging import get_logger
|
23 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
24 |
+
from packaging import version
|
25 |
+
from tqdm.auto import tqdm
|
26 |
+
import copy
|
27 |
+
|
28 |
+
import diffusers
|
29 |
+
from diffusers import AutoencoderKLTemporalDecoder
|
30 |
+
from diffusers import UNetSpatioTemporalConditionModel
|
31 |
+
from diffusers.optimization import get_scheduler
|
32 |
+
from diffusers.training_utils import cast_training_params
|
33 |
+
from diffusers.utils import check_min_version, is_wandb_available
|
34 |
+
from diffusers.utils.import_utils import is_xformers_available
|
35 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
36 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing
|
37 |
+
|
38 |
+
|
39 |
+
from custom_diffusers.pipelines.pipeline_stable_video_diffusion_with_ref_attnmap import StableVideoDiffusionWithRefAttnMapPipeline
|
40 |
+
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
|
41 |
+
from attn_ctrl.attention_control import (AttentionStore,
|
42 |
+
register_temporal_self_attention_control,
|
43 |
+
register_temporal_self_attention_flip_control,
|
44 |
+
)
|
45 |
+
from utils.parse_args import parse_args
|
46 |
+
from dataset.stable_video_dataset import StableVideoDataset
|
47 |
+
|
48 |
+
logger = get_logger(__name__, log_level="INFO")
|
49 |
+
|
50 |
+
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
|
51 |
+
"""Draws samples from an lognormal distribution."""
|
52 |
+
u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
|
53 |
+
return torch.distributions.Normal(loc, scale).icdf(u).exp()
|
54 |
+
|
55 |
+
def main():
|
56 |
+
args = parse_args()
|
57 |
+
|
58 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
59 |
+
|
60 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
61 |
+
|
62 |
+
accelerator = Accelerator(
|
63 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
64 |
+
mixed_precision=args.mixed_precision,
|
65 |
+
log_with=args.report_to,
|
66 |
+
project_config=accelerator_project_config,
|
67 |
+
)
|
68 |
+
if args.report_to == "wandb":
|
69 |
+
if not is_wandb_available():
|
70 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
71 |
+
import wandb
|
72 |
+
|
73 |
+
# Make one log on every process with the configuration for debugging.
|
74 |
+
logging.basicConfig(
|
75 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
76 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
77 |
+
level=logging.INFO,
|
78 |
+
)
|
79 |
+
logger.info(accelerator.state, main_process_only=False)
|
80 |
+
if accelerator.is_local_main_process:
|
81 |
+
datasets.utils.logging.set_verbosity_warning()
|
82 |
+
transformers.utils.logging.set_verbosity_warning()
|
83 |
+
diffusers.utils.logging.set_verbosity_info()
|
84 |
+
else:
|
85 |
+
datasets.utils.logging.set_verbosity_error()
|
86 |
+
transformers.utils.logging.set_verbosity_error()
|
87 |
+
diffusers.utils.logging.set_verbosity_error()
|
88 |
+
|
89 |
+
# If passed along, set the training seed now.
|
90 |
+
if args.seed is not None:
|
91 |
+
set_seed(args.seed)
|
92 |
+
|
93 |
+
# Handle the repository creation
|
94 |
+
if accelerator.is_main_process:
|
95 |
+
if args.output_dir is not None:
|
96 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
97 |
+
|
98 |
+
# Load scheduler, tokenizer and models.
|
99 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
100 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(args.pretrained_model_name_or_path, subfolder="feature_extractor")
|
101 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
102 |
+
args.pretrained_model_name_or_path, subfolder="image_encoder", variant=args.variant
|
103 |
+
)
|
104 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
105 |
+
args.pretrained_model_name_or_path, subfolder="vae", variant=args.variant
|
106 |
+
)
|
107 |
+
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
108 |
+
args.pretrained_model_name_or_path, subfolder="unet", low_cpu_mem_usage=True, variant=args.variant
|
109 |
+
)
|
110 |
+
ref_unet = copy.deepcopy(unet)
|
111 |
+
|
112 |
+
# register customized attn processors
|
113 |
+
controller_ref = AttentionStore()
|
114 |
+
register_temporal_self_attention_control(ref_unet, controller_ref)
|
115 |
+
|
116 |
+
controller = AttentionStore()
|
117 |
+
register_temporal_self_attention_flip_control(unet, controller, controller_ref)
|
118 |
+
|
119 |
+
# freeze parameters of models to save more memory
|
120 |
+
ref_unet.requires_grad_(False)
|
121 |
+
unet.requires_grad_(False)
|
122 |
+
vae.requires_grad_(False)
|
123 |
+
image_encoder.requires_grad_(False)
|
124 |
+
|
125 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
126 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
127 |
+
weight_dtype = torch.float32
|
128 |
+
if accelerator.mixed_precision == "fp16":
|
129 |
+
weight_dtype = torch.float16
|
130 |
+
elif accelerator.mixed_precision == "bf16":
|
131 |
+
weight_dtype = torch.bfloat16
|
132 |
+
|
133 |
+
# Move unet, vae and image_encoder to device and cast to weight_dtype
|
134 |
+
# unet.to(accelerator.device, dtype=weight_dtype)
|
135 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
136 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
137 |
+
ref_unet.to(accelerator.device, dtype=weight_dtype)
|
138 |
+
|
139 |
+
unet_train_params_list = []
|
140 |
+
# Customize the parameters that need to be trained; if necessary, you can uncomment them yourself.
|
141 |
+
for name, para in unet.named_parameters():
|
142 |
+
if 'temporal_transformer_blocks.0.attn1.to_v.weight' in name or 'temporal_transformer_blocks.0.attn1.to_out.0.weight' in name:
|
143 |
+
unet_train_params_list.append(para)
|
144 |
+
para.requires_grad = True
|
145 |
+
else:
|
146 |
+
para.requires_grad = False
|
147 |
+
|
148 |
+
|
149 |
+
if args.mixed_precision == "fp16":
|
150 |
+
# only upcast trainable parameters into fp32
|
151 |
+
cast_training_params(unet, dtype=torch.float32)
|
152 |
+
|
153 |
+
if args.enable_xformers_memory_efficient_attention:
|
154 |
+
if is_xformers_available():
|
155 |
+
import xformers
|
156 |
+
|
157 |
+
xformers_version = version.parse(xformers.__version__)
|
158 |
+
if xformers_version == version.parse("0.0.16"):
|
159 |
+
logger.warn(
|
160 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
161 |
+
)
|
162 |
+
unet.enable_xformers_memory_efficient_attention()
|
163 |
+
else:
|
164 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
165 |
+
|
166 |
+
# `accelerate` 0.16.0 will have better support for customized saving
|
167 |
+
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
168 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
169 |
+
def save_model_hook(models, weights, output_dir):
|
170 |
+
if accelerator.is_main_process:
|
171 |
+
for i, model in enumerate(models):
|
172 |
+
model.save_pretrained(os.path.join(output_dir, "unet"))
|
173 |
+
|
174 |
+
# make sure to pop weight so that corresponding model is not saved again
|
175 |
+
weights.pop()
|
176 |
+
|
177 |
+
def load_model_hook(models, input_dir):
|
178 |
+
for _ in range(len(models)):
|
179 |
+
# pop models so that they are not loaded again
|
180 |
+
model = models.pop()
|
181 |
+
|
182 |
+
# load diffusers style into model
|
183 |
+
load_model = UNetSpatioTemporalConditionModel.from_pretrained(input_dir, subfolder="unet")
|
184 |
+
model.register_to_config(**load_model.config)
|
185 |
+
|
186 |
+
model.load_state_dict(load_model.state_dict())
|
187 |
+
del load_model
|
188 |
+
|
189 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
190 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
191 |
+
|
192 |
+
if args.gradient_checkpointing:
|
193 |
+
unet.enable_gradient_checkpointing()
|
194 |
+
|
195 |
+
if args.gradient_checkpointing:
|
196 |
+
unet.enable_gradient_checkpointing()
|
197 |
+
|
198 |
+
if accelerator.is_main_process:
|
199 |
+
rec_txt1 = open('frozen_param.txt', 'w')
|
200 |
+
rec_txt2 = open('train_param.txt', 'w')
|
201 |
+
for name, para in unet.named_parameters():
|
202 |
+
if para.requires_grad is False:
|
203 |
+
rec_txt1.write(f'{name}\n')
|
204 |
+
else:
|
205 |
+
rec_txt2.write(f'{name}\n')
|
206 |
+
rec_txt1.close()
|
207 |
+
rec_txt2.close()
|
208 |
+
|
209 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
210 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
211 |
+
if args.allow_tf32:
|
212 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
213 |
+
|
214 |
+
if args.scale_lr:
|
215 |
+
args.learning_rate = (
|
216 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
217 |
+
)
|
218 |
+
|
219 |
+
# Initialize the optimizer
|
220 |
+
optimizer = torch.optim.AdamW(
|
221 |
+
unet_train_params_list,
|
222 |
+
lr=args.learning_rate,
|
223 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
224 |
+
weight_decay=args.adam_weight_decay,
|
225 |
+
eps=args.adam_epsilon,
|
226 |
+
)
|
227 |
+
|
228 |
+
def unwrap_model(model):
|
229 |
+
model = accelerator.unwrap_model(model)
|
230 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
231 |
+
return model
|
232 |
+
|
233 |
+
train_dataset = StableVideoDataset(video_data_dir=args.train_data_dir,
|
234 |
+
max_num_videos=args.max_train_samples,
|
235 |
+
num_frames=args.num_frames,
|
236 |
+
is_reverse_video=True,
|
237 |
+
double_sampling_rate=args.double_sampling_rate)
|
238 |
+
def collate_fn(examples):
|
239 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
240 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
241 |
+
conditions = torch.stack([example["conditions"] for example in examples])
|
242 |
+
conditions =conditions.to(memory_format=torch.contiguous_format).float()
|
243 |
+
return {"pixel_values": pixel_values, "conditions": conditions}
|
244 |
+
|
245 |
+
# DataLoaders creation:
|
246 |
+
train_dataloader = torch.utils.data.DataLoader(
|
247 |
+
train_dataset,
|
248 |
+
shuffle=True,
|
249 |
+
collate_fn=collate_fn,
|
250 |
+
batch_size=args.train_batch_size,
|
251 |
+
num_workers=args.dataloader_num_workers,
|
252 |
+
)
|
253 |
+
|
254 |
+
# Validation data
|
255 |
+
if args.validation_data_dir is not None:
|
256 |
+
validation_image_paths = sorted(glob(os.path.join(args.validation_data_dir, '*.png')))
|
257 |
+
num_validation_images = min(args.num_validation_images, len(validation_image_paths))
|
258 |
+
validation_image_paths = validation_image_paths[:num_validation_images]
|
259 |
+
validation_images = [Image.open(image_path).convert('RGB').resize((1024, 576)) for image_path in validation_image_paths]
|
260 |
+
|
261 |
+
|
262 |
+
# Scheduler and math around the number of training steps.
|
263 |
+
overrode_max_train_steps = False
|
264 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
265 |
+
if args.max_train_steps is None:
|
266 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
267 |
+
overrode_max_train_steps = True
|
268 |
+
|
269 |
+
lr_scheduler = get_scheduler(
|
270 |
+
args.lr_scheduler,
|
271 |
+
optimizer=optimizer,
|
272 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
273 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
274 |
+
)
|
275 |
+
|
276 |
+
# Prepare everything with our `accelerator`.
|
277 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
278 |
+
unet, optimizer, train_dataloader, lr_scheduler
|
279 |
+
)
|
280 |
+
|
281 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
282 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
283 |
+
if overrode_max_train_steps:
|
284 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
285 |
+
# Afterwards we recalculate our number of training epochs
|
286 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
287 |
+
|
288 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
289 |
+
# The trackers initializes automatically on the main process.
|
290 |
+
if accelerator.is_main_process:
|
291 |
+
accelerator.init_trackers("image2video-reverse-fine-tune", config=vars(args))
|
292 |
+
|
293 |
+
# Train!
|
294 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
295 |
+
|
296 |
+
logger.info("***** Running training *****")
|
297 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
298 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
299 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
300 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
301 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
302 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
303 |
+
global_step = 0
|
304 |
+
first_epoch = 0
|
305 |
+
|
306 |
+
# Potentially load in the weights and states from a previous save
|
307 |
+
if args.resume_from_checkpoint:
|
308 |
+
if args.resume_from_checkpoint != "latest":
|
309 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
310 |
+
else:
|
311 |
+
# Get the most recent checkpoint
|
312 |
+
dirs = os.listdir(args.output_dir)
|
313 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
314 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
315 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
316 |
+
|
317 |
+
if path is None:
|
318 |
+
accelerator.print(
|
319 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
320 |
+
)
|
321 |
+
args.resume_from_checkpoint = None
|
322 |
+
initial_global_step = 0
|
323 |
+
else:
|
324 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
325 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
326 |
+
global_step = int(path.split("-")[1])
|
327 |
+
|
328 |
+
initial_global_step = global_step
|
329 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
330 |
+
else:
|
331 |
+
initial_global_step = 0
|
332 |
+
|
333 |
+
progress_bar = tqdm(
|
334 |
+
range(0, args.max_train_steps),
|
335 |
+
initial=initial_global_step,
|
336 |
+
desc="Steps",
|
337 |
+
# Only show the progress bar once on each machine.
|
338 |
+
disable=not accelerator.is_local_main_process,
|
339 |
+
)
|
340 |
+
|
341 |
+
# default motion param setting
|
342 |
+
def _get_add_time_ids(
|
343 |
+
dtype,
|
344 |
+
batch_size,
|
345 |
+
fps=6,
|
346 |
+
motion_bucket_id=127,
|
347 |
+
noise_aug_strength=0.02,
|
348 |
+
):
|
349 |
+
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
|
350 |
+
passed_add_embed_dim = unet.module.config.addition_time_embed_dim * \
|
351 |
+
len(add_time_ids)
|
352 |
+
expected_add_embed_dim = unet.module.add_embedding.linear_1.in_features
|
353 |
+
assert (expected_add_embed_dim == passed_add_embed_dim)
|
354 |
+
|
355 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
356 |
+
add_time_ids = add_time_ids.repeat(batch_size, 1)
|
357 |
+
return add_time_ids
|
358 |
+
|
359 |
+
def compute_image_embeddings(image):
|
360 |
+
image = _resize_with_antialiasing(image, (224, 224))
|
361 |
+
image = (image + 1.0) / 2.0
|
362 |
+
# Normalize the image with for CLIP input
|
363 |
+
image = feature_extractor(
|
364 |
+
images=image,
|
365 |
+
do_normalize=True,
|
366 |
+
do_center_crop=False,
|
367 |
+
do_resize=False,
|
368 |
+
do_rescale=False,
|
369 |
+
return_tensors="pt",
|
370 |
+
).pixel_values
|
371 |
+
|
372 |
+
image = image.to(accelerator.device).to(dtype=weight_dtype)
|
373 |
+
image_embeddings = image_encoder(image).image_embeds
|
374 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
375 |
+
return image_embeddings
|
376 |
+
|
377 |
+
noise_aug_strength = 0.02
|
378 |
+
fps=7
|
379 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
380 |
+
unet.train()
|
381 |
+
train_loss = 0.0
|
382 |
+
for step, batch in enumerate(train_dataloader):
|
383 |
+
with accelerator.accumulate(unet):
|
384 |
+
# Get the image embedding for conditioning
|
385 |
+
encoder_hidden_states = compute_image_embeddings(batch["conditions"])
|
386 |
+
encoder_hidden_states_ref = compute_image_embeddings(batch["pixel_values"][:, -1])
|
387 |
+
|
388 |
+
batch["conditions"] = batch["conditions"].to(accelerator.device).to(dtype=weight_dtype)
|
389 |
+
batch["pixel_values"] = batch["pixel_values"].to(accelerator.device).to(dtype=weight_dtype)
|
390 |
+
|
391 |
+
# Get the image latent for input condtioning
|
392 |
+
noise = torch.randn_like(batch["conditions"])
|
393 |
+
conditions = batch["conditions"] + noise_aug_strength * noise
|
394 |
+
conditions_latent = vae.encode(conditions).latent_dist.mode()
|
395 |
+
conditions_latent = conditions_latent.unsqueeze(1).repeat(1, args.num_frames, 1, 1, 1)
|
396 |
+
|
397 |
+
conditions_ref = batch["pixel_values"][:, -1] + noise_aug_strength * noise
|
398 |
+
conditions_latent_ref = vae.encode(conditions_ref).latent_dist.mode()
|
399 |
+
conditions_latent_ref = conditions_latent_ref.unsqueeze(1).repeat(1, args.num_frames, 1, 1, 1)
|
400 |
+
|
401 |
+
# Convert frames to latent space
|
402 |
+
pixel_values = rearrange(batch["pixel_values"], "b f c h w -> (b f) c h w")
|
403 |
+
latents = vae.encode(pixel_values).latent_dist.sample()
|
404 |
+
latents = latents * vae.config.scaling_factor
|
405 |
+
latents = rearrange(latents, "(b f) c h w -> b f c h w", f=args.num_frames)
|
406 |
+
latents_ref= torch.flip(latents, dims=(1,))
|
407 |
+
|
408 |
+
# Sample noise that we'll add to the latents
|
409 |
+
noise = torch.randn_like(latents)
|
410 |
+
if args.noise_offset:
|
411 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
412 |
+
noise += args.noise_offset * torch.randn(
|
413 |
+
(latents.shape[0], latents.shape[1], latents.shape[2], 1, 1), device=latents.device
|
414 |
+
)
|
415 |
+
|
416 |
+
bsz = latents.shape[0]
|
417 |
+
# Sample a random timestep for each image
|
418 |
+
# P_mean=0.7 P_std=1.6
|
419 |
+
sigmas = rand_log_normal(shape=[bsz,], loc=0.7, scale=1.6).to(latents.device)
|
420 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
421 |
+
# (this is the forward diffusion process)
|
422 |
+
sigmas = sigmas[:, None, None, None, None]
|
423 |
+
timesteps = torch.Tensor(
|
424 |
+
[0.25 * sigma.log() for sigma in sigmas]).to(accelerator.device)
|
425 |
+
|
426 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
427 |
+
# (this is the forward diffusion process)
|
428 |
+
noisy_latents = latents + noise * sigmas
|
429 |
+
noisy_latents_inp = noisy_latents / ((sigmas**2 + 1) ** 0.5)
|
430 |
+
noisy_latents_inp = torch.cat([noisy_latents_inp, conditions_latent], dim=2)
|
431 |
+
|
432 |
+
noisy_latents_ref = latents_ref + torch.flip(noise, dims=(1,)) * sigmas
|
433 |
+
noisy_latents_ref_inp = noisy_latents_ref / ((sigmas**2 + 1) ** 0.5)
|
434 |
+
noisy_latents_ref_inp = torch.cat([noisy_latents_ref_inp, conditions_latent_ref], dim=2)
|
435 |
+
|
436 |
+
# Get the target for loss depending on the prediction type
|
437 |
+
target = latents
|
438 |
+
# Predict the noise residual and compute loss
|
439 |
+
added_time_ids = _get_add_time_ids(encoder_hidden_states.dtype, bsz).to(accelerator.device)
|
440 |
+
ref_model_pred = ref_unet(noisy_latents_ref_inp.to(weight_dtype), timesteps.to(weight_dtype),
|
441 |
+
encoder_hidden_states=encoder_hidden_states_ref,
|
442 |
+
added_time_ids=added_time_ids,
|
443 |
+
return_dict=False)[0]
|
444 |
+
model_pred = unet(noisy_latents_inp, timesteps,
|
445 |
+
encoder_hidden_states=encoder_hidden_states,
|
446 |
+
added_time_ids=added_time_ids,
|
447 |
+
return_dict=False)[0] # v-prediction
|
448 |
+
# Denoise the latents
|
449 |
+
c_out = -sigmas / ((sigmas**2 + 1)**0.5)
|
450 |
+
c_skip = 1 / (sigmas**2 + 1)
|
451 |
+
denoised_latents = model_pred * c_out + c_skip * noisy_latents
|
452 |
+
weighing = (1 + sigmas ** 2) * (sigmas**-2.0)
|
453 |
+
|
454 |
+
# MSE loss
|
455 |
+
loss = torch.mean(
|
456 |
+
(weighing.float() * (denoised_latents.float() -
|
457 |
+
target.float()) ** 2).reshape(target.shape[0], -1),
|
458 |
+
dim=1,
|
459 |
+
)
|
460 |
+
loss = loss.mean()
|
461 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
462 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
463 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
464 |
+
|
465 |
+
# Backpropagate
|
466 |
+
accelerator.backward(loss)
|
467 |
+
if accelerator.sync_gradients:
|
468 |
+
params_to_clip = unet_train_params_list
|
469 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
470 |
+
optimizer.step()
|
471 |
+
lr_scheduler.step()
|
472 |
+
optimizer.zero_grad()
|
473 |
+
|
474 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
475 |
+
if accelerator.sync_gradients:
|
476 |
+
progress_bar.update(1)
|
477 |
+
global_step += 1
|
478 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
479 |
+
train_loss = 0.0
|
480 |
+
|
481 |
+
if global_step % args.checkpointing_steps == 0:
|
482 |
+
if accelerator.is_main_process:
|
483 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
484 |
+
if args.checkpoints_total_limit is not None:
|
485 |
+
checkpoints = os.listdir(args.output_dir)
|
486 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
487 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
488 |
+
|
489 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
490 |
+
if len(checkpoints) >= args.checkpoints_total_limit:
|
491 |
+
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
492 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
493 |
+
|
494 |
+
logger.info(
|
495 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
496 |
+
)
|
497 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
498 |
+
|
499 |
+
for removing_checkpoint in removing_checkpoints:
|
500 |
+
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
501 |
+
shutil.rmtree(removing_checkpoint)
|
502 |
+
|
503 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
504 |
+
accelerator.save_state(save_path)
|
505 |
+
logger.info(f"Saved state to {save_path}")
|
506 |
+
|
507 |
+
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
508 |
+
progress_bar.set_postfix(**logs)
|
509 |
+
|
510 |
+
if global_step >= args.max_train_steps:
|
511 |
+
break
|
512 |
+
|
513 |
+
if accelerator.is_main_process:
|
514 |
+
if args.validation_data_dir is not None and epoch % args.validation_epochs == 0:
|
515 |
+
logger.info(
|
516 |
+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
517 |
+
f" {args.validation_data_dir}."
|
518 |
+
)
|
519 |
+
# create pipeline
|
520 |
+
pipeline = StableVideoDiffusionWithRefAttnMapPipeline.from_pretrained(
|
521 |
+
args.pretrained_model_name_or_path,
|
522 |
+
scheduler=noise_scheduler,
|
523 |
+
unet=unwrap_model(unet),
|
524 |
+
variant=args.variant,
|
525 |
+
torch_dtype=weight_dtype,
|
526 |
+
)
|
527 |
+
pipeline = pipeline.to(accelerator.device)
|
528 |
+
pipeline.set_progress_bar_config(disable=True)
|
529 |
+
|
530 |
+
# run inference
|
531 |
+
generator = torch.Generator(device=accelerator.device)
|
532 |
+
if args.seed is not None:
|
533 |
+
generator = generator.manual_seed(args.seed)
|
534 |
+
videos = []
|
535 |
+
with torch.cuda.amp.autocast():
|
536 |
+
for val_idx in range(num_validation_images):
|
537 |
+
val_img = validation_images[val_idx]
|
538 |
+
videos.append(
|
539 |
+
pipeline(ref_unet=ref_unet, image=val_img, ref_image=val_img, num_inference_steps=50, generator=generator, output_type='pt').frames[0]
|
540 |
+
)
|
541 |
+
|
542 |
+
for tracker in accelerator.trackers:
|
543 |
+
if tracker.name == "tensorboard":
|
544 |
+
videos = torch.stack(videos)
|
545 |
+
tracker.writer.add_video("validation", videos, epoch, fps=fps)
|
546 |
+
|
547 |
+
del pipeline
|
548 |
+
torch.cuda.empty_cache()
|
549 |
+
|
550 |
+
# Save the lora layers
|
551 |
+
accelerator.wait_for_everyone()
|
552 |
+
if accelerator.is_main_process:
|
553 |
+
unet = unet.to(torch.float32)
|
554 |
+
|
555 |
+
unwrapped_unet = unwrap_model(unet)
|
556 |
+
pipeline = StableVideoDiffusionWithRefAttnMapPipeline.from_pretrained(
|
557 |
+
args.pretrained_model_name_or_path,
|
558 |
+
scheduler=noise_scheduler,
|
559 |
+
unet=unwrapped_unet,
|
560 |
+
variant=args.variant,
|
561 |
+
)
|
562 |
+
pipeline.save_pretrained(args.output_dir)
|
563 |
+
# Final inference
|
564 |
+
# Load previous pipeline
|
565 |
+
if args.validation_data_dir is not None:
|
566 |
+
pipeline = pipeline.to(accelerator.device)
|
567 |
+
pipeline.torch_dtype = weight_dtype
|
568 |
+
# run inference
|
569 |
+
generator = torch.Generator(device=accelerator.device)
|
570 |
+
if args.seed is not None:
|
571 |
+
generator = generator.manual_seed(args.seed)
|
572 |
+
videos = []
|
573 |
+
with torch.cuda.amp.autocast():
|
574 |
+
for val_idx in range(num_validation_images):
|
575 |
+
val_img = validation_images[val_idx]
|
576 |
+
videos.append(
|
577 |
+
pipeline(ref_unet=ref_unet, image=val_img, ref_image=val_img, num_inference_steps=50, generator=generator, output_type='pt').frames[0]
|
578 |
+
)
|
579 |
+
|
580 |
+
|
581 |
+
for tracker in accelerator.trackers:
|
582 |
+
if len(videos) != 0:
|
583 |
+
if tracker.name == "tensorboard":
|
584 |
+
videos = torch.stack(videos)
|
585 |
+
tracker.writer.add_video("validation", videos, epoch, fps=fps)
|
586 |
+
|
587 |
+
accelerator.end_training()
|
588 |
+
|
589 |
+
|
590 |
+
if __name__ == "__main__":
|
591 |
+
main()
|
train_reverse_motion_with_attnflip.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL_NAME=stabilityai/stable-video-diffusion-img2vid
|
2 |
+
TRAIN_DIR=../keyframe_interpolation_data/synthetic_videos_frames
|
3 |
+
VALIDATION_DIR=eval/val
|
4 |
+
accelerate launch --mixed_precision="fp16" train_reverse_motion_with_attnflip.py \
|
5 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
6 |
+
--variant "fp16" \
|
7 |
+
--num_frames 14 \
|
8 |
+
--train_data_dir=$TRAIN_DIR \
|
9 |
+
--validation_data_dir=$VALIDATION_DIR \
|
10 |
+
--max_train_samples=100 \
|
11 |
+
--train_batch_size=1 \
|
12 |
+
--gradient_accumulation_steps 1 \
|
13 |
+
--num_train_epochs=1000 --checkpointing_steps=2000 \
|
14 |
+
--validation_epochs=50 \
|
15 |
+
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
16 |
+
--seed=42 \
|
17 |
+
--double_sampling_rate \
|
18 |
+
--output_dir="checkpoints/svd_reverse_motion_with_attnflip" \
|
19 |
+
--cache_dir="checkpoints/svd_reverse_motion_with_attnflip_cache" \
|
20 |
+
--report_to="tensorboard"
|
utils/parse_args.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
def parse_args():
|
4 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
5 |
+
parser.add_argument(
|
6 |
+
"--pretrained_model_name_or_path",
|
7 |
+
type=str,
|
8 |
+
default=None,
|
9 |
+
required=True,
|
10 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
11 |
+
)
|
12 |
+
parser.add_argument(
|
13 |
+
"--variant",
|
14 |
+
type=str,
|
15 |
+
default=None,
|
16 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--num_frames",
|
20 |
+
type=int,
|
21 |
+
default=25,
|
22 |
+
help="Number of frames that should be generated in the video.",
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--train_data_dir",
|
26 |
+
type=str,
|
27 |
+
default=None,
|
28 |
+
required=True,
|
29 |
+
help=(
|
30 |
+
"A folder containing the training data. Folder contents must follow the structure described in"
|
31 |
+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
32 |
+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
33 |
+
),
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--max_train_samples",
|
37 |
+
type=int,
|
38 |
+
default=None,
|
39 |
+
help=(
|
40 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
41 |
+
"value if set."
|
42 |
+
),
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--double_sampling_rate",
|
46 |
+
action="store_true",
|
47 |
+
help=(
|
48 |
+
"whether or not sampling training frames double rate"
|
49 |
+
),
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--validation_data_dir", type=str, default=None, help="A prompt that is sampled during training for inference."
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--num_validation_images",
|
56 |
+
type=int,
|
57 |
+
default=4,
|
58 |
+
help="Number of images that should be generated during validation with `validation_prompt`.",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--validation_epochs",
|
62 |
+
type=int,
|
63 |
+
default=1,
|
64 |
+
help=(
|
65 |
+
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
|
66 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
67 |
+
),
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--output_dir",
|
71 |
+
type=str,
|
72 |
+
default="sd-model-finetuned-lora",
|
73 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--cache_dir",
|
77 |
+
type=str,
|
78 |
+
default=None,
|
79 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
80 |
+
)
|
81 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
82 |
+
parser.add_argument(
|
83 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
84 |
+
)
|
85 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
86 |
+
parser.add_argument(
|
87 |
+
"--max_train_steps",
|
88 |
+
type=int,
|
89 |
+
default=None,
|
90 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--gradient_accumulation_steps",
|
94 |
+
type=int,
|
95 |
+
default=1,
|
96 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--gradient_checkpointing",
|
100 |
+
action="store_true",
|
101 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--learning_rate",
|
105 |
+
type=float,
|
106 |
+
default=1e-4,
|
107 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--scale_lr",
|
111 |
+
action="store_true",
|
112 |
+
default=False,
|
113 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--lr_scheduler",
|
117 |
+
type=str,
|
118 |
+
default="constant",
|
119 |
+
help=(
|
120 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
121 |
+
' "constant", "constant_with_warmup"]'
|
122 |
+
),
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--allow_tf32",
|
129 |
+
action="store_true",
|
130 |
+
help=(
|
131 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
132 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
133 |
+
),
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--dataloader_num_workers",
|
137 |
+
type=int,
|
138 |
+
default=0,
|
139 |
+
help=(
|
140 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
141 |
+
),
|
142 |
+
)
|
143 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
144 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
145 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
146 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
147 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
148 |
+
parser.add_argument(
|
149 |
+
"--prediction_type",
|
150 |
+
type=str,
|
151 |
+
default=None,
|
152 |
+
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--logging_dir",
|
156 |
+
type=str,
|
157 |
+
default="logs",
|
158 |
+
help=(
|
159 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
160 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
161 |
+
),
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--mixed_precision",
|
165 |
+
type=str,
|
166 |
+
default=None,
|
167 |
+
choices=["no", "fp16", "bf16"],
|
168 |
+
help=(
|
169 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
170 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
171 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
172 |
+
),
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--report_to",
|
176 |
+
type=str,
|
177 |
+
default="tensorboard",
|
178 |
+
help=(
|
179 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
180 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
181 |
+
),
|
182 |
+
)
|
183 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
184 |
+
parser.add_argument(
|
185 |
+
"--checkpointing_steps",
|
186 |
+
type=int,
|
187 |
+
default=500,
|
188 |
+
help=(
|
189 |
+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
190 |
+
" training using `--resume_from_checkpoint`."
|
191 |
+
),
|
192 |
+
)
|
193 |
+
parser.add_argument(
|
194 |
+
"--checkpoints_total_limit",
|
195 |
+
type=int,
|
196 |
+
default=None,
|
197 |
+
help=("Max number of checkpoints to store."),
|
198 |
+
)
|
199 |
+
parser.add_argument(
|
200 |
+
"--resume_from_checkpoint",
|
201 |
+
type=str,
|
202 |
+
default=None,
|
203 |
+
help=(
|
204 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
205 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
206 |
+
),
|
207 |
+
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
210 |
+
)
|
211 |
+
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
212 |
+
parser.add_argument(
|
213 |
+
"--rank",
|
214 |
+
type=int,
|
215 |
+
default=4,
|
216 |
+
help=("The dimension of the LoRA update matrices."),
|
217 |
+
)
|
218 |
+
|
219 |
+
args = parser.parse_args()
|
220 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
221 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
222 |
+
args.local_rank = env_local_rank
|
223 |
+
|
224 |
+
return args
|