Spaces:
Running
on
Zero
Running
on
Zero
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- README.md +11 -7
- __init__.py +0 -0
- config.yaml +316 -0
- dataloader/dataset_factory.py +13 -0
- dataloader/single_image_dataset.py +16 -0
- dataloader/video_data_module.py +32 -0
- diffusion_trainer/abstract_trainer.py +108 -0
- diffusion_trainer/streaming_svd.py +508 -0
- gradio_demo.py +214 -0
- i2v_enhance/i2v_enhance_interface.py +128 -0
- i2v_enhance/pipeline_i2vgen_xl.py +988 -0
- i2v_enhance/thirdparty/VFI/Trainer.py +168 -0
- i2v_enhance/thirdparty/VFI/ckpt/Put ours.pkl files here.txt +1 -0
- i2v_enhance/thirdparty/VFI/ckpt/__init__.py +0 -0
- i2v_enhance/thirdparty/VFI/config.py +49 -0
- i2v_enhance/thirdparty/VFI/dataset.py +93 -0
- i2v_enhance/thirdparty/VFI/model/__init__.py +5 -0
- i2v_enhance/thirdparty/VFI/model/feature_extractor.py +516 -0
- i2v_enhance/thirdparty/VFI/model/flow_estimation.py +141 -0
- i2v_enhance/thirdparty/VFI/model/loss.py +95 -0
- i2v_enhance/thirdparty/VFI/model/refine.py +71 -0
- i2v_enhance/thirdparty/VFI/model/warplayer.py +21 -0
- i2v_enhance/thirdparty/VFI/train.py +105 -0
- lib/__init__.py +0 -0
- lib/farancia/__init__.py +4 -0
- lib/farancia/animation.py +43 -0
- lib/farancia/config.py +1 -0
- lib/farancia/libimage/__init__.py +45 -0
- lib/farancia/libimage/iimage.py +511 -0
- lib/farancia/libimage/utils.py +8 -0
- models/cam/conditioning.py +150 -0
- models/control/controlnet.py +581 -0
- models/diffusion/discretizer.py +33 -0
- models/diffusion/video_model.py +574 -0
- models/diffusion/wrappers.py +78 -0
- models/svd/sgm/__init__.py +4 -0
- models/svd/sgm/data/__init__.py +1 -0
- models/svd/sgm/data/cifar10.py +67 -0
- models/svd/sgm/data/dataset.py +80 -0
- models/svd/sgm/data/mnist.py +85 -0
- models/svd/sgm/inference/api.py +385 -0
- models/svd/sgm/inference/helpers.py +305 -0
- models/svd/sgm/lr_scheduler.py +135 -0
- models/svd/sgm/models/__init__.py +2 -0
- models/svd/sgm/models/autoencoder.py +615 -0
- models/svd/sgm/models/diffusion.py +341 -0
- models/svd/sgm/modules/__init__.py +6 -0
- models/svd/sgm/modules/attention.py +809 -0
- models/svd/sgm/modules/autoencoding/__init__.py +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,17 @@
|
|
1 |
---
|
2 |
title: StreamingSVD
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.43.0
|
|
|
|
|
8 |
app_file: app.py
|
9 |
-
pinned: false
|
10 |
license: mit
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: StreamingSVD
|
3 |
+
emoji: 🎥
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.43.0
|
8 |
+
suggested_hardware: a100-large
|
9 |
+
suggested_storage: large
|
10 |
app_file: app.py
|
|
|
11 |
license: mit
|
12 |
+
tags:
|
13 |
+
- StreamingSVD
|
14 |
+
- long-video-generation
|
15 |
+
- PAIR
|
16 |
+
short_description: Image-to-Video
|
17 |
+
disable_embedding: false
|
__init__.py
ADDED
File without changes
|
config.yaml
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_lightning==2.2.2
|
2 |
+
seed_everything: 33
|
3 |
+
trainer:
|
4 |
+
accelerator: auto
|
5 |
+
strategy: auto
|
6 |
+
devices: '1'
|
7 |
+
num_nodes: 1
|
8 |
+
precision: 16-mixed
|
9 |
+
logger: False
|
10 |
+
model:
|
11 |
+
class_path: diffusion_trainer.streaming_svd.StreamingSVD
|
12 |
+
init_args:
|
13 |
+
vfi:
|
14 |
+
class_path: modules.params.vfi.VFIParams
|
15 |
+
init_args:
|
16 |
+
ckpt_path_local: checkpoint/VFI/ours.pkl
|
17 |
+
ckpt_path_global: https://drive.google.com/file/d/1XCNoyhA1RX3m8W-XJK8H8inH47l36kxP/view?usp=sharing
|
18 |
+
i2v_enhance:
|
19 |
+
class_path: modules.params.i2v_enhance.I2VEnhanceParams
|
20 |
+
init_args:
|
21 |
+
ckpt_path_local: checkpoint/i2v_enhance/
|
22 |
+
ckpt_path_global: ali-vilab/i2vgen-xl
|
23 |
+
module_loader:
|
24 |
+
class_path: modules.loader.module_loader.GenericModuleLoader
|
25 |
+
init_args:
|
26 |
+
pipeline_repo: stabilityai/stable-video-diffusion-img2vid-xt
|
27 |
+
pipeline_obj: streamingt2v_pipeline
|
28 |
+
set_prediction_type: ''
|
29 |
+
module_names:
|
30 |
+
- network_config
|
31 |
+
- model
|
32 |
+
- controlnet
|
33 |
+
- denoiser
|
34 |
+
- conditioner
|
35 |
+
- first_stage_model
|
36 |
+
- sampler
|
37 |
+
- svd_pipeline
|
38 |
+
module_config:
|
39 |
+
controlnet:
|
40 |
+
class_path: modules.loader.module_loader_config.ModuleLoaderConfig
|
41 |
+
init_args:
|
42 |
+
loader_cls_path: models.control.controlnet.ControlNet
|
43 |
+
cls_func: from_unet
|
44 |
+
cls_func_fast_dev_run: ''
|
45 |
+
kwargs_diffusers: null
|
46 |
+
model_params:
|
47 |
+
merging_mode: addition
|
48 |
+
zero_conv_mode: Identity
|
49 |
+
frame_expansion: none
|
50 |
+
downsample_controlnet_cond: true
|
51 |
+
use_image_encoder_normalization: true
|
52 |
+
use_controlnet_mask: false
|
53 |
+
condition_encoder: ''
|
54 |
+
conditioning_embedding_out_channels:
|
55 |
+
- 32
|
56 |
+
- 96
|
57 |
+
- 256
|
58 |
+
- 512
|
59 |
+
kwargs_diff_trainer_params: null
|
60 |
+
args: []
|
61 |
+
dependent_modules:
|
62 |
+
model: model
|
63 |
+
dependent_modules_cloned: null
|
64 |
+
state_dict_path: ''
|
65 |
+
strict_loading: true
|
66 |
+
state_dict_filters: []
|
67 |
+
network_config:
|
68 |
+
class_path: models.diffusion.video_model.VideoUNet
|
69 |
+
init_args:
|
70 |
+
in_channels: 8
|
71 |
+
model_channels: 320
|
72 |
+
out_channels: 4
|
73 |
+
num_res_blocks: 2
|
74 |
+
num_conditional_frames: null
|
75 |
+
attention_resolutions:
|
76 |
+
- 4
|
77 |
+
- 2
|
78 |
+
- 1
|
79 |
+
dropout: 0.0
|
80 |
+
channel_mult:
|
81 |
+
- 1
|
82 |
+
- 2
|
83 |
+
- 4
|
84 |
+
- 4
|
85 |
+
conv_resample: true
|
86 |
+
dims: 2
|
87 |
+
num_classes: sequential
|
88 |
+
use_checkpoint: False
|
89 |
+
num_heads: -1
|
90 |
+
num_head_channels: 64
|
91 |
+
num_heads_upsample: -1
|
92 |
+
use_scale_shift_norm: false
|
93 |
+
resblock_updown: false
|
94 |
+
transformer_depth: 1
|
95 |
+
transformer_depth_middle: null
|
96 |
+
context_dim: 1024
|
97 |
+
time_downup: false
|
98 |
+
time_context_dim: null
|
99 |
+
extra_ff_mix_layer: true
|
100 |
+
use_spatial_context: true
|
101 |
+
merge_strategy: learned_with_images
|
102 |
+
merge_factor: 0.5
|
103 |
+
spatial_transformer_attn_type: softmax-xformers
|
104 |
+
video_kernel_size:
|
105 |
+
- 3
|
106 |
+
- 1
|
107 |
+
- 1
|
108 |
+
use_linear_in_transformer: true
|
109 |
+
adm_in_channels: 768
|
110 |
+
disable_temporal_crossattention: false
|
111 |
+
max_ddpm_temb_period: 10000
|
112 |
+
merging_mode: attention_cross_attention
|
113 |
+
controlnet_mode: true
|
114 |
+
use_apm: false
|
115 |
+
model:
|
116 |
+
class_path: modules.loader.module_loader_config.ModuleLoaderConfig
|
117 |
+
init_args:
|
118 |
+
loader_cls_path: models.svd.sgm.modules.diffusionmodules.wrappers.OpenAIWrapper
|
119 |
+
cls_func: ''
|
120 |
+
cls_func_fast_dev_run: ''
|
121 |
+
kwargs_diffusers:
|
122 |
+
compile_model: false
|
123 |
+
model_params: null
|
124 |
+
model_params_fast_dev_run: null
|
125 |
+
kwargs_diff_trainer_params: null
|
126 |
+
args: []
|
127 |
+
dependent_modules:
|
128 |
+
diffusion_model: network_config
|
129 |
+
dependent_modules_cloned: null
|
130 |
+
state_dict_path: ''
|
131 |
+
strict_loading: true
|
132 |
+
state_dict_filters: []
|
133 |
+
denoiser:
|
134 |
+
class_path: models.svd.sgm.modules.diffusionmodules.denoiser.Denoiser
|
135 |
+
init_args:
|
136 |
+
scaling_config:
|
137 |
+
target: models.svd.sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
138 |
+
sampler:
|
139 |
+
class_path: models.svd.sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
140 |
+
init_args:
|
141 |
+
s_churn: 0.0
|
142 |
+
s_tmin: 0.0
|
143 |
+
s_tmax: .inf
|
144 |
+
s_noise: 1.0
|
145 |
+
discretization_config:
|
146 |
+
target: models.diffusion.discretizer.AlignYourSteps
|
147 |
+
params:
|
148 |
+
sigma_max: 700.0
|
149 |
+
num_steps: 30
|
150 |
+
guider_config:
|
151 |
+
target: models.svd.sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
|
152 |
+
params:
|
153 |
+
max_scale: 3.0
|
154 |
+
min_scale: 1.5
|
155 |
+
num_frames: 25
|
156 |
+
verbose: false
|
157 |
+
device: cuda
|
158 |
+
conditioner:
|
159 |
+
class_path: models.svd.sgm.modules.GeneralConditioner
|
160 |
+
init_args:
|
161 |
+
emb_models:
|
162 |
+
- is_trainable: false
|
163 |
+
input_key: cond_frames_without_noise
|
164 |
+
target: models.svd.sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
165 |
+
params:
|
166 |
+
n_cond_frames: 1
|
167 |
+
n_copies: 1
|
168 |
+
open_clip_embedding_config:
|
169 |
+
target: models.svd.sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
170 |
+
params:
|
171 |
+
freeze: true
|
172 |
+
- input_key: fps_id
|
173 |
+
is_trainable: false
|
174 |
+
target: models.svd.sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
175 |
+
params:
|
176 |
+
outdim: 256
|
177 |
+
- input_key: motion_bucket_id
|
178 |
+
is_trainable: false
|
179 |
+
target: models.svd.sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
180 |
+
params:
|
181 |
+
outdim: 256
|
182 |
+
- input_key: cond_frames
|
183 |
+
is_trainable: false
|
184 |
+
target: models.svd.sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
185 |
+
params:
|
186 |
+
disable_encoder_autocast: true
|
187 |
+
n_cond_frames: 1
|
188 |
+
n_copies: 1
|
189 |
+
is_ae: true
|
190 |
+
encoder_config:
|
191 |
+
target: models.svd.sgm.models.autoencoder.AutoencoderKLModeOnly
|
192 |
+
params:
|
193 |
+
embed_dim: 4
|
194 |
+
monitor: val/rec_loss
|
195 |
+
ddconfig:
|
196 |
+
attn_type: vanilla-xformers
|
197 |
+
double_z: true
|
198 |
+
z_channels: 4
|
199 |
+
resolution: 256
|
200 |
+
in_channels: 3
|
201 |
+
out_ch: 3
|
202 |
+
ch: 128
|
203 |
+
ch_mult:
|
204 |
+
- 1
|
205 |
+
- 2
|
206 |
+
- 4
|
207 |
+
- 4
|
208 |
+
num_res_blocks: 2
|
209 |
+
attn_resolutions: []
|
210 |
+
dropout: 0.0
|
211 |
+
lossconfig:
|
212 |
+
target: torch.nn.Identity
|
213 |
+
- input_key: cond_aug
|
214 |
+
is_trainable: false
|
215 |
+
target: models.svd.sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
216 |
+
params:
|
217 |
+
outdim: 256
|
218 |
+
first_stage_model:
|
219 |
+
class_path: models.svd.sgm.AutoencodingEngine
|
220 |
+
init_args:
|
221 |
+
encoder_config:
|
222 |
+
target: models.svd.sgm.modules.diffusionmodules.model.Encoder
|
223 |
+
params:
|
224 |
+
attn_type: vanilla
|
225 |
+
double_z: true
|
226 |
+
z_channels: 4
|
227 |
+
resolution: 256
|
228 |
+
in_channels: 3
|
229 |
+
out_ch: 3
|
230 |
+
ch: 128
|
231 |
+
ch_mult:
|
232 |
+
- 1
|
233 |
+
- 2
|
234 |
+
- 4
|
235 |
+
- 4
|
236 |
+
num_res_blocks: 2
|
237 |
+
attn_resolutions: []
|
238 |
+
dropout: 0.0
|
239 |
+
decoder_config:
|
240 |
+
target: models.svd.sgm.modules.autoencoding.temporal_ae.VideoDecoder
|
241 |
+
params:
|
242 |
+
attn_type: vanilla
|
243 |
+
double_z: true
|
244 |
+
z_channels: 4
|
245 |
+
resolution: 256
|
246 |
+
in_channels: 3
|
247 |
+
out_ch: 3
|
248 |
+
ch: 128
|
249 |
+
ch_mult:
|
250 |
+
- 1
|
251 |
+
- 2
|
252 |
+
- 4
|
253 |
+
- 4
|
254 |
+
num_res_blocks: 2
|
255 |
+
attn_resolutions: []
|
256 |
+
dropout: 0.0
|
257 |
+
video_kernel_size:
|
258 |
+
- 3
|
259 |
+
- 1
|
260 |
+
- 1
|
261 |
+
loss_config:
|
262 |
+
target: torch.nn.Identity
|
263 |
+
regularizer_config:
|
264 |
+
target: models.svd.sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
265 |
+
optimizer_config: null
|
266 |
+
lr_g_factor: 1.0
|
267 |
+
trainable_ae_params: null
|
268 |
+
ae_optimizer_args: null
|
269 |
+
trainable_disc_params: null
|
270 |
+
disc_optimizer_args: null
|
271 |
+
disc_start_iter: 0
|
272 |
+
diff_boost_factor: 3.0
|
273 |
+
ckpt_engine: null
|
274 |
+
ckpt_path: null
|
275 |
+
additional_decode_keys: null
|
276 |
+
ema_decay: null
|
277 |
+
monitor: null
|
278 |
+
input_key: jpg
|
279 |
+
svd_pipeline:
|
280 |
+
class_path: modules.loader.module_loader_config.ModuleLoaderConfig
|
281 |
+
init_args:
|
282 |
+
loader_cls_path: diffusers.StableVideoDiffusionPipeline
|
283 |
+
cls_func: from_pretrained
|
284 |
+
cls_func_fast_dev_run: ''
|
285 |
+
kwargs_diffusers:
|
286 |
+
torch_dtype: torch.float16
|
287 |
+
variant: fp16
|
288 |
+
use_safetensors: true
|
289 |
+
model_params: null
|
290 |
+
model_params_fast_dev_run: null
|
291 |
+
kwargs_diff_trainer_params: null
|
292 |
+
args:
|
293 |
+
- stabilityai/stable-video-diffusion-img2vid-xt
|
294 |
+
dependent_modules: null
|
295 |
+
dependent_modules_cloned: null
|
296 |
+
state_dict_path: ''
|
297 |
+
strict_loading: true
|
298 |
+
state_dict_filters: []
|
299 |
+
root_cls: null
|
300 |
+
diff_trainer_params:
|
301 |
+
class_path: modules.params.diffusion_trainer.params_streaming_diff_trainer.DiffusionTrainerParams
|
302 |
+
init_args:
|
303 |
+
scale_factor: 0.18215
|
304 |
+
streamingsvd_ckpt:
|
305 |
+
class_path: modules.params.diffusion_trainer.params_streaming_diff_trainer.CheckpointDescriptor
|
306 |
+
init_args:
|
307 |
+
ckpt_path_local: checkpoint/StreamingSVD/model.safetensors
|
308 |
+
ckpt_path_global: PAIR/StreamingSVD/resolve/main/model.safetensors
|
309 |
+
disable_first_stage_autocast: true
|
310 |
+
inference_params:
|
311 |
+
class_path: modules.params.diffusion.inference_params.T2VInferenceParams
|
312 |
+
init_args:
|
313 |
+
n_autoregressive_generations: 2 # Number of autoregression for StreamingSVD
|
314 |
+
num_conditional_frames: 7 # is this used?
|
315 |
+
anchor_frames: '6' # Take the (Number+1)th frame as CLIP encoding for StreamingSVD
|
316 |
+
reset_seed_per_generation: true # If true, the seed is reset on every generation
|
dataloader/dataset_factory.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
|
4 |
+
from dataloader.single_image_dataset import SingleImageDataset
|
5 |
+
|
6 |
+
|
7 |
+
class SingleImageDatasetFactory():
|
8 |
+
|
9 |
+
def __init__(self, file: Path):
|
10 |
+
self.data_path = file
|
11 |
+
|
12 |
+
def get_dataset(self, max_samples: int = None) -> Dataset:
|
13 |
+
return SingleImageDataset(file=self.data_path)
|
dataloader/single_image_dataset.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
|
5 |
+
|
6 |
+
class SingleImageDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, file: np.ndarray):
|
9 |
+
super().__init__()
|
10 |
+
self.images = [file]
|
11 |
+
|
12 |
+
def __len__(self):
|
13 |
+
return len(self.images)
|
14 |
+
|
15 |
+
def __getitem__(self, index):
|
16 |
+
return {"image": self.images[index], "sample_id": torch.tensor(index, dtype=torch.int64)}
|
dataloader/video_data_module.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
from pytorch_lightning.utilities.types import (EVAL_DATALOADERS)
|
4 |
+
from dataloader.dataset_factory import SingleImageDatasetFactory
|
5 |
+
|
6 |
+
|
7 |
+
class VideoDataModule(pl.LightningDataModule):
|
8 |
+
|
9 |
+
def __init__(self,
|
10 |
+
workers: int,
|
11 |
+
predict_dataset_factory: SingleImageDatasetFactory = None,
|
12 |
+
) -> None:
|
13 |
+
super().__init__()
|
14 |
+
self.num_workers = workers
|
15 |
+
|
16 |
+
self.video_data_module = {}
|
17 |
+
# TODO read size from loaded unet via unet.sample_sizes
|
18 |
+
self.predict_dataset_factory = predict_dataset_factory
|
19 |
+
|
20 |
+
def setup(self, stage: str) -> None:
|
21 |
+
if stage == "predict":
|
22 |
+
self.video_data_module["predict"] = self.predict_dataset_factory.get_dataset(
|
23 |
+
)
|
24 |
+
|
25 |
+
def predict_dataloader(self) -> EVAL_DATALOADERS:
|
26 |
+
return torch.utils.data.DataLoader(self.video_data_module["predict"],
|
27 |
+
batch_size=1,
|
28 |
+
pin_memory=True,
|
29 |
+
num_workers=self.num_workers,
|
30 |
+
collate_fn=None,
|
31 |
+
shuffle=False,
|
32 |
+
drop_last=False)
|
diffusion_trainer/abstract_trainer.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
from modules.params.diffusion.inference_params import InferenceParams
|
9 |
+
from modules.loader.module_loader import GenericModuleLoader
|
10 |
+
from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams
|
11 |
+
|
12 |
+
|
13 |
+
class AbstractTrainer(pl.LightningModule):
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
inference_params: Any,
|
17 |
+
diff_trainer_params: DiffusionTrainerParams,
|
18 |
+
module_loader: GenericModuleLoader,
|
19 |
+
):
|
20 |
+
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
self.inference_params = inference_params
|
24 |
+
self.diff_trainer_params = diff_trainer_params
|
25 |
+
self.module_loader = module_loader
|
26 |
+
|
27 |
+
self.on_start_once_called = False
|
28 |
+
self._setup_methods = []
|
29 |
+
|
30 |
+
module_loader(
|
31 |
+
trainer=self,
|
32 |
+
diff_trainer_params=diff_trainer_params)
|
33 |
+
|
34 |
+
# ------ IMPLEMENTATION HOOKS -------
|
35 |
+
|
36 |
+
def post_init(self, batch):
|
37 |
+
'''
|
38 |
+
Is called after LightningDataModule and LightningModule is created, but before any training/validation/prediction.
|
39 |
+
First possible access to the 'trainer' object (e.g. to get 'device').
|
40 |
+
'''
|
41 |
+
|
42 |
+
def generate_output(self, batch, batch_idx, inference_params: InferenceParams):
|
43 |
+
'''
|
44 |
+
Is called during validation to generate for each batch an output.
|
45 |
+
Return the meta information about produced result (where result were stored).
|
46 |
+
This is used for the metric evaluation.
|
47 |
+
'''
|
48 |
+
|
49 |
+
# ------- HELPER FUNCTIONS -------
|
50 |
+
|
51 |
+
def _reset_random_generator(self):
|
52 |
+
'''
|
53 |
+
Reset the random generator to the same seed across all workers. The generator is used only for inference.
|
54 |
+
'''
|
55 |
+
if not hasattr(self, "random_generator"):
|
56 |
+
self.random_generator = torch.Generator(device=self.device)
|
57 |
+
# set seed according to 'seed_everything' in config
|
58 |
+
seed = int(os.environ.get("PL_GLOBAL_SEED", 42))
|
59 |
+
else:
|
60 |
+
seed = self.random_generator.initial_seed()
|
61 |
+
self.random_generator.manual_seed(seed)
|
62 |
+
|
63 |
+
# ----- PREDICT HOOKS ------
|
64 |
+
|
65 |
+
def on_predict_start(self):
|
66 |
+
self.on_start()
|
67 |
+
|
68 |
+
def predict_step(self, batch, batch_idx):
|
69 |
+
self.on_inference_step(batch=batch, batch_idx=batch_idx)
|
70 |
+
|
71 |
+
def on_predict_epoch_start(self):
|
72 |
+
self.on_inference_epoch_start()
|
73 |
+
|
74 |
+
# ----- CUSTOM HOOKS -----
|
75 |
+
|
76 |
+
# Global Hooks (Called by Training, Validation and Prediction)
|
77 |
+
|
78 |
+
# abstract method
|
79 |
+
|
80 |
+
def _on_start_once(self):
|
81 |
+
'''
|
82 |
+
Will be called only once by on_start. Thus, it will be called by the first call of train,validation or prediction.
|
83 |
+
'''
|
84 |
+
if self.on_start_once_called:
|
85 |
+
return
|
86 |
+
else:
|
87 |
+
self.on_start_once_called = True
|
88 |
+
self.post_init()
|
89 |
+
|
90 |
+
def on_start(self):
|
91 |
+
'''
|
92 |
+
Called at the beginning of training, validation and prediction.
|
93 |
+
'''
|
94 |
+
self._on_start_once()
|
95 |
+
|
96 |
+
# Inference Hooks (Called by Validation and Prediction)
|
97 |
+
|
98 |
+
# ----- Inference Hooks (called by 'validation' and 'predict') ------
|
99 |
+
|
100 |
+
def on_inference_epoch_start(self):
|
101 |
+
# reset seed at every inference
|
102 |
+
self._reset_random_generator()
|
103 |
+
|
104 |
+
def on_inference_step(self, batch, batch_idx):
|
105 |
+
if self.inference_params.reset_seed_per_generation:
|
106 |
+
self._reset_random_generator()
|
107 |
+
self.generate_output(
|
108 |
+
batch=batch, inference_params=self.inference_params, batch_idx=batch_idx)
|
diffusion_trainer/streaming_svd.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.loader.module_loader import GenericModuleLoader
|
2 |
+
from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams
|
3 |
+
import torch
|
4 |
+
from modules.params.diffusion.inference_params import InferenceParams
|
5 |
+
from utils import result_processor
|
6 |
+
from modules.loader.module_loader import GenericModuleLoader
|
7 |
+
from tqdm import tqdm
|
8 |
+
from PIL import Image, ImageFilter
|
9 |
+
from utils.inference_utils import resize_and_crop,get_padding_for_aspect_ratio
|
10 |
+
import numpy as np
|
11 |
+
from safetensors.torch import load_file as load_safetensors
|
12 |
+
import math
|
13 |
+
from einops import repeat, rearrange
|
14 |
+
from torchvision.transforms import ToTensor
|
15 |
+
from models.svd.sgm.modules.autoencoding.temporal_ae import VideoDecoder
|
16 |
+
import PIL
|
17 |
+
from modules.params.vfi import VFIParams
|
18 |
+
from modules.params.i2v_enhance import I2VEnhanceParams
|
19 |
+
from typing import List,Union
|
20 |
+
from models.diffusion.wrappers import StreamingWrapper
|
21 |
+
from diffusion_trainer.abstract_trainer import AbstractTrainer
|
22 |
+
from utils.loader import download_ckpt
|
23 |
+
import torchvision.transforms.functional as TF
|
24 |
+
from diffusers import AutoPipelineForInpainting, DEISMultistepScheduler
|
25 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
26 |
+
|
27 |
+
class StreamingSVD(AbstractTrainer):
|
28 |
+
def __init__(self,
|
29 |
+
module_loader: GenericModuleLoader,
|
30 |
+
diff_trainer_params: DiffusionTrainerParams,
|
31 |
+
inference_params: InferenceParams,
|
32 |
+
vfi: VFIParams,
|
33 |
+
i2v_enhance: I2VEnhanceParams,
|
34 |
+
):
|
35 |
+
super().__init__(inference_params=inference_params,
|
36 |
+
diff_trainer_params=diff_trainer_params,
|
37 |
+
module_loader=module_loader,
|
38 |
+
)
|
39 |
+
|
40 |
+
# network config is wrapped by OpenAIWrapper, so we dont need a direct reference anymore
|
41 |
+
# this corresponds to the config yaml defined at model.module_loader.module_config.model.dependent_modules
|
42 |
+
del self.network_config
|
43 |
+
self.diff_trainer_params: DiffusionTrainerParams
|
44 |
+
self.vfi = vfi
|
45 |
+
self.i2v_enhance = i2v_enhance
|
46 |
+
|
47 |
+
def on_inference_epoch_start(self):
|
48 |
+
super().on_inference_epoch_start()
|
49 |
+
|
50 |
+
# for StreamingSVD we use a model wrapper that combines the base SVD model and the control model.
|
51 |
+
self.inference_model = StreamingWrapper(
|
52 |
+
diffusion_model=self.model.diffusion_model,
|
53 |
+
controlnet=self.controlnet,
|
54 |
+
num_frame_conditioning=self.inference_params.num_conditional_frames
|
55 |
+
)
|
56 |
+
|
57 |
+
def post_init(self):
|
58 |
+
self.svd_pipeline.set_progress_bar_config(disable=True)
|
59 |
+
if self.device.type != "cpu":
|
60 |
+
self.svd_pipeline.enable_model_cpu_offload(gpu_id = self.device.index)
|
61 |
+
|
62 |
+
# re-use the open clip already loaded for image conditioner for image_encoder_apm
|
63 |
+
embedders = self.conditioner.embedders
|
64 |
+
for embedder in embedders:
|
65 |
+
if hasattr(embedder,"input_key") and embedder.input_key == "cond_frames_without_noise":
|
66 |
+
self.image_encoder_apm = embedder.open_clip
|
67 |
+
self.first_stage_model.to("cpu")
|
68 |
+
self.conditioner.embedders[3].encoder.to("cpu")
|
69 |
+
self.conditioner.embedders[0].open_clip.to("cpu")
|
70 |
+
|
71 |
+
pipe = AutoPipelineForInpainting.from_pretrained(
|
72 |
+
'Lykon/dreamshaper-8-inpainting', torch_dtype=torch.float16, variant="fp16", safety_checker=None, requires_safety_checker=False)
|
73 |
+
|
74 |
+
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
|
75 |
+
pipe = pipe.to(self.device)
|
76 |
+
pipe.enable_model_cpu_offload(gpu_id = self.device.index)
|
77 |
+
self.inpaint_pipe = pipe
|
78 |
+
|
79 |
+
processor = BlipProcessor.from_pretrained(
|
80 |
+
"Salesforce/blip-image-captioning-large")
|
81 |
+
|
82 |
+
|
83 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
84 |
+
"Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to(self.device)
|
85 |
+
def blip(x): return processor.decode(model.generate(** processor(x,
|
86 |
+
return_tensors='pt').to("cuda", torch.float16))[0], skip_special_tokens=True)
|
87 |
+
self.blip = blip
|
88 |
+
|
89 |
+
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
|
90 |
+
def get_unique_embedder_keys_from_conditioner(self, conditioner):
|
91 |
+
return list(set([x.input_key for x in conditioner.embedders]))
|
92 |
+
|
93 |
+
|
94 |
+
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
|
95 |
+
def get_batch_sgm(self, keys, value_dict, N, T, device):
|
96 |
+
batch = {}
|
97 |
+
batch_uc = {}
|
98 |
+
|
99 |
+
for key in keys:
|
100 |
+
if key == "fps_id":
|
101 |
+
batch[key] = (
|
102 |
+
torch.tensor([value_dict["fps_id"]])
|
103 |
+
.to(device)
|
104 |
+
.repeat(int(math.prod(N)))
|
105 |
+
)
|
106 |
+
elif key == "motion_bucket_id":
|
107 |
+
batch[key] = (
|
108 |
+
torch.tensor([value_dict["motion_bucket_id"]])
|
109 |
+
.to(device)
|
110 |
+
.repeat(int(math.prod(N)))
|
111 |
+
)
|
112 |
+
elif key == "cond_aug":
|
113 |
+
batch[key] = repeat(
|
114 |
+
torch.tensor([value_dict["cond_aug"]]).to(device),
|
115 |
+
"1 -> b",
|
116 |
+
b=math.prod(N),
|
117 |
+
)
|
118 |
+
elif key == "cond_frames":
|
119 |
+
batch[key] = repeat(value_dict["cond_frames"],
|
120 |
+
"1 ... -> b ...", b=N[0])
|
121 |
+
elif key == "cond_frames_without_noise":
|
122 |
+
batch[key] = repeat(
|
123 |
+
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
batch[key] = value_dict[key]
|
127 |
+
|
128 |
+
if T is not None:
|
129 |
+
batch["num_video_frames"] = T
|
130 |
+
|
131 |
+
for key in batch.keys():
|
132 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
133 |
+
batch_uc[key] = torch.clone(batch[key])
|
134 |
+
return batch, batch_uc
|
135 |
+
|
136 |
+
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/sgm/models/diffusion.py
|
137 |
+
@torch.no_grad()
|
138 |
+
def decode_first_stage(self, z):
|
139 |
+
self.first_stage_model.to(self.device)
|
140 |
+
|
141 |
+
z = 1.0 / self.diff_trainer_params.scale_factor * z
|
142 |
+
#n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
143 |
+
n_samples = min(z.shape[0],8)
|
144 |
+
#print("SVD decoder started")
|
145 |
+
import time
|
146 |
+
start = time.time()
|
147 |
+
n_rounds = math.ceil(z.shape[0] / n_samples)
|
148 |
+
all_out = []
|
149 |
+
with torch.autocast("cuda", enabled=not self.diff_trainer_params.disable_first_stage_autocast):
|
150 |
+
for n in range(n_rounds):
|
151 |
+
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
152 |
+
kwargs = {"timesteps": len(
|
153 |
+
z[n * n_samples: (n + 1) * n_samples])}
|
154 |
+
else:
|
155 |
+
kwargs = {}
|
156 |
+
out = self.first_stage_model.decode(
|
157 |
+
z[n * n_samples: (n + 1) * n_samples], **kwargs
|
158 |
+
)
|
159 |
+
all_out.append(out)
|
160 |
+
out = torch.cat(all_out, dim=0)
|
161 |
+
# print(f"SVD decoder finished after {time.time()-start} seconds.")
|
162 |
+
self.first_stage_model.to("cpu")
|
163 |
+
return out
|
164 |
+
|
165 |
+
|
166 |
+
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/scripts/sampling/simple_video_sample.py
|
167 |
+
def _generate_conditional_output(self, svd_input_frame, inference_params: InferenceParams, **params):
|
168 |
+
C = 4
|
169 |
+
F = 8 # spatial compression TODO read from model
|
170 |
+
|
171 |
+
H = svd_input_frame.shape[-2]
|
172 |
+
W = svd_input_frame.shape[-1]
|
173 |
+
num_frames = self.sampler.guider.num_frames
|
174 |
+
|
175 |
+
shape = (num_frames, C, H // F, W // F)
|
176 |
+
batch_size = 1
|
177 |
+
|
178 |
+
image = svd_input_frame[None,:]
|
179 |
+
cond_aug = 0.02
|
180 |
+
|
181 |
+
value_dict = {}
|
182 |
+
value_dict["motion_bucket_id"] = 127
|
183 |
+
value_dict["fps_id"] = 6
|
184 |
+
value_dict["cond_aug"] = cond_aug
|
185 |
+
value_dict["cond_frames_without_noise"] = image
|
186 |
+
value_dict["cond_frames"] =image + cond_aug * torch.rand_like(image)
|
187 |
+
|
188 |
+
batch, batch_uc = self.get_batch_sgm(
|
189 |
+
self.get_unique_embedder_keys_from_conditioner(
|
190 |
+
self.conditioner),
|
191 |
+
value_dict,
|
192 |
+
[1, num_frames],
|
193 |
+
T=num_frames,
|
194 |
+
device=self.device,
|
195 |
+
)
|
196 |
+
|
197 |
+
self.conditioner.embedders[3].encoder.to(self.device)
|
198 |
+
self.conditioner.embedders[0].open_clip.to(self.device)
|
199 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
200 |
+
batch,
|
201 |
+
batch_uc=batch_uc,
|
202 |
+
force_uc_zero_embeddings=[
|
203 |
+
"cond_frames",
|
204 |
+
"cond_frames_without_noise",
|
205 |
+
],
|
206 |
+
)
|
207 |
+
self.conditioner.embedders[3].encoder.to("cpu")
|
208 |
+
self.conditioner.embedders[0].open_clip.to("cpu")
|
209 |
+
|
210 |
+
|
211 |
+
for k in ["crossattn", "concat"]:
|
212 |
+
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
|
213 |
+
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
|
214 |
+
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
|
215 |
+
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
|
216 |
+
|
217 |
+
randn = torch.randn(shape, device=self.device)
|
218 |
+
|
219 |
+
additional_model_inputs = {}
|
220 |
+
additional_model_inputs["image_only_indicator"] = torch.zeros(2*batch_size,num_frames).to(self.device)
|
221 |
+
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
|
222 |
+
|
223 |
+
# StreamingSVD inputs
|
224 |
+
additional_model_inputs["batch_size"] = 2*batch_size
|
225 |
+
additional_model_inputs["num_conditional_frames"] = self.inference_params.num_conditional_frames
|
226 |
+
additional_model_inputs["ctrl_frames"] = params["ctrl_frames"]
|
227 |
+
|
228 |
+
self.inference_model.diffusion_model = self.inference_model.diffusion_model.to(
|
229 |
+
self.device)
|
230 |
+
self.inference_model.controlnet = self.inference_model.controlnet.to(
|
231 |
+
self.device)
|
232 |
+
|
233 |
+
c["vector"] = c["vector"].to(randn.dtype)
|
234 |
+
uc["vector"] = uc["vector"].to(randn.dtype)
|
235 |
+
def denoiser(input, sigma, c):
|
236 |
+
return self.denoiser(self.inference_model,input,sigma,c, **additional_model_inputs)
|
237 |
+
samples_z = self.sampler(denoiser,randn,cond=c,uc=uc)
|
238 |
+
|
239 |
+
self.inference_model.diffusion_model = self.inference_model.diffusion_model.to( "cpu")
|
240 |
+
self.inference_model.controlnet = self.inference_model.controlnet.to("cpu")
|
241 |
+
samples_x = self.decode_first_stage(samples_z)
|
242 |
+
|
243 |
+
samples = torch.clamp(samples_x,min=-1.0,max=1.0)
|
244 |
+
return samples
|
245 |
+
|
246 |
+
|
247 |
+
def extract_anchor_frames(self, video, input_range,inference_params: InferenceParams):
|
248 |
+
"""
|
249 |
+
Extracts anchor frames from the input video based on the provided inference parameters.
|
250 |
+
|
251 |
+
Parameters:
|
252 |
+
- video: torch.Tensor
|
253 |
+
The input video tensor.
|
254 |
+
- input_range: list
|
255 |
+
The pixel value range of input video.
|
256 |
+
- inference_params: InferenceParams
|
257 |
+
An object containing inference parameters.
|
258 |
+
- anchor_frames: str
|
259 |
+
Specifies how the anchor frames are encoded. It can be either a single number specifying which frame is used as the anchor frame,
|
260 |
+
or a range in the format "a:b" indicating that frames from index a up to index b (inclusive) are used as anchor frames.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
- torch.Tensor
|
264 |
+
The extracted anchor frames from the input video.
|
265 |
+
"""
|
266 |
+
video = result_processor.convert_range(video=video.clone(),input_range=input_range,output_range=[-1,1])
|
267 |
+
|
268 |
+
if video.shape[1] == 3 and video.shape[0]>3:
|
269 |
+
video = rearrange(video,"F C W H -> 1 F C W H")
|
270 |
+
elif video.shape[0]>3 and video.shape[-1] == 3:
|
271 |
+
video = rearrange(video,"F W H C -> 1 F C W H")
|
272 |
+
else:
|
273 |
+
raise NotImplementedError(f"Unexpected video input format: {video.shape}")
|
274 |
+
|
275 |
+
if ":" in inference_params.anchor_frames:
|
276 |
+
anchor_frames = inference_params.anchor_frames.split(":")
|
277 |
+
anchor_frames = [int(anchor_frame) for anchor_frame in anchor_frames]
|
278 |
+
assert len(anchor_frames) == 2,"Anchor frames encoding wrong."
|
279 |
+
anchor = video[:,anchor_frames[0]:anchor_frames[1]]
|
280 |
+
else:
|
281 |
+
anchor_frame = int(inference_params.anchor_frames)
|
282 |
+
anchor = video[:, anchor_frame].unsqueeze(0)
|
283 |
+
|
284 |
+
return anchor
|
285 |
+
|
286 |
+
def extract_ctrl_frames(self,video: torch.FloatType, input_range: List[int], inference_params: InferenceParams):
|
287 |
+
"""
|
288 |
+
Extracts control frames from the input video.
|
289 |
+
|
290 |
+
Parameters:
|
291 |
+
- video: torch.Tensor
|
292 |
+
The input video tensor.
|
293 |
+
- input_range: list
|
294 |
+
The pixel value range of input video.
|
295 |
+
- inference_params: InferenceParams
|
296 |
+
An object containing inference parameters.
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
- torch.Tensor
|
300 |
+
The extracted control image encoding frames from the input video.
|
301 |
+
"""
|
302 |
+
video = result_processor.convert_range(video=video.clone(), input_range=input_range, output_range=[-1, 1])
|
303 |
+
if video.shape[1] == 3 and video.shape[0] > 3:
|
304 |
+
video = rearrange(video, "F C W H -> 1 F C W H")
|
305 |
+
elif video.shape[0] > 3 and video.shape[-1] == 3:
|
306 |
+
video = rearrange(video, "F W H C -> 1 F C W H")
|
307 |
+
else:
|
308 |
+
raise NotImplementedError(
|
309 |
+
f"Unexpected video input format: {video.shape}")
|
310 |
+
|
311 |
+
# return the last num_conditional_frames frames
|
312 |
+
video = video[:, -inference_params.num_conditional_frames:]
|
313 |
+
return video
|
314 |
+
|
315 |
+
|
316 |
+
def _autoregressive_generation(self,initial_generation: Union[torch.FloatType,List[torch.FloatType]], inference_params:InferenceParams):
|
317 |
+
"""
|
318 |
+
Perform autoregressive generation of video chunks based on the initial generation and inference parameters.
|
319 |
+
|
320 |
+
Parameters:
|
321 |
+
- initial_generation: torch.Tensor or list of torch.Tensor
|
322 |
+
The initial generation or list of initial generation video chunks.
|
323 |
+
- inference_params: InferenceParams
|
324 |
+
An object containing inference parameters.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
- torch.Tensor
|
328 |
+
The generated video resulting from autoregressive generation.
|
329 |
+
"""
|
330 |
+
|
331 |
+
# input is [-1,1] float
|
332 |
+
result_chunks = initial_generation
|
333 |
+
if not isinstance(result_chunks,list):
|
334 |
+
result_chunks = [result_chunks]
|
335 |
+
|
336 |
+
# make sure
|
337 |
+
if (result_chunks[0].shape[1] >3) and (result_chunks[0].shape[-1] == 3):
|
338 |
+
result_chunks = [rearrange(result_chunks[0],"F W H C -> F C W H")]
|
339 |
+
|
340 |
+
# generating chunk by conditioning on the previous chunks
|
341 |
+
for _ in tqdm(list(range(inference_params.n_autoregressive_generations)),desc="StreamingSVD"):
|
342 |
+
|
343 |
+
# extract anchor frames based on the entire, so far generated, video
|
344 |
+
# note that we do note use anchor frame in StreamingSVD (apart from the anchor frame already used by SVD).
|
345 |
+
anchor_frames = self.extract_anchor_frames(
|
346 |
+
video = torch.cat(result_chunks),
|
347 |
+
inference_params=inference_params,
|
348 |
+
input_range=[-1, 1],
|
349 |
+
)
|
350 |
+
|
351 |
+
# extract control frames based on the last generated chunk
|
352 |
+
ctrl_frames = self.extract_ctrl_frames(
|
353 |
+
video = result_chunks[-1],
|
354 |
+
input_range=[-1, 1],
|
355 |
+
inference_params=inference_params,
|
356 |
+
)
|
357 |
+
|
358 |
+
# select the anchor frame for svd
|
359 |
+
svd_input_frame = result_chunks[0][int(inference_params.anchor_frames)]
|
360 |
+
|
361 |
+
# generate the next chunk
|
362 |
+
# result is [F, C, H, W], range is [-1,1] float.
|
363 |
+
result = self._generate_conditional_output(
|
364 |
+
svd_input_frame = svd_input_frame,
|
365 |
+
inference_params=inference_params,
|
366 |
+
anchor_frames=anchor_frames,
|
367 |
+
ctrl_frames=ctrl_frames,
|
368 |
+
)
|
369 |
+
|
370 |
+
# from each generation, we keep all frames except for the first <num_conditional_frames> frames
|
371 |
+
result = result[inference_params.num_conditional_frames:]
|
372 |
+
result_chunks.append(result)
|
373 |
+
torch.cuda.empty_cache()
|
374 |
+
|
375 |
+
# concat all chunks to one long video
|
376 |
+
result_chunks = [result_processor.convert_range(chunk,output_range=[0,255],input_range=[-1,1]) for chunk in result_chunks]
|
377 |
+
result = result_processor.concat_chunks(result_chunks)
|
378 |
+
torch.cuda.empty_cache()
|
379 |
+
return result
|
380 |
+
|
381 |
+
def ensure_image_ratio(self,source_image: PIL,target_aspect_ratio = 16/9):
|
382 |
+
|
383 |
+
if source_image.width / source_image.height == target_aspect_ratio:
|
384 |
+
return source_image, None
|
385 |
+
|
386 |
+
image = source_image.copy().convert("RGBA")
|
387 |
+
mask = image.split()[-1]
|
388 |
+
image = image.convert("RGB")
|
389 |
+
padding = get_padding_for_aspect_ratio(image)
|
390 |
+
|
391 |
+
|
392 |
+
mask_padded = TF.pad(mask, padding)
|
393 |
+
mask_padded_size = mask_padded.size
|
394 |
+
mask_padded_resized = TF.resize(mask_padded, (512, 512),
|
395 |
+
interpolation=TF.InterpolationMode.NEAREST)
|
396 |
+
mask_padded_resized = TF.invert(mask_padded_resized)
|
397 |
+
|
398 |
+
# image
|
399 |
+
padded_input_image = TF.pad(image, padding, padding_mode="reflect")
|
400 |
+
resized_image = TF.resize(padded_input_image, (512, 512))
|
401 |
+
|
402 |
+
image_tensor = (self.inpaint_pipe.image_processor.preprocess(
|
403 |
+
resized_image).cuda().half())
|
404 |
+
latent_tensor = self.inpaint_pipe._encode_vae_image(image_tensor, None)
|
405 |
+
self.inpaint_pipe.scheduler.set_timesteps(999)
|
406 |
+
noisy_latent_tensor = self.inpaint_pipe.scheduler.add_noise(
|
407 |
+
latent_tensor,
|
408 |
+
torch.randn_like(latent_tensor),
|
409 |
+
self.inpaint_pipe.scheduler.timesteps[:1],
|
410 |
+
)
|
411 |
+
|
412 |
+
prompt = self.blip(source_image)
|
413 |
+
if prompt.startswith("there is "):
|
414 |
+
prompt = prompt[len("there is "):]
|
415 |
+
|
416 |
+
output_image_normalized_size = self.inpaint_pipe(
|
417 |
+
prompt=prompt,
|
418 |
+
image=resized_image,
|
419 |
+
mask_image=mask_padded_resized,
|
420 |
+
latents=noisy_latent_tensor,
|
421 |
+
).images[0]
|
422 |
+
|
423 |
+
output_image_extended_size = TF.resize(
|
424 |
+
output_image_normalized_size, mask_padded_size[::-1])
|
425 |
+
|
426 |
+
blured_outpainting_mask = TF.invert(mask_padded).filter(
|
427 |
+
ImageFilter.GaussianBlur(radius=5))
|
428 |
+
|
429 |
+
final_image = Image.composite(
|
430 |
+
output_image_extended_size, padded_input_image, blured_outpainting_mask)
|
431 |
+
return final_image, TF.invert(mask_padded)
|
432 |
+
|
433 |
+
|
434 |
+
def image_to_video(self, batch, inference_params: InferenceParams, batch_idx):
|
435 |
+
|
436 |
+
"""
|
437 |
+
Performs image to video based on the input batch and inference parameters.
|
438 |
+
It runs SVD-XT one to generate the first chunk, then auto-regressively applies StreamingSVD.
|
439 |
+
|
440 |
+
Parameters:
|
441 |
+
- batch: dict
|
442 |
+
The input batch containing the start image for generating the video.
|
443 |
+
- inference_params: InferenceParams
|
444 |
+
An object containing inference parameters.
|
445 |
+
- batch_idx: int
|
446 |
+
The index of the batch.
|
447 |
+
|
448 |
+
Returns:
|
449 |
+
- torch.Tensor
|
450 |
+
The generated video based on the image image.
|
451 |
+
"""
|
452 |
+
batch_key = "image"
|
453 |
+
assert batch_key == "image", f"Generating video from {batch_key} not implemented."
|
454 |
+
input_image = PIL.Image.fromarray(batch[batch_key][0].cpu().numpy())
|
455 |
+
# TODO remove conversion forth and back
|
456 |
+
|
457 |
+
outpainted_image, _ = self.ensure_image_ratio(input_image)
|
458 |
+
|
459 |
+
#image = Image.fromarray(np.uint8(image))
|
460 |
+
'''
|
461 |
+
if image.width/image.height != 16/9:
|
462 |
+
print(f"Warning! For best results, we assume the aspect ratio of the input image to be 16:9. Found ratio {image.width}:{image.height}.")
|
463 |
+
'''
|
464 |
+
scaled_outpainted_image, expanded_size = resize_and_crop(outpainted_image)
|
465 |
+
assert scaled_outpainted_image.width == 1024 and scaled_outpainted_image.height == 576, f"Wrong shape for file {batch[batch_key]} with shape {scaled_outpainted_image.width}:{scaled_outpainted_image.height}."
|
466 |
+
|
467 |
+
# Generating first chunk
|
468 |
+
with torch.autocast(device_type="cuda",enabled=False):
|
469 |
+
video_chunks = self.svd_pipeline(
|
470 |
+
scaled_outpainted_image, decode_chunk_size=8).frames[0]
|
471 |
+
|
472 |
+
video_chunks = torch.stack([ToTensor()(frame) for frame in video_chunks])
|
473 |
+
video_chunks = video_chunks * 2.0 - 1 # [-1,1], float
|
474 |
+
|
475 |
+
video_chunks = video_chunks.to(self.device)
|
476 |
+
|
477 |
+
video = self._autoregressive_generation(
|
478 |
+
initial_generation=video_chunks,
|
479 |
+
inference_params=inference_params)
|
480 |
+
|
481 |
+
return video, scaled_outpainted_image, expanded_size
|
482 |
+
|
483 |
+
|
484 |
+
def generate_output(self, batch, batch_idx,inference_params: InferenceParams):
|
485 |
+
"""
|
486 |
+
Generate output video based on the input batch and inference parameters.
|
487 |
+
|
488 |
+
Parameters:
|
489 |
+
- batch: dict
|
490 |
+
The input batch containing data for generating the output video.
|
491 |
+
- batch_idx: int
|
492 |
+
The index of the batch.
|
493 |
+
- inference_params: InferenceParams
|
494 |
+
An object containing inference parameters.
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
- torch.Tensor
|
498 |
+
The generated video. Note the result is also accessible via self.trainer.generated_video
|
499 |
+
"""
|
500 |
+
|
501 |
+
sample_id = batch["sample_id"].item()
|
502 |
+
video, scaled_outpainted_image, expanded_size = self.image_to_video(
|
503 |
+
batch, inference_params=inference_params, batch_idx=sample_id)
|
504 |
+
|
505 |
+
self.trainer.generated_video = video.numpy()
|
506 |
+
self.trainer.expanded_size = expanded_size
|
507 |
+
self.trainer.scaled_outpainted_image = scaled_outpainted_image
|
508 |
+
return video
|
gradio_demo.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from utils.gradio_utils import *
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
GRADIO_CACHE = ""
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument('--public_access', action='store_true')
|
10 |
+
args = parser.parse_args()
|
11 |
+
|
12 |
+
streaming_svd = StreamingSVD(load_argv=False)
|
13 |
+
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
|
14 |
+
|
15 |
+
examples = [
|
16 |
+
["Experience the dance of jellyfish: float through mesmerizing swarms of jellyfish, pulsating with otherworldly grace and beauty.",
|
17 |
+
"200 - frames (recommended)", 33, None, None],
|
18 |
+
["Dive into the depths of the ocean: explore vibrant coral reefs, mysterious underwater caves, and the mesmerizing creatures that call the sea home.",
|
19 |
+
"200 - frames (recommended)", 33, None, None],
|
20 |
+
["A cute cat.",
|
21 |
+
"200 - frames (recommended)", 33, None, None],
|
22 |
+
["",
|
23 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test1.jpg", None],
|
24 |
+
["",
|
25 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test2.jpg", None],
|
26 |
+
["",
|
27 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test3.png", None],
|
28 |
+
["",
|
29 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test4.png", None],
|
30 |
+
["",
|
31 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test5.jpg", None],
|
32 |
+
["",
|
33 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test6.png", None],
|
34 |
+
["",
|
35 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test7.jpg", None],
|
36 |
+
["",
|
37 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test8.jpg", None],
|
38 |
+
["",
|
39 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test9.jpg", None],
|
40 |
+
["",
|
41 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test10.jpg", None],
|
42 |
+
["",
|
43 |
+
"200 - frames (recommended)", 33, "__assets__/gradio_cached_examples/test11.jpg", None],
|
44 |
+
]
|
45 |
+
|
46 |
+
def generate(prompt, num_frames, seed, image: np.ndarray):
|
47 |
+
if num_frames == [] or num_frames is None:
|
48 |
+
num_frames = 50
|
49 |
+
else:
|
50 |
+
num_frames = int(num_frames.split(" ")[0])
|
51 |
+
if num_frames > 200: # and on_huggingspace:
|
52 |
+
num_frames = 200
|
53 |
+
|
54 |
+
if image is None:
|
55 |
+
image = text_to_image_gradio(
|
56 |
+
prompt=prompt, streaming_svd=streaming_svd, seed=seed)
|
57 |
+
|
58 |
+
video_file_stage_one = image_to_video_vfi_gradio(
|
59 |
+
img=image, num_frames=num_frames, streaming_svd=streaming_svd, seed=seed, gradio_cache=GRADIO_CACHE)
|
60 |
+
|
61 |
+
expanded_size, orig_size, scaled_outpainted_image = retrieve_intermediate_data(video_file_stage_one)
|
62 |
+
|
63 |
+
video_file_stage_two = enhance_video_vfi_gradio(
|
64 |
+
img=scaled_outpainted_image, video=video_file_stage_one.replace("__cropped__", "__expanded__"), num_frames=24, streaming_svd=streaming_svd, seed=seed, expanded_size=expanded_size, orig_size=orig_size, gradio_cache=GRADIO_CACHE)
|
65 |
+
|
66 |
+
return image, video_file_stage_one, video_file_stage_two
|
67 |
+
|
68 |
+
|
69 |
+
def enhance(prompt, num_frames, seed, image: np.ndarray, video:str):
|
70 |
+
if num_frames == [] or num_frames is None:
|
71 |
+
num_frames = 50
|
72 |
+
else:
|
73 |
+
num_frames = int(num_frames.split(" ")[0])
|
74 |
+
if num_frames > 200: # and on_huggingspace:
|
75 |
+
num_frames = 200
|
76 |
+
|
77 |
+
# User directly applied Long Video Generation (without preview) with Flux.
|
78 |
+
if image is None:
|
79 |
+
image = text_to_image_gradio(
|
80 |
+
prompt=prompt, streaming_svd=streaming_svd, seed=seed)
|
81 |
+
|
82 |
+
# User directly applied Long Video Generation (without preview) with or without Flux.
|
83 |
+
if video is None:
|
84 |
+
video = image_to_video_gradio(
|
85 |
+
img=image, num_frames=(num_frames+1) // 2, streaming_svd=streaming_svd, seed=seed, gradio_cache=GRADIO_CACHE)
|
86 |
+
expanded_size, orig_size, scaled_outpainted_image = retrieve_intermediate_data(video)
|
87 |
+
|
88 |
+
# Here the video is path and image is numpy array
|
89 |
+
video_file_stage_two = enhance_video_vfi_gradio(
|
90 |
+
img=scaled_outpainted_image, video=video.replace("__cropped__", "__expanded__"), num_frames=num_frames, streaming_svd=streaming_svd, seed=seed, expanded_size=expanded_size, orig_size=orig_size, gradio_cache=GRADIO_CACHE)
|
91 |
+
|
92 |
+
return image, video_file_stage_two
|
93 |
+
|
94 |
+
|
95 |
+
with gr.Blocks() as demo:
|
96 |
+
GRADIO_CACHE = demo.GRADIO_CACHE
|
97 |
+
gr.HTML("""
|
98 |
+
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
|
99 |
+
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
|
100 |
+
<a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">StreamingSVD</a>
|
101 |
+
</h1>
|
102 |
+
<h2 style="font-weight: 650; font-size: 2rem; margin: 0rem">
|
103 |
+
A StreamingT2V method for high-quality long video generation
|
104 |
+
</h2>
|
105 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
106 |
+
Roberto Henschel<sup>1*</sup>, Levon Khachatryan<sup>1*</sup>, Daniil Hayrapetyan<sup>1*</sup>, Hayk Poghosyan<sup>1</sup>, Vahram Tadevosyan<sup>1</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>, <a href="https://www.humphreyshi.com/" style="color:blue;">Humphrey Shi</a><sup>1,3</sup>
|
107 |
+
</h2>
|
108 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
109 |
+
<sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UT Austin, <sup>3</sup>SHI Labs @ Georgia Tech, Oregon & UIUC
|
110 |
+
</h2>
|
111 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
112 |
+
*Equal Contribution
|
113 |
+
</h2>
|
114 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
115 |
+
[<a href="https://arxiv.org/abs/2403.14773" style="color:blue;">arXiv</a>]
|
116 |
+
[<a href="https://github.com/Picsart-AI-Research/StreamingT2V" style="color:blue;">GitHub</a>]
|
117 |
+
</h2>
|
118 |
+
<h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
|
119 |
+
<b>StreamingSVD</b> is an advanced autoregressive technique for text-to-video and image-to-video generation,
|
120 |
+
generating long hiqh-quality videos with rich motion dynamics, turning SVD into a long video generator.
|
121 |
+
Our method ensures temporal consistency throughout the video, aligns closely to the input text/image,
|
122 |
+
and maintains high frame-level image quality. Our demonstrations include successful examples of videos
|
123 |
+
up to 200 frames, spanning 8 seconds, and can be extended for even longer durations.
|
124 |
+
</h2>
|
125 |
+
</div>
|
126 |
+
""")
|
127 |
+
|
128 |
+
if on_huggingspace:
|
129 |
+
gr.HTML("""
|
130 |
+
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
131 |
+
<br/>
|
132 |
+
<a href="https://huggingface.co/spaces/PAIR/StreamingT2V?duplicate=true">
|
133 |
+
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
134 |
+
</p>""")
|
135 |
+
|
136 |
+
with gr.Row():
|
137 |
+
with gr.Column(scale=1):
|
138 |
+
with gr.Row():
|
139 |
+
with gr.Column():
|
140 |
+
with gr.Row():
|
141 |
+
num_frames = gr.Dropdown(["50 - frames (recommended)", "80 - frames (recommended)", "140 - frames (recommended)", "200 - frames (recommended)", "500 - frames", "1000 - frames", "10000 - frames"],
|
142 |
+
label="Number of Video Frames", info="For >200 frames use local workstation!", value="50 - frames (recommended)")
|
143 |
+
with gr.Row():
|
144 |
+
prompt_stage1 = gr.Textbox(label='Text-to-Video (Enter text prompt here)',
|
145 |
+
interactive=True, max_lines=1)
|
146 |
+
with gr.Row():
|
147 |
+
image_stage1 = gr.Image(label='Image-to-Video (Upload Image here, text prompt will be ignored for I2V if entered)',
|
148 |
+
show_label=True, show_download_button=True, interactive=True, height=250)
|
149 |
+
with gr.Column():
|
150 |
+
video_stage1 = gr.Video(label='Long Video Preview', show_label=True,
|
151 |
+
interactive=False, show_download_button=True, height=203)
|
152 |
+
with gr.Row():
|
153 |
+
run_button_stage1 = gr.Button("Long Video Generation (faster preview)")
|
154 |
+
with gr.Row():
|
155 |
+
with gr.Column():
|
156 |
+
with gr.Accordion('Advanced options', open=False):
|
157 |
+
seed = gr.Slider(label='Seed', minimum=0,
|
158 |
+
maximum=65536, value=33, step=1,)
|
159 |
+
|
160 |
+
with gr.Column(scale=3):
|
161 |
+
with gr.Row():
|
162 |
+
video_stage2 = gr.Video(label='High-Quality Long Video (Preview or Full)', show_label=True,
|
163 |
+
interactive=False, show_download_button=True, height=700)
|
164 |
+
with gr.Row():
|
165 |
+
run_button_stage2 = gr.Button("Long Video Generation (full high-quality)")
|
166 |
+
|
167 |
+
inputs_t2v = [prompt_stage1, num_frames,
|
168 |
+
seed, image_stage1]
|
169 |
+
inputs_v2v = [prompt_stage1, num_frames, seed,
|
170 |
+
image_stage1, video_stage1]
|
171 |
+
|
172 |
+
run_button_stage1.click(fn=generate, inputs=inputs_t2v,
|
173 |
+
outputs=[image_stage1, video_stage1, video_stage2])
|
174 |
+
run_button_stage2.click(fn=enhance, inputs=inputs_v2v,
|
175 |
+
outputs=[image_stage1, video_stage2])
|
176 |
+
|
177 |
+
|
178 |
+
gr.Examples(examples=examples,
|
179 |
+
inputs=inputs_v2v,
|
180 |
+
outputs=[image_stage1, video_stage2],
|
181 |
+
fn=enhance,
|
182 |
+
cache_examples=True,
|
183 |
+
run_on_click=False,
|
184 |
+
)
|
185 |
+
|
186 |
+
|
187 |
+
'''
|
188 |
+
'''
|
189 |
+
gr.HTML("""
|
190 |
+
<div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
|
191 |
+
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
|
192 |
+
<b>Version: v1.0</b>
|
193 |
+
</h3>
|
194 |
+
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
|
195 |
+
<b>Caution</b>:
|
196 |
+
We would like the raise the awareness of users of this demo of its potential issues and concerns.
|
197 |
+
Like previous large foundation models, StreamingSVD could be problematic in some cases, partially we use pretrained ModelScope, therefore StreamingSVD can Inherit Its Imperfections.
|
198 |
+
So far, we keep all features available for research testing both to show the great potential of the StreamingSVD framework and to collect important feedback to improve the model in the future.
|
199 |
+
We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
|
200 |
+
</h3>
|
201 |
+
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
|
202 |
+
<b>Biases and content acknowledgement</b>:
|
203 |
+
Beware that StreamingSVD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence.
|
204 |
+
StreamingSVD in this demo is meant only for research purposes.
|
205 |
+
</h3>
|
206 |
+
</div>
|
207 |
+
""")
|
208 |
+
|
209 |
+
|
210 |
+
if on_huggingspace:
|
211 |
+
demo.queue(max_size=20)
|
212 |
+
demo.launch(debug=True)
|
213 |
+
else:
|
214 |
+
demo.queue(api_open=False).launch(share=args.public_access)
|
i2v_enhance/i2v_enhance_interface.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from i2v_enhance.pipeline_i2vgen_xl import I2VGenXLPipeline
|
3 |
+
from tqdm import tqdm
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from einops import rearrange
|
7 |
+
import i2v_enhance.thirdparty.VFI.config as cfg
|
8 |
+
from i2v_enhance.thirdparty.VFI.Trainer import Model as VFI
|
9 |
+
from pathlib import Path
|
10 |
+
from modules.params.vfi import VFIParams
|
11 |
+
from modules.params.i2v_enhance import I2VEnhanceParams
|
12 |
+
from utils.loader import download_ckpt
|
13 |
+
|
14 |
+
|
15 |
+
def vfi_init(ckpt_cfg: VFIParams, device_id=0):
|
16 |
+
cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config(F=32, depth=[
|
17 |
+
2, 2, 2, 4, 4])
|
18 |
+
vfi = VFI(-1)
|
19 |
+
|
20 |
+
ckpt_file = Path(download_ckpt(
|
21 |
+
local_path=ckpt_cfg.ckpt_path_local, global_path=ckpt_cfg.ckpt_path_global))
|
22 |
+
|
23 |
+
vfi.load_model(ckpt_file.as_posix())
|
24 |
+
vfi.eval()
|
25 |
+
vfi.device()
|
26 |
+
assert device_id == 0, "VFI on rank!=0 not implemented yet."
|
27 |
+
return vfi
|
28 |
+
|
29 |
+
|
30 |
+
def vfi_process(video, vfi, video_len):
|
31 |
+
video = video[:(video_len//2+1)]
|
32 |
+
|
33 |
+
video = [i[:, :, :3]/255. for i in video]
|
34 |
+
video = [i[:, :, ::-1] for i in video]
|
35 |
+
video = np.stack(video, axis=0)
|
36 |
+
video = rearrange(torch.from_numpy(video),
|
37 |
+
'b h w c -> b c h w').to("cuda", torch.float32)
|
38 |
+
|
39 |
+
frames = []
|
40 |
+
for i in tqdm(range(video.shape[0]-1), desc="VFI"):
|
41 |
+
I0_ = video[i:i+1, ...]
|
42 |
+
I2_ = video[i+1:i+2, ...]
|
43 |
+
frames.append((I0_[0].detach().cpu().numpy().transpose(
|
44 |
+
1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
|
45 |
+
|
46 |
+
mid = (vfi.inference(I0_, I2_, TTA=True, fast_TTA=True)[
|
47 |
+
0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)
|
48 |
+
frames.append(mid[:, :, ::-1])
|
49 |
+
|
50 |
+
frames.append((video[-1].detach().cpu().numpy().transpose(1,
|
51 |
+
2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
|
52 |
+
if video_len % 2 == 0:
|
53 |
+
frames.append((video[-1].detach().cpu().numpy().transpose(1,
|
54 |
+
2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
|
55 |
+
|
56 |
+
del vfi
|
57 |
+
del video
|
58 |
+
torch.cuda.empty_cache()
|
59 |
+
|
60 |
+
video = [Image.fromarray(frame).resize((1280, 720)) for frame in frames]
|
61 |
+
del frames
|
62 |
+
return video
|
63 |
+
|
64 |
+
|
65 |
+
def i2v_enhance_init(i2vgen_cfg: I2VEnhanceParams):
|
66 |
+
generator = torch.manual_seed(8888)
|
67 |
+
try:
|
68 |
+
pipeline = I2VGenXLPipeline.from_pretrained(
|
69 |
+
i2vgen_cfg.ckpt_path_local, torch_dtype=torch.float16, variant="fp16")
|
70 |
+
except Exception as e:
|
71 |
+
pipeline = I2VGenXLPipeline.from_pretrained(
|
72 |
+
i2vgen_cfg.ckpt_path_global, torch_dtype=torch.float16, variant="fp16")
|
73 |
+
pipeline.save_pretrained(i2vgen_cfg.ckpt_path_local)
|
74 |
+
pipeline.enable_model_cpu_offload()
|
75 |
+
return pipeline, generator
|
76 |
+
|
77 |
+
|
78 |
+
def i2v_enhance_process(image, video, pipeline, generator, overlap_size, strength, chunk_size=38, use_randomized_blending=False):
|
79 |
+
prompt = "High Quality, HQ, detailed."
|
80 |
+
negative_prompt = "Distorted, blurry, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
|
81 |
+
|
82 |
+
if use_randomized_blending:
|
83 |
+
# We first need to enhance key-frames (the 1st frame of each chunk)
|
84 |
+
video_chunks = [video[i:i+chunk_size] for i in range(0, len(
|
85 |
+
video), chunk_size-overlap_size) if len(video[i:i+chunk_size]) == chunk_size]
|
86 |
+
video_short = [chunk[0] for chunk in video_chunks]
|
87 |
+
|
88 |
+
# If randomized blending then we must have a list of starting images (1 for each chunk)
|
89 |
+
image = pipeline(
|
90 |
+
prompt=prompt,
|
91 |
+
height=720,
|
92 |
+
width=1280,
|
93 |
+
image=image,
|
94 |
+
video=video_short,
|
95 |
+
strength=strength,
|
96 |
+
overlap_size=0,
|
97 |
+
chunk_size=len(video_short),
|
98 |
+
num_frames=len(video_short),
|
99 |
+
num_inference_steps=30,
|
100 |
+
decode_chunk_size=1,
|
101 |
+
negative_prompt=negative_prompt,
|
102 |
+
guidance_scale=9.0,
|
103 |
+
generator=generator,
|
104 |
+
).frames[0]
|
105 |
+
|
106 |
+
# Remove the last few frames (< chunk_size) of the video that do not fit into one chunk.
|
107 |
+
max_idx = (chunk_size - overlap_size) * \
|
108 |
+
(len(video_chunks) - 1) + chunk_size
|
109 |
+
video = video[:max_idx]
|
110 |
+
|
111 |
+
frames = pipeline(
|
112 |
+
prompt=prompt,
|
113 |
+
height=720,
|
114 |
+
width=1280,
|
115 |
+
image=image,
|
116 |
+
video=video,
|
117 |
+
strength=strength,
|
118 |
+
overlap_size=overlap_size,
|
119 |
+
chunk_size=chunk_size,
|
120 |
+
num_frames=chunk_size,
|
121 |
+
num_inference_steps=30,
|
122 |
+
decode_chunk_size=1,
|
123 |
+
negative_prompt=negative_prompt,
|
124 |
+
guidance_scale=9.0,
|
125 |
+
generator=generator,
|
126 |
+
).frames[0]
|
127 |
+
|
128 |
+
return frames
|
i2v_enhance/pipeline_i2vgen_xl.py
ADDED
@@ -0,0 +1,988 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import PIL
|
21 |
+
import torch
|
22 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
23 |
+
|
24 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
25 |
+
from diffusers.models import AutoencoderKL
|
26 |
+
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
|
27 |
+
from diffusers.schedulers import DDIMScheduler
|
28 |
+
from diffusers.utils import (
|
29 |
+
BaseOutput,
|
30 |
+
logging,
|
31 |
+
replace_example_docstring,
|
32 |
+
)
|
33 |
+
from diffusers.utils.torch_utils import randn_tensor
|
34 |
+
from diffusers.video_processor import VideoProcessor
|
35 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
36 |
+
import random
|
37 |
+
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40 |
+
|
41 |
+
EXAMPLE_DOC_STRING = """
|
42 |
+
Examples:
|
43 |
+
```py
|
44 |
+
>>> import torch
|
45 |
+
>>> from diffusers import I2VGenXLPipeline
|
46 |
+
>>> from diffusers.utils import export_to_gif, load_image
|
47 |
+
|
48 |
+
>>> pipeline = I2VGenXLPipeline.from_pretrained(
|
49 |
+
... "ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16"
|
50 |
+
... )
|
51 |
+
>>> pipeline.enable_model_cpu_offload()
|
52 |
+
|
53 |
+
>>> image_url = (
|
54 |
+
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png"
|
55 |
+
... )
|
56 |
+
>>> image = load_image(image_url).convert("RGB")
|
57 |
+
|
58 |
+
>>> prompt = "Papers were floating in the air on a table in the library"
|
59 |
+
>>> negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
|
60 |
+
>>> generator = torch.manual_seed(8888)
|
61 |
+
|
62 |
+
>>> frames = pipeline(
|
63 |
+
... prompt=prompt,
|
64 |
+
... image=image,
|
65 |
+
... num_inference_steps=50,
|
66 |
+
... negative_prompt=negative_prompt,
|
67 |
+
... guidance_scale=9.0,
|
68 |
+
... generator=generator,
|
69 |
+
... ).frames[0]
|
70 |
+
>>> video_path = export_to_gif(frames, "i2v.gif")
|
71 |
+
```
|
72 |
+
"""
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class I2VGenXLPipelineOutput(BaseOutput):
|
77 |
+
r"""
|
78 |
+
Output class for image-to-video pipeline.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
82 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
83 |
+
denoised
|
84 |
+
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
85 |
+
`(batch_size, num_frames, channels, height, width)`
|
86 |
+
"""
|
87 |
+
|
88 |
+
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
89 |
+
|
90 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
91 |
+
|
92 |
+
|
93 |
+
def retrieve_latents(
|
94 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
95 |
+
):
|
96 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
97 |
+
return encoder_output.latent_dist.sample(generator)
|
98 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
99 |
+
return encoder_output.latent_dist.mode()
|
100 |
+
elif hasattr(encoder_output, "latents"):
|
101 |
+
return encoder_output.latents
|
102 |
+
else:
|
103 |
+
raise AttributeError(
|
104 |
+
"Could not access latents of provided encoder_output")
|
105 |
+
|
106 |
+
|
107 |
+
class I2VGenXLPipeline(
|
108 |
+
DiffusionPipeline,
|
109 |
+
StableDiffusionMixin,
|
110 |
+
):
|
111 |
+
r"""
|
112 |
+
Pipeline for image-to-video generation as proposed in [I2VGenXL](https://i2vgen-xl.github.io/).
|
113 |
+
|
114 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
115 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
116 |
+
|
117 |
+
Args:
|
118 |
+
vae ([`AutoencoderKL`]):
|
119 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
120 |
+
text_encoder ([`CLIPTextModel`]):
|
121 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
122 |
+
tokenizer (`CLIPTokenizer`):
|
123 |
+
A [`~transformers.CLIPTokenizer`] to tokenize text.
|
124 |
+
unet ([`I2VGenXLUNet`]):
|
125 |
+
A [`I2VGenXLUNet`] to denoise the encoded video latents.
|
126 |
+
scheduler ([`DDIMScheduler`]):
|
127 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
128 |
+
"""
|
129 |
+
|
130 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
131 |
+
|
132 |
+
def __init__(
|
133 |
+
self,
|
134 |
+
vae: AutoencoderKL,
|
135 |
+
text_encoder: CLIPTextModel,
|
136 |
+
tokenizer: CLIPTokenizer,
|
137 |
+
image_encoder: CLIPVisionModelWithProjection,
|
138 |
+
feature_extractor: CLIPImageProcessor,
|
139 |
+
unet: I2VGenXLUNet,
|
140 |
+
scheduler: DDIMScheduler,
|
141 |
+
):
|
142 |
+
super().__init__()
|
143 |
+
|
144 |
+
self.register_modules(
|
145 |
+
vae=vae,
|
146 |
+
text_encoder=text_encoder,
|
147 |
+
tokenizer=tokenizer,
|
148 |
+
image_encoder=image_encoder,
|
149 |
+
feature_extractor=feature_extractor,
|
150 |
+
unet=unet,
|
151 |
+
scheduler=scheduler,
|
152 |
+
)
|
153 |
+
self.vae_scale_factor = 2 ** (
|
154 |
+
len(self.vae.config.block_out_channels) - 1)
|
155 |
+
# `do_resize=False` as we do custom resizing.
|
156 |
+
self.video_processor = VideoProcessor(
|
157 |
+
vae_scale_factor=self.vae_scale_factor, do_resize=False)
|
158 |
+
|
159 |
+
@property
|
160 |
+
def guidance_scale(self):
|
161 |
+
return self._guidance_scale
|
162 |
+
|
163 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
164 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
165 |
+
# corresponds to doing no classifier free guidance.
|
166 |
+
@property
|
167 |
+
def do_classifier_free_guidance(self):
|
168 |
+
return self._guidance_scale > 1
|
169 |
+
|
170 |
+
def encode_prompt(
|
171 |
+
self,
|
172 |
+
prompt,
|
173 |
+
device,
|
174 |
+
num_videos_per_prompt,
|
175 |
+
negative_prompt=None,
|
176 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
177 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
178 |
+
clip_skip: Optional[int] = None,
|
179 |
+
):
|
180 |
+
r"""
|
181 |
+
Encodes the prompt into text encoder hidden states.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
prompt (`str` or `List[str]`, *optional*):
|
185 |
+
prompt to be encoded
|
186 |
+
device: (`torch.device`):
|
187 |
+
torch device
|
188 |
+
num_videos_per_prompt (`int`):
|
189 |
+
number of images that should be generated per prompt
|
190 |
+
do_classifier_free_guidance (`bool`):
|
191 |
+
whether to use classifier free guidance or not
|
192 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
193 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
194 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
195 |
+
less than `1`).
|
196 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
197 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
198 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
199 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
200 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
201 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
202 |
+
argument.
|
203 |
+
clip_skip (`int`, *optional*):
|
204 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
205 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
206 |
+
"""
|
207 |
+
if prompt is not None and isinstance(prompt, str):
|
208 |
+
batch_size = 1
|
209 |
+
elif prompt is not None and isinstance(prompt, list):
|
210 |
+
batch_size = len(prompt)
|
211 |
+
else:
|
212 |
+
batch_size = prompt_embeds.shape[0]
|
213 |
+
|
214 |
+
if prompt_embeds is None:
|
215 |
+
text_inputs = self.tokenizer(
|
216 |
+
prompt,
|
217 |
+
padding="max_length",
|
218 |
+
max_length=self.tokenizer.model_max_length,
|
219 |
+
truncation=True,
|
220 |
+
return_tensors="pt",
|
221 |
+
)
|
222 |
+
text_input_ids = text_inputs.input_ids
|
223 |
+
untruncated_ids = self.tokenizer(
|
224 |
+
prompt, padding="longest", return_tensors="pt").input_ids
|
225 |
+
|
226 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
227 |
+
text_input_ids, untruncated_ids
|
228 |
+
):
|
229 |
+
removed_text = self.tokenizer.batch_decode(
|
230 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
|
231 |
+
)
|
232 |
+
logger.warning(
|
233 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
234 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
235 |
+
)
|
236 |
+
|
237 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
238 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
239 |
+
else:
|
240 |
+
attention_mask = None
|
241 |
+
|
242 |
+
if clip_skip is None:
|
243 |
+
prompt_embeds = self.text_encoder(
|
244 |
+
text_input_ids.to(device), attention_mask=attention_mask)
|
245 |
+
prompt_embeds = prompt_embeds[0]
|
246 |
+
else:
|
247 |
+
prompt_embeds = self.text_encoder(
|
248 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
249 |
+
)
|
250 |
+
# Access the `hidden_states` first, that contains a tuple of
|
251 |
+
# all the hidden states from the encoder layers. Then index into
|
252 |
+
# the tuple to access the hidden states from the desired layer.
|
253 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
254 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
255 |
+
# representations. The `last_hidden_states` that we typically use for
|
256 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
257 |
+
# layer.
|
258 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(
|
259 |
+
prompt_embeds)
|
260 |
+
|
261 |
+
if self.text_encoder is not None:
|
262 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
263 |
+
elif self.unet is not None:
|
264 |
+
prompt_embeds_dtype = self.unet.dtype
|
265 |
+
else:
|
266 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
267 |
+
|
268 |
+
prompt_embeds = prompt_embeds.to(
|
269 |
+
dtype=prompt_embeds_dtype, device=device)
|
270 |
+
|
271 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
272 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
273 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
274 |
+
prompt_embeds = prompt_embeds.view(
|
275 |
+
bs_embed * num_videos_per_prompt, seq_len, -1)
|
276 |
+
|
277 |
+
# get unconditional embeddings for classifier free guidance
|
278 |
+
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
279 |
+
uncond_tokens: List[str]
|
280 |
+
if negative_prompt is None:
|
281 |
+
uncond_tokens = [""] * batch_size
|
282 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
283 |
+
raise TypeError(
|
284 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
285 |
+
f" {type(prompt)}."
|
286 |
+
)
|
287 |
+
elif isinstance(negative_prompt, str):
|
288 |
+
uncond_tokens = [negative_prompt]
|
289 |
+
elif batch_size != len(negative_prompt):
|
290 |
+
raise ValueError(
|
291 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
292 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
293 |
+
" the batch size of `prompt`."
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
uncond_tokens = negative_prompt
|
297 |
+
|
298 |
+
max_length = prompt_embeds.shape[1]
|
299 |
+
uncond_input = self.tokenizer(
|
300 |
+
uncond_tokens,
|
301 |
+
padding="max_length",
|
302 |
+
max_length=max_length,
|
303 |
+
truncation=True,
|
304 |
+
return_tensors="pt",
|
305 |
+
)
|
306 |
+
|
307 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
308 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
309 |
+
else:
|
310 |
+
attention_mask = None
|
311 |
+
|
312 |
+
# Apply clip_skip to negative prompt embeds
|
313 |
+
if clip_skip is None:
|
314 |
+
negative_prompt_embeds = self.text_encoder(
|
315 |
+
uncond_input.input_ids.to(device),
|
316 |
+
attention_mask=attention_mask,
|
317 |
+
)
|
318 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
319 |
+
else:
|
320 |
+
negative_prompt_embeds = self.text_encoder(
|
321 |
+
uncond_input.input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
322 |
+
)
|
323 |
+
# Access the `hidden_states` first, that contains a tuple of
|
324 |
+
# all the hidden states from the encoder layers. Then index into
|
325 |
+
# the tuple to access the hidden states from the desired layer.
|
326 |
+
negative_prompt_embeds = negative_prompt_embeds[-1][-(
|
327 |
+
clip_skip + 1)]
|
328 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
329 |
+
# representations. The `last_hidden_states` that we typically use for
|
330 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
331 |
+
# layer.
|
332 |
+
negative_prompt_embeds = self.text_encoder.text_model.final_layer_norm(
|
333 |
+
negative_prompt_embeds)
|
334 |
+
|
335 |
+
if self.do_classifier_free_guidance:
|
336 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
337 |
+
seq_len = negative_prompt_embeds.shape[1]
|
338 |
+
|
339 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
340 |
+
dtype=prompt_embeds_dtype, device=device)
|
341 |
+
|
342 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
343 |
+
1, num_videos_per_prompt, 1)
|
344 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
345 |
+
batch_size * num_videos_per_prompt, seq_len, -1)
|
346 |
+
|
347 |
+
return prompt_embeds, negative_prompt_embeds
|
348 |
+
|
349 |
+
def _encode_image(self, image, device, num_videos_per_prompt):
|
350 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
351 |
+
|
352 |
+
if not isinstance(image, torch.Tensor):
|
353 |
+
image = self.video_processor.pil_to_numpy(image)
|
354 |
+
image = self.video_processor.numpy_to_pt(image)
|
355 |
+
|
356 |
+
# Normalize the image with CLIP training stats.
|
357 |
+
image = self.feature_extractor(
|
358 |
+
images=image,
|
359 |
+
do_normalize=True,
|
360 |
+
do_center_crop=False,
|
361 |
+
do_resize=False,
|
362 |
+
do_rescale=False,
|
363 |
+
return_tensors="pt",
|
364 |
+
).pixel_values
|
365 |
+
|
366 |
+
image = image.to(device=device, dtype=dtype)
|
367 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
368 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
369 |
+
|
370 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
371 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
372 |
+
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
373 |
+
image_embeddings = image_embeddings.view(
|
374 |
+
bs_embed * num_videos_per_prompt, seq_len, -1)
|
375 |
+
|
376 |
+
if self.do_classifier_free_guidance:
|
377 |
+
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
378 |
+
image_embeddings = torch.cat(
|
379 |
+
[negative_image_embeddings, image_embeddings])
|
380 |
+
|
381 |
+
return image_embeddings
|
382 |
+
|
383 |
+
def decode_latents(self, latents, decode_chunk_size=None):
|
384 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
385 |
+
|
386 |
+
batch_size, channels, num_frames, height, width = latents.shape
|
387 |
+
latents = latents.permute(0, 2, 1, 3, 4).reshape(
|
388 |
+
batch_size * num_frames, channels, height, width)
|
389 |
+
|
390 |
+
if decode_chunk_size is not None:
|
391 |
+
frames = []
|
392 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
393 |
+
frame = self.vae.decode(
|
394 |
+
latents[i: i + decode_chunk_size]).sample
|
395 |
+
frames.append(frame)
|
396 |
+
image = torch.cat(frames, dim=0)
|
397 |
+
else:
|
398 |
+
image = self.vae.decode(latents).sample
|
399 |
+
|
400 |
+
decode_shape = (batch_size, num_frames, -1) + image.shape[2:]
|
401 |
+
video = image[None, :].reshape(decode_shape).permute(0, 2, 1, 3, 4)
|
402 |
+
|
403 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
404 |
+
video = video.float()
|
405 |
+
return video
|
406 |
+
|
407 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
408 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
409 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
410 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
411 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
412 |
+
# and should be between [0, 1]
|
413 |
+
|
414 |
+
accepts_eta = "eta" in set(inspect.signature(
|
415 |
+
self.scheduler.step).parameters.keys())
|
416 |
+
extra_step_kwargs = {}
|
417 |
+
if accepts_eta:
|
418 |
+
extra_step_kwargs["eta"] = eta
|
419 |
+
|
420 |
+
# check if the scheduler accepts generator
|
421 |
+
accepts_generator = "generator" in set(
|
422 |
+
inspect.signature(self.scheduler.step).parameters.keys())
|
423 |
+
if accepts_generator:
|
424 |
+
extra_step_kwargs["generator"] = generator
|
425 |
+
return extra_step_kwargs
|
426 |
+
|
427 |
+
def check_inputs(
|
428 |
+
self,
|
429 |
+
prompt,
|
430 |
+
image,
|
431 |
+
height,
|
432 |
+
width,
|
433 |
+
negative_prompt=None,
|
434 |
+
prompt_embeds=None,
|
435 |
+
negative_prompt_embeds=None,
|
436 |
+
):
|
437 |
+
if height % 8 != 0 or width % 8 != 0:
|
438 |
+
raise ValueError(
|
439 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
440 |
+
|
441 |
+
if prompt is not None and prompt_embeds is not None:
|
442 |
+
raise ValueError(
|
443 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
444 |
+
" only forward one of the two."
|
445 |
+
)
|
446 |
+
elif prompt is None and prompt_embeds is None:
|
447 |
+
raise ValueError(
|
448 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
449 |
+
)
|
450 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
451 |
+
raise ValueError(
|
452 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
453 |
+
|
454 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
455 |
+
raise ValueError(
|
456 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
457 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
458 |
+
)
|
459 |
+
|
460 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
461 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
462 |
+
raise ValueError(
|
463 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
464 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
465 |
+
f" {negative_prompt_embeds.shape}."
|
466 |
+
)
|
467 |
+
|
468 |
+
if (
|
469 |
+
not isinstance(image, torch.Tensor)
|
470 |
+
and not isinstance(image, PIL.Image.Image)
|
471 |
+
and not isinstance(image, list)
|
472 |
+
):
|
473 |
+
raise ValueError(
|
474 |
+
"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
475 |
+
f" {type(image)}"
|
476 |
+
)
|
477 |
+
|
478 |
+
def prepare_image_latents(
|
479 |
+
self,
|
480 |
+
image,
|
481 |
+
device,
|
482 |
+
num_frames,
|
483 |
+
num_videos_per_prompt,
|
484 |
+
):
|
485 |
+
image = image.to(device=device)
|
486 |
+
image_latents = self.vae.encode(image).latent_dist.sample()
|
487 |
+
image_latents = image_latents * self.vae.config.scaling_factor
|
488 |
+
|
489 |
+
# Add frames dimension to image latents
|
490 |
+
image_latents = image_latents.unsqueeze(2)
|
491 |
+
|
492 |
+
# Append a position mask for each subsequent frame
|
493 |
+
# after the intial image latent frame
|
494 |
+
frame_position_mask = []
|
495 |
+
for frame_idx in range(num_frames - 1):
|
496 |
+
scale = (frame_idx + 1) / (num_frames - 1)
|
497 |
+
frame_position_mask.append(
|
498 |
+
torch.ones_like(image_latents[:, :, :1]) * scale)
|
499 |
+
if frame_position_mask:
|
500 |
+
frame_position_mask = torch.cat(frame_position_mask, dim=2)
|
501 |
+
image_latents = torch.cat(
|
502 |
+
[image_latents, frame_position_mask], dim=2)
|
503 |
+
|
504 |
+
# duplicate image_latents for each generation per prompt, using mps friendly method
|
505 |
+
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1, 1)
|
506 |
+
|
507 |
+
if self.do_classifier_free_guidance:
|
508 |
+
image_latents = torch.cat([image_latents] * 2)
|
509 |
+
|
510 |
+
return image_latents
|
511 |
+
|
512 |
+
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
513 |
+
def prepare_latents(
|
514 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
515 |
+
):
|
516 |
+
shape = (
|
517 |
+
batch_size,
|
518 |
+
num_channels_latents,
|
519 |
+
num_frames,
|
520 |
+
height // self.vae_scale_factor,
|
521 |
+
width // self.vae_scale_factor,
|
522 |
+
)
|
523 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
524 |
+
raise ValueError(
|
525 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
526 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
527 |
+
)
|
528 |
+
|
529 |
+
if latents is None:
|
530 |
+
latents = randn_tensor(
|
531 |
+
shape, generator=generator, device=device, dtype=dtype)
|
532 |
+
else:
|
533 |
+
latents = latents.to(device)
|
534 |
+
|
535 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
536 |
+
latents = latents * self.scheduler.init_noise_sigma
|
537 |
+
return latents
|
538 |
+
|
539 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
540 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
541 |
+
# get the original timestep using init_timestep
|
542 |
+
init_timestep = min(
|
543 |
+
int(num_inference_steps * strength), num_inference_steps)
|
544 |
+
|
545 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
546 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
|
547 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
548 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
549 |
+
|
550 |
+
return timesteps, num_inference_steps - t_start
|
551 |
+
|
552 |
+
# Similar to image, we need to prepare the latents for the video.
|
553 |
+
def prepare_video_latents(
|
554 |
+
self, video, timestep, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
555 |
+
):
|
556 |
+
video = video.to(device=device, dtype=dtype)
|
557 |
+
is_long = video.shape[2] > 16
|
558 |
+
|
559 |
+
# change from (b, c, f, h, w) -> (b * f, c, w, h)
|
560 |
+
bsz, channel, frames, width, height = video.shape
|
561 |
+
video = video.permute(0, 2, 1, 3, 4).reshape(
|
562 |
+
bsz * frames, channel, width, height)
|
563 |
+
|
564 |
+
if video.shape[1] == 4:
|
565 |
+
init_latents = video
|
566 |
+
else:
|
567 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
568 |
+
raise ValueError(
|
569 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
570 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
571 |
+
)
|
572 |
+
elif isinstance(generator, list):
|
573 |
+
init_latents = [
|
574 |
+
retrieve_latents(self.vae.encode(
|
575 |
+
video[i: i + 1]), generator=generator[i])
|
576 |
+
for i in range(batch_size)
|
577 |
+
]
|
578 |
+
init_latents = torch.cat(init_latents, dim=0)
|
579 |
+
else:
|
580 |
+
if not is_long:
|
581 |
+
# 1 step encoding
|
582 |
+
init_latents = retrieve_latents(
|
583 |
+
self.vae.encode(video), generator=generator)
|
584 |
+
else:
|
585 |
+
# chunk by chunk encoding. for low-memory consumption.
|
586 |
+
video_list = torch.chunk(
|
587 |
+
video, video.shape[0] // 16, dim=0)
|
588 |
+
with torch.no_grad():
|
589 |
+
init_latents = []
|
590 |
+
for video_chunk in video_list:
|
591 |
+
video_chunk = retrieve_latents(
|
592 |
+
self.vae.encode(video_chunk), generator=generator)
|
593 |
+
init_latents.append(video_chunk)
|
594 |
+
init_latents = torch.cat(init_latents, dim=0)
|
595 |
+
# torch.cuda.empty_cache()
|
596 |
+
|
597 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
598 |
+
|
599 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
600 |
+
raise ValueError(
|
601 |
+
f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
602 |
+
)
|
603 |
+
else:
|
604 |
+
init_latents = torch.cat([init_latents], dim=0)
|
605 |
+
|
606 |
+
shape = init_latents.shape
|
607 |
+
noise = randn_tensor(shape, generator=generator,
|
608 |
+
device=device, dtype=dtype)
|
609 |
+
|
610 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
611 |
+
latents = latents[None, :].reshape(
|
612 |
+
(bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4)
|
613 |
+
|
614 |
+
return latents
|
615 |
+
|
616 |
+
@torch.no_grad()
|
617 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
618 |
+
def __call__(
|
619 |
+
self,
|
620 |
+
prompt: Union[str, List[str]] = None,
|
621 |
+
# Now image can be either a single image or a list of images (when randomized blending is enalbled).
|
622 |
+
image: Union[List[PipelineImageInput], PipelineImageInput] = None,
|
623 |
+
video: Union[List[np.ndarray], torch.Tensor] = None,
|
624 |
+
strength: float = 0.97,
|
625 |
+
overlap_size: int = 0,
|
626 |
+
chunk_size: int = 38,
|
627 |
+
height: Optional[int] = 720,
|
628 |
+
width: Optional[int] = 1280,
|
629 |
+
target_fps: Optional[int] = 38,
|
630 |
+
num_frames: int = 38,
|
631 |
+
num_inference_steps: int = 50,
|
632 |
+
guidance_scale: float = 9.0,
|
633 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
634 |
+
eta: float = 0.0,
|
635 |
+
num_videos_per_prompt: Optional[int] = 1,
|
636 |
+
decode_chunk_size: Optional[int] = 1,
|
637 |
+
generator: Optional[Union[torch.Generator,
|
638 |
+
List[torch.Generator]]] = None,
|
639 |
+
latents: Optional[torch.Tensor] = None,
|
640 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
641 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
642 |
+
output_type: Optional[str] = "pil",
|
643 |
+
return_dict: bool = True,
|
644 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
645 |
+
clip_skip: Optional[int] = 1,
|
646 |
+
):
|
647 |
+
r"""
|
648 |
+
The call function to the pipeline for image-to-video generation with [`I2VGenXLPipeline`].
|
649 |
+
|
650 |
+
Args:
|
651 |
+
prompt (`str` or `List[str]`, *optional*):
|
652 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
653 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
|
654 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
655 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
656 |
+
video (`List[np.ndarray]` or `torch.Tensor`):
|
657 |
+
Video to guide video enhancement.
|
658 |
+
strength (`float`, *optional*, defaults to 0.97):
|
659 |
+
Indicates extent to transform the reference `video`. Must be between 0 and 1. `image` is used as a
|
660 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
661 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
662 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
663 |
+
essentially ignores `image`.
|
664 |
+
overlap_size (`int`, *optional*, defaults to 0):
|
665 |
+
This parameter is used in randomized blending, when it is enabled.
|
666 |
+
It defines the size of overlap between neighbouring chunks.
|
667 |
+
chunk_size (`int`, *optional*, defaults to 38):
|
668 |
+
This parameter is used in randomized blending, when it is enabled.
|
669 |
+
It defines the number of frames we will enhance during each chunk of randomized blending.
|
670 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
671 |
+
The height in pixels of the generated image.
|
672 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
673 |
+
The width in pixels of the generated image.
|
674 |
+
target_fps (`int`, *optional*):
|
675 |
+
Frames per second. The rate at which the generated images shall be exported to a video after
|
676 |
+
generation. This is also used as a "micro-condition" while generation.
|
677 |
+
num_frames (`int`, *optional*):
|
678 |
+
The number of video frames to generate.
|
679 |
+
num_inference_steps (`int`, *optional*):
|
680 |
+
The number of denoising steps.
|
681 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
682 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
683 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
684 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
685 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
686 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
687 |
+
eta (`float`, *optional*):
|
688 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
689 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
690 |
+
num_videos_per_prompt (`int`, *optional*):
|
691 |
+
The number of images to generate per prompt.
|
692 |
+
decode_chunk_size (`int`, *optional*):
|
693 |
+
The number of frames to decode at a time. The higher the chunk size, the higher the temporal
|
694 |
+
consistency between frames, but also the higher the memory consumption. By default, the decoder will
|
695 |
+
decode all frames at once for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
|
696 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
697 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
698 |
+
generation deterministic.
|
699 |
+
latents (`torch.Tensor`, *optional*):
|
700 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
701 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
702 |
+
tensor is generated by sampling using the supplied random `generator`.
|
703 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
704 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
705 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
706 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
707 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
708 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
709 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
710 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
711 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
712 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
713 |
+
plain tuple.
|
714 |
+
cross_attention_kwargs (`dict`, *optional*):
|
715 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
716 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
717 |
+
clip_skip (`int`, *optional*):
|
718 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
719 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
720 |
+
|
721 |
+
Examples:
|
722 |
+
|
723 |
+
Returns:
|
724 |
+
[`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] or `tuple`:
|
725 |
+
If `return_dict` is `True`, [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] is
|
726 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
727 |
+
"""
|
728 |
+
# 0. Default height and width to unet
|
729 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
730 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
731 |
+
|
732 |
+
# 1. Check inputs. Raise error if not correct
|
733 |
+
self.check_inputs(prompt, image, height, width,
|
734 |
+
negative_prompt, prompt_embeds, negative_prompt_embeds)
|
735 |
+
|
736 |
+
# 2. Define call parameters
|
737 |
+
if prompt is not None and isinstance(prompt, str):
|
738 |
+
batch_size = 1
|
739 |
+
elif prompt is not None and isinstance(prompt, list):
|
740 |
+
batch_size = len(prompt)
|
741 |
+
else:
|
742 |
+
batch_size = prompt_embeds.shape[0]
|
743 |
+
|
744 |
+
device = self._execution_device
|
745 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
746 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
747 |
+
# corresponds to doing no classifier free guidance.
|
748 |
+
self._guidance_scale = guidance_scale
|
749 |
+
|
750 |
+
# 3.1 Encode input text prompt
|
751 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
752 |
+
prompt,
|
753 |
+
device,
|
754 |
+
num_videos_per_prompt,
|
755 |
+
negative_prompt,
|
756 |
+
prompt_embeds=prompt_embeds,
|
757 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
758 |
+
clip_skip=clip_skip,
|
759 |
+
)
|
760 |
+
# For classifier free guidance, we need to do two forward passes.
|
761 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
762 |
+
# to avoid doing two forward passes
|
763 |
+
if self.do_classifier_free_guidance:
|
764 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
765 |
+
|
766 |
+
# 3.2 Encode image prompt
|
767 |
+
# 3.2.1 Image encodings.
|
768 |
+
# https://github.com/ali-vilab/i2vgen-xl/blob/2539c9262ff8a2a22fa9daecbfd13f0a2dbc32d0/tools/inferences/inference_i2vgen_entrance.py#L114
|
769 |
+
# As now we can have a list of images (when randomized blending), we encode each image separately as before.
|
770 |
+
image_embeddings_list = []
|
771 |
+
for img in image:
|
772 |
+
cropped_image = _center_crop_wide(img, (width, width))
|
773 |
+
cropped_image = _resize_bilinear(
|
774 |
+
cropped_image, (self.feature_extractor.crop_size["width"],
|
775 |
+
self.feature_extractor.crop_size["height"])
|
776 |
+
)
|
777 |
+
image_embeddings = self._encode_image(
|
778 |
+
cropped_image, device, num_videos_per_prompt)
|
779 |
+
image_embeddings_list.append(image_embeddings)
|
780 |
+
|
781 |
+
# 3.2.2 Image latents.
|
782 |
+
# As now we can have a list of images (when randomized blending), we encode each image separately as before.
|
783 |
+
image_latents_list = []
|
784 |
+
for img in image:
|
785 |
+
resized_image = _center_crop_wide(img, (width, height))
|
786 |
+
img = self.video_processor.preprocess(resized_image).to(
|
787 |
+
device=device, dtype=image_embeddings_list[0].dtype)
|
788 |
+
image_latents = self.prepare_image_latents(
|
789 |
+
img,
|
790 |
+
device=device,
|
791 |
+
num_frames=num_frames,
|
792 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
793 |
+
)
|
794 |
+
image_latents_list.append(image_latents)
|
795 |
+
|
796 |
+
# 3.3 Prepare additional conditions for the UNet.
|
797 |
+
if self.do_classifier_free_guidance:
|
798 |
+
fps_tensor = torch.tensor([target_fps, target_fps]).to(device)
|
799 |
+
else:
|
800 |
+
fps_tensor = torch.tensor([target_fps]).to(device)
|
801 |
+
fps_tensor = fps_tensor.repeat(
|
802 |
+
batch_size * num_videos_per_prompt, 1).ravel()
|
803 |
+
|
804 |
+
# 3.4 Preprocess video, similar to images.
|
805 |
+
video = self.video_processor.preprocess_video(video).to(
|
806 |
+
device=device, dtype=image_embeddings_list[0].dtype)
|
807 |
+
num_images_per_prompt = 1
|
808 |
+
|
809 |
+
# 4. Prepare timesteps. This will be used for modified SDEdit approach.
|
810 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
811 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
812 |
+
num_inference_steps, strength, device)
|
813 |
+
latent_timestep = timesteps[:1].repeat(
|
814 |
+
batch_size * num_images_per_prompt)
|
815 |
+
|
816 |
+
# 5. Prepare latent variables. Now we get latents for input video.
|
817 |
+
num_channels_latents = self.unet.config.in_channels
|
818 |
+
latents = self.prepare_video_latents(
|
819 |
+
video,
|
820 |
+
latent_timestep,
|
821 |
+
batch_size * num_videos_per_prompt,
|
822 |
+
num_channels_latents,
|
823 |
+
num_frames,
|
824 |
+
height,
|
825 |
+
width,
|
826 |
+
prompt_embeds.dtype,
|
827 |
+
device,
|
828 |
+
generator,
|
829 |
+
latents,
|
830 |
+
)
|
831 |
+
|
832 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
833 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
834 |
+
|
835 |
+
# 7. Denoising loop
|
836 |
+
num_warmup_steps = len(timesteps) - \
|
837 |
+
num_inference_steps * self.scheduler.order
|
838 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
839 |
+
for i, t in enumerate(timesteps):
|
840 |
+
latents_denoised = torch.empty_like(latents)
|
841 |
+
|
842 |
+
CHUNK_START = 0
|
843 |
+
# Each chunk must have a corresponding 1st frame
|
844 |
+
for idx in range(len(image_latents_list)):
|
845 |
+
latents_chunk = latents[:, :,
|
846 |
+
CHUNK_START:CHUNK_START + chunk_size]
|
847 |
+
|
848 |
+
# expand the latents if we are doing classifier free guidance
|
849 |
+
latent_model_input = torch.cat(
|
850 |
+
[latents_chunk] * 2) if self.do_classifier_free_guidance else latents_chunk
|
851 |
+
latent_model_input = self.scheduler.scale_model_input(
|
852 |
+
latent_model_input, t)
|
853 |
+
|
854 |
+
# predict the noise residual
|
855 |
+
noise_pred = self.unet(
|
856 |
+
latent_model_input,
|
857 |
+
t,
|
858 |
+
encoder_hidden_states=prompt_embeds,
|
859 |
+
fps=fps_tensor,
|
860 |
+
image_latents=image_latents_list[idx],
|
861 |
+
image_embeddings=image_embeddings_list[idx],
|
862 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
863 |
+
return_dict=False,
|
864 |
+
)[0]
|
865 |
+
|
866 |
+
# perform guidance
|
867 |
+
if self.do_classifier_free_guidance:
|
868 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(
|
869 |
+
2)
|
870 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
871 |
+
(noise_pred_text - noise_pred_uncond)
|
872 |
+
|
873 |
+
# reshape latents_chunk
|
874 |
+
batch_size, channel, frames, width, height = latents_chunk.shape
|
875 |
+
latents_chunk = latents_chunk.permute(0, 2, 1, 3, 4).reshape(
|
876 |
+
batch_size * frames, channel, width, height)
|
877 |
+
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(
|
878 |
+
batch_size * frames, channel, width, height)
|
879 |
+
|
880 |
+
# compute the previous noisy sample x_t -> x_t-1
|
881 |
+
latents_chunk = self.scheduler.step(
|
882 |
+
noise_pred, t, latents_chunk, **extra_step_kwargs).prev_sample
|
883 |
+
|
884 |
+
# reshape latents back
|
885 |
+
latents_chunk = latents_chunk[None, :].reshape(
|
886 |
+
batch_size, frames, channel, width, height).permute(0, 2, 1, 3, 4)
|
887 |
+
|
888 |
+
# Make sure random_offset is set correctly.
|
889 |
+
if CHUNK_START == 0:
|
890 |
+
random_offset = 0
|
891 |
+
else:
|
892 |
+
if overlap_size != 0:
|
893 |
+
random_offset = random.randint(0, overlap_size - 1)
|
894 |
+
else:
|
895 |
+
random_offset = 0
|
896 |
+
|
897 |
+
# Apply Randomized Blending.
|
898 |
+
latents_denoised[:, :, CHUNK_START + random_offset:CHUNK_START +
|
899 |
+
chunk_size] = latents_chunk[:, :, random_offset:]
|
900 |
+
CHUNK_START += chunk_size - overlap_size
|
901 |
+
|
902 |
+
latents = latents_denoised
|
903 |
+
|
904 |
+
if CHUNK_START + overlap_size > latents_denoised.shape[2]:
|
905 |
+
raise NotImplementedError(f"Video of size={latents_denoised.shape[2]} is not dividable into chunks "
|
906 |
+
f"with size={chunk_size} and overlap={overlap_size}")
|
907 |
+
|
908 |
+
# call the callback, if provided
|
909 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
910 |
+
progress_bar.update()
|
911 |
+
|
912 |
+
# 8. Post processing
|
913 |
+
if output_type == "latent":
|
914 |
+
video = latents
|
915 |
+
else:
|
916 |
+
video_tensor = self.decode_latents(
|
917 |
+
latents, decode_chunk_size=decode_chunk_size)
|
918 |
+
video = self.video_processor.postprocess_video(
|
919 |
+
video=video_tensor, output_type=output_type)
|
920 |
+
|
921 |
+
# 9. Offload all models
|
922 |
+
self.maybe_free_model_hooks()
|
923 |
+
|
924 |
+
if not return_dict:
|
925 |
+
return (video,)
|
926 |
+
|
927 |
+
return I2VGenXLPipelineOutput(frames=video)
|
928 |
+
|
929 |
+
|
930 |
+
# The following utilities are taken and adapted from
|
931 |
+
# https://github.com/ali-vilab/i2vgen-xl/blob/main/utils/transforms.py.
|
932 |
+
|
933 |
+
|
934 |
+
def _convert_pt_to_pil(image: Union[torch.Tensor, List[torch.Tensor]]):
|
935 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor):
|
936 |
+
image = torch.cat(image, 0)
|
937 |
+
|
938 |
+
if isinstance(image, torch.Tensor):
|
939 |
+
if image.ndim == 3:
|
940 |
+
image = image.unsqueeze(0)
|
941 |
+
|
942 |
+
image_numpy = VaeImageProcessor.pt_to_numpy(image)
|
943 |
+
image_pil = VaeImageProcessor.numpy_to_pil(image_numpy)
|
944 |
+
image = image_pil
|
945 |
+
|
946 |
+
return image
|
947 |
+
|
948 |
+
|
949 |
+
def _resize_bilinear(
|
950 |
+
image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int]
|
951 |
+
):
|
952 |
+
# First convert the images to PIL in case they are float tensors (only relevant for tests now).
|
953 |
+
image = _convert_pt_to_pil(image)
|
954 |
+
|
955 |
+
if isinstance(image, list):
|
956 |
+
image = [u.resize(resolution, PIL.Image.BILINEAR) for u in image]
|
957 |
+
else:
|
958 |
+
image = image.resize(resolution, PIL.Image.BILINEAR)
|
959 |
+
return image
|
960 |
+
|
961 |
+
|
962 |
+
def _center_crop_wide(
|
963 |
+
image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int]
|
964 |
+
):
|
965 |
+
# First convert the images to PIL in case they are float tensors (only relevant for tests now).
|
966 |
+
image = _convert_pt_to_pil(image)
|
967 |
+
|
968 |
+
if isinstance(image, list):
|
969 |
+
scale = min(image[0].size[0] / resolution[0],
|
970 |
+
image[0].size[1] / resolution[1])
|
971 |
+
image = [u.resize((round(u.width // scale), round(u.height //
|
972 |
+
scale)), resample=PIL.Image.BOX) for u in image]
|
973 |
+
|
974 |
+
# center crop
|
975 |
+
x1 = (image[0].width - resolution[0]) // 2
|
976 |
+
y1 = (image[0].height - resolution[1]) // 2
|
977 |
+
image = [u.crop((x1, y1, x1 + resolution[0], y1 + resolution[1]))
|
978 |
+
for u in image]
|
979 |
+
return image
|
980 |
+
else:
|
981 |
+
scale = min(image.size[0] / resolution[0],
|
982 |
+
image.size[1] / resolution[1])
|
983 |
+
image = image.resize((round(image.width // scale),
|
984 |
+
round(image.height // scale)), resample=PIL.Image.BOX)
|
985 |
+
x1 = (image.width - resolution[0]) // 2
|
986 |
+
y1 = (image.height - resolution[1]) // 2
|
987 |
+
image = image.crop((x1, y1, x1 + resolution[0], y1 + resolution[1]))
|
988 |
+
return image
|
i2v_enhance/thirdparty/VFI/Trainer.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/Trainer.py
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
5 |
+
from torch.optim import AdamW
|
6 |
+
from i2v_enhance.thirdparty.VFI.model.loss import *
|
7 |
+
from i2v_enhance.thirdparty.VFI.config import *
|
8 |
+
|
9 |
+
|
10 |
+
class Model:
|
11 |
+
def __init__(self, local_rank):
|
12 |
+
backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE']
|
13 |
+
backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH']
|
14 |
+
self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg)
|
15 |
+
self.name = MODEL_CONFIG['LOGNAME']
|
16 |
+
self.device()
|
17 |
+
|
18 |
+
# train
|
19 |
+
self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4)
|
20 |
+
self.lap = LapLoss()
|
21 |
+
if local_rank != -1:
|
22 |
+
self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank)
|
23 |
+
|
24 |
+
def train(self):
|
25 |
+
self.net.train()
|
26 |
+
|
27 |
+
def eval(self):
|
28 |
+
self.net.eval()
|
29 |
+
|
30 |
+
def device(self):
|
31 |
+
self.net.to(torch.device("cuda"))
|
32 |
+
|
33 |
+
def unload(self):
|
34 |
+
self.net.to(torch.device("cpu"))
|
35 |
+
|
36 |
+
def load_model(self, name=None, rank=0):
|
37 |
+
def convert(param):
|
38 |
+
return {
|
39 |
+
k.replace("module.", ""): v
|
40 |
+
for k, v in param.items()
|
41 |
+
if "module." in k and 'attn_mask' not in k and 'HW' not in k
|
42 |
+
}
|
43 |
+
if rank <= 0 :
|
44 |
+
if name is None:
|
45 |
+
name = self.name
|
46 |
+
# self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')))
|
47 |
+
self.net.load_state_dict(convert(torch.load(f'{name}')))
|
48 |
+
|
49 |
+
def save_model(self, rank=0):
|
50 |
+
if rank == 0:
|
51 |
+
torch.save(self.net.state_dict(),f'ckpt/{self.name}.pkl')
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
def hr_inference(self, img0, img1, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False):
|
55 |
+
'''
|
56 |
+
Infer with down_scale flow
|
57 |
+
Noting: return BxCxHxW
|
58 |
+
'''
|
59 |
+
def infer(imgs):
|
60 |
+
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
61 |
+
imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
|
62 |
+
|
63 |
+
flow, mask = self.net.calculate_flow(imgs_down, timestep)
|
64 |
+
|
65 |
+
flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
|
66 |
+
mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
|
67 |
+
|
68 |
+
af, _ = self.net.feature_bone(img0, img1)
|
69 |
+
pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
|
70 |
+
return pred
|
71 |
+
|
72 |
+
imgs = torch.cat((img0, img1), 1)
|
73 |
+
if fast_TTA:
|
74 |
+
imgs_ = imgs.flip(2).flip(3)
|
75 |
+
input = torch.cat((imgs, imgs_), 0)
|
76 |
+
preds = infer(input)
|
77 |
+
return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
|
78 |
+
|
79 |
+
if TTA == False:
|
80 |
+
return infer(imgs)
|
81 |
+
else:
|
82 |
+
return (infer(imgs) + infer(imgs.flip(2).flip(3)).flip(2).flip(3)) / 2
|
83 |
+
|
84 |
+
@torch.no_grad()
|
85 |
+
def inference(self, img0, img1, TTA = False, timestep = 0.5, fast_TTA = False):
|
86 |
+
imgs = torch.cat((img0, img1), 1)
|
87 |
+
'''
|
88 |
+
Noting: return BxCxHxW
|
89 |
+
'''
|
90 |
+
if fast_TTA:
|
91 |
+
imgs_ = imgs.flip(2).flip(3)
|
92 |
+
input = torch.cat((imgs, imgs_), 0)
|
93 |
+
_, _, _, preds = self.net(input, timestep=timestep)
|
94 |
+
return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
|
95 |
+
|
96 |
+
_, _, _, pred = self.net(imgs, timestep=timestep)
|
97 |
+
if TTA == False:
|
98 |
+
return pred
|
99 |
+
else:
|
100 |
+
_, _, _, pred2 = self.net(imgs.flip(2).flip(3), timestep=timestep)
|
101 |
+
return (pred + pred2.flip(2).flip(3)) / 2
|
102 |
+
|
103 |
+
@torch.no_grad()
|
104 |
+
def multi_inference(self, img0, img1, TTA = False, down_scale = 1.0, time_list=[], fast_TTA = False):
|
105 |
+
'''
|
106 |
+
Run backbone once, get multi frames at different timesteps
|
107 |
+
Noting: return a list of [CxHxW]
|
108 |
+
'''
|
109 |
+
assert len(time_list) > 0, 'Time_list should not be empty!'
|
110 |
+
def infer(imgs):
|
111 |
+
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
112 |
+
af, mf = self.net.feature_bone(img0, img1)
|
113 |
+
imgs_down = None
|
114 |
+
if down_scale != 1.0:
|
115 |
+
imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
|
116 |
+
afd, mfd = self.net.feature_bone(imgs_down[:, :3], imgs_down[:, 3:6])
|
117 |
+
|
118 |
+
pred_list = []
|
119 |
+
for timestep in time_list:
|
120 |
+
if imgs_down is None:
|
121 |
+
flow, mask = self.net.calculate_flow(imgs, timestep, af, mf)
|
122 |
+
else:
|
123 |
+
flow, mask = self.net.calculate_flow(imgs_down, timestep, afd, mfd)
|
124 |
+
flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
|
125 |
+
mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
|
126 |
+
|
127 |
+
pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
|
128 |
+
pred_list.append(pred)
|
129 |
+
|
130 |
+
return pred_list
|
131 |
+
|
132 |
+
imgs = torch.cat((img0, img1), 1)
|
133 |
+
if fast_TTA:
|
134 |
+
imgs_ = imgs.flip(2).flip(3)
|
135 |
+
input = torch.cat((imgs, imgs_), 0)
|
136 |
+
preds_lst = infer(input)
|
137 |
+
return [(preds_lst[i][0] + preds_lst[i][1].flip(1).flip(2))/2 for i in range(len(time_list))]
|
138 |
+
|
139 |
+
preds = infer(imgs)
|
140 |
+
if TTA is False:
|
141 |
+
return [preds[i][0] for i in range(len(time_list))]
|
142 |
+
else:
|
143 |
+
flip_pred = infer(imgs.flip(2).flip(3))
|
144 |
+
return [(preds[i][0] + flip_pred[i][0].flip(1).flip(2))/2 for i in range(len(time_list))]
|
145 |
+
|
146 |
+
def update(self, imgs, gt, learning_rate=0, training=True):
|
147 |
+
for param_group in self.optimG.param_groups:
|
148 |
+
param_group['lr'] = learning_rate
|
149 |
+
if training:
|
150 |
+
self.train()
|
151 |
+
else:
|
152 |
+
self.eval()
|
153 |
+
|
154 |
+
if training:
|
155 |
+
flow, mask, merged, pred = self.net(imgs)
|
156 |
+
loss_l1 = (self.lap(pred, gt)).mean()
|
157 |
+
|
158 |
+
for merge in merged:
|
159 |
+
loss_l1 += (self.lap(merge, gt)).mean() * 0.5
|
160 |
+
|
161 |
+
self.optimG.zero_grad()
|
162 |
+
loss_l1.backward()
|
163 |
+
self.optimG.step()
|
164 |
+
return pred, loss_l1
|
165 |
+
else:
|
166 |
+
with torch.no_grad():
|
167 |
+
flow, mask, merged, pred = self.net(imgs)
|
168 |
+
return pred, 0
|
i2v_enhance/thirdparty/VFI/ckpt/Put ours.pkl files here.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
here is the link to the all EMA-VFI models:https://drive.google.com/drive/folders/16jUa3HkQ85Z5lb5gce1yoaWkP-rdCd0o
|
i2v_enhance/thirdparty/VFI/ckpt/__init__.py
ADDED
File without changes
|
i2v_enhance/thirdparty/VFI/config.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/config.py
|
2 |
+
from functools import partial
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from i2v_enhance.thirdparty.VFI.model import feature_extractor
|
6 |
+
from i2v_enhance.thirdparty.VFI.model import flow_estimation
|
7 |
+
|
8 |
+
'''==========Model config=========='''
|
9 |
+
def init_model_config(F=32, W=7, depth=[2, 2, 2, 4, 4]):
|
10 |
+
'''This function should not be modified'''
|
11 |
+
return {
|
12 |
+
'embed_dims':[F, 2*F, 4*F, 8*F, 16*F],
|
13 |
+
'motion_dims':[0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]],
|
14 |
+
'num_heads':[8*F//32, 16*F//32],
|
15 |
+
'mlp_ratios':[4, 4],
|
16 |
+
'qkv_bias':True,
|
17 |
+
'norm_layer':partial(nn.LayerNorm, eps=1e-6),
|
18 |
+
'depths':depth,
|
19 |
+
'window_sizes':[W, W]
|
20 |
+
}, {
|
21 |
+
'embed_dims':[F, 2*F, 4*F, 8*F, 16*F],
|
22 |
+
'motion_dims':[0, 0, 0, 8*F//depth[-2], 16*F//depth[-1]],
|
23 |
+
'depths':depth,
|
24 |
+
'num_heads':[8*F//32, 16*F//32],
|
25 |
+
'window_sizes':[W, W],
|
26 |
+
'scales':[4, 8, 16],
|
27 |
+
'hidden_dims':[4*F, 4*F],
|
28 |
+
'c':F
|
29 |
+
}
|
30 |
+
|
31 |
+
MODEL_CONFIG = {
|
32 |
+
'LOGNAME': 'ours',
|
33 |
+
'MODEL_TYPE': (feature_extractor, flow_estimation),
|
34 |
+
'MODEL_ARCH': init_model_config(
|
35 |
+
F = 32,
|
36 |
+
W = 7,
|
37 |
+
depth = [2, 2, 2, 4, 4]
|
38 |
+
)
|
39 |
+
}
|
40 |
+
|
41 |
+
# MODEL_CONFIG = {
|
42 |
+
# 'LOGNAME': 'ours_small',
|
43 |
+
# 'MODEL_TYPE': (feature_extractor, flow_estimation),
|
44 |
+
# 'MODEL_ARCH': init_model_config(
|
45 |
+
# F = 16,
|
46 |
+
# W = 7,
|
47 |
+
# depth = [2, 2, 2, 2, 2]
|
48 |
+
# )
|
49 |
+
# }
|
i2v_enhance/thirdparty/VFI/dataset.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/dataset.py
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from config import *
|
9 |
+
|
10 |
+
cv2.setNumThreads(1)
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
class VimeoDataset(Dataset):
|
13 |
+
def __init__(self, dataset_name, path, batch_size=32, model="RIFE"):
|
14 |
+
self.batch_size = batch_size
|
15 |
+
self.dataset_name = dataset_name
|
16 |
+
self.model = model
|
17 |
+
self.h = 256
|
18 |
+
self.w = 448
|
19 |
+
self.data_root = path
|
20 |
+
self.image_root = os.path.join(self.data_root, 'sequences')
|
21 |
+
train_fn = os.path.join(self.data_root, 'tri_trainlist.txt')
|
22 |
+
test_fn = os.path.join(self.data_root, 'tri_testlist.txt')
|
23 |
+
with open(train_fn, 'r') as f:
|
24 |
+
self.trainlist = f.read().splitlines()
|
25 |
+
with open(test_fn, 'r') as f:
|
26 |
+
self.testlist = f.read().splitlines()
|
27 |
+
self.load_data()
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return len(self.meta_data)
|
31 |
+
|
32 |
+
def load_data(self):
|
33 |
+
if self.dataset_name != 'test':
|
34 |
+
self.meta_data = self.trainlist
|
35 |
+
else:
|
36 |
+
self.meta_data = self.testlist
|
37 |
+
|
38 |
+
def aug(self, img0, gt, img1, h, w):
|
39 |
+
ih, iw, _ = img0.shape
|
40 |
+
x = np.random.randint(0, ih - h + 1)
|
41 |
+
y = np.random.randint(0, iw - w + 1)
|
42 |
+
img0 = img0[x:x+h, y:y+w, :]
|
43 |
+
img1 = img1[x:x+h, y:y+w, :]
|
44 |
+
gt = gt[x:x+h, y:y+w, :]
|
45 |
+
return img0, gt, img1
|
46 |
+
|
47 |
+
def getimg(self, index):
|
48 |
+
imgpath = os.path.join(self.image_root, self.meta_data[index])
|
49 |
+
imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png']
|
50 |
+
|
51 |
+
img0 = cv2.imread(imgpaths[0])
|
52 |
+
gt = cv2.imread(imgpaths[1])
|
53 |
+
img1 = cv2.imread(imgpaths[2])
|
54 |
+
return img0, gt, img1
|
55 |
+
|
56 |
+
def __getitem__(self, index):
|
57 |
+
img0, gt, img1 = self.getimg(index)
|
58 |
+
|
59 |
+
if 'train' in self.dataset_name:
|
60 |
+
img0, gt, img1 = self.aug(img0, gt, img1, 256, 256)
|
61 |
+
if random.uniform(0, 1) < 0.5:
|
62 |
+
img0 = img0[:, :, ::-1]
|
63 |
+
img1 = img1[:, :, ::-1]
|
64 |
+
gt = gt[:, :, ::-1]
|
65 |
+
if random.uniform(0, 1) < 0.5:
|
66 |
+
img1, img0 = img0, img1
|
67 |
+
if random.uniform(0, 1) < 0.5:
|
68 |
+
img0 = img0[::-1]
|
69 |
+
img1 = img1[::-1]
|
70 |
+
gt = gt[::-1]
|
71 |
+
if random.uniform(0, 1) < 0.5:
|
72 |
+
img0 = img0[:, ::-1]
|
73 |
+
img1 = img1[:, ::-1]
|
74 |
+
gt = gt[:, ::-1]
|
75 |
+
|
76 |
+
p = random.uniform(0, 1)
|
77 |
+
if p < 0.25:
|
78 |
+
img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE)
|
79 |
+
gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE)
|
80 |
+
img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE)
|
81 |
+
elif p < 0.5:
|
82 |
+
img0 = cv2.rotate(img0, cv2.ROTATE_180)
|
83 |
+
gt = cv2.rotate(gt, cv2.ROTATE_180)
|
84 |
+
img1 = cv2.rotate(img1, cv2.ROTATE_180)
|
85 |
+
elif p < 0.75:
|
86 |
+
img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
87 |
+
gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
88 |
+
img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
89 |
+
|
90 |
+
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
|
91 |
+
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
|
92 |
+
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
|
93 |
+
return torch.cat((img0, img1, gt), 0)
|
i2v_enhance/thirdparty/VFI/model/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .feature_extractor import feature_extractor
|
2 |
+
from .flow_estimation import MultiScaleFlow as flow_estimation
|
3 |
+
|
4 |
+
|
5 |
+
__all__ = ['feature_extractor', 'flow_estimation']
|
i2v_enhance/thirdparty/VFI/model/feature_extractor.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/feature_extractor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
6 |
+
|
7 |
+
def window_partition(x, window_size):
|
8 |
+
B, H, W, C = x.shape
|
9 |
+
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
10 |
+
windows = (
|
11 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0]*window_size[1], C)
|
12 |
+
)
|
13 |
+
return windows
|
14 |
+
|
15 |
+
|
16 |
+
def window_reverse(windows, window_size, H, W):
|
17 |
+
nwB, N, C = windows.shape
|
18 |
+
windows = windows.view(-1, window_size[0], window_size[1], C)
|
19 |
+
B = int(nwB / (H * W / window_size[0] / window_size[1]))
|
20 |
+
x = windows.view(
|
21 |
+
B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1
|
22 |
+
)
|
23 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
def pad_if_needed(x, size, window_size):
|
28 |
+
n, h, w, c = size
|
29 |
+
pad_h = math.ceil(h / window_size[0]) * window_size[0] - h
|
30 |
+
pad_w = math.ceil(w / window_size[1]) * window_size[1] - w
|
31 |
+
if pad_h > 0 or pad_w > 0: # center-pad the feature on H and W axes
|
32 |
+
img_mask = torch.zeros((1, h+pad_h, w+pad_w, 1)) # 1 H W 1
|
33 |
+
h_slices = (
|
34 |
+
slice(0, pad_h//2),
|
35 |
+
slice(pad_h//2, h+pad_h//2),
|
36 |
+
slice(h+pad_h//2, None),
|
37 |
+
)
|
38 |
+
w_slices = (
|
39 |
+
slice(0, pad_w//2),
|
40 |
+
slice(pad_w//2, w+pad_w//2),
|
41 |
+
slice(w+pad_w//2, None),
|
42 |
+
)
|
43 |
+
cnt = 0
|
44 |
+
for h in h_slices:
|
45 |
+
for w in w_slices:
|
46 |
+
img_mask[:, h, w, :] = cnt
|
47 |
+
cnt += 1
|
48 |
+
|
49 |
+
mask_windows = window_partition(
|
50 |
+
img_mask, window_size
|
51 |
+
) # nW, window_size*window_size, 1
|
52 |
+
mask_windows = mask_windows.squeeze(-1)
|
53 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
54 |
+
attn_mask = attn_mask.masked_fill(
|
55 |
+
attn_mask != 0, float(-100.0)
|
56 |
+
).masked_fill(attn_mask == 0, float(0.0))
|
57 |
+
return nn.functional.pad(
|
58 |
+
x,
|
59 |
+
(0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2),
|
60 |
+
), attn_mask
|
61 |
+
return x, None
|
62 |
+
|
63 |
+
|
64 |
+
def depad_if_needed(x, size, window_size):
|
65 |
+
n, h, w, c = size
|
66 |
+
pad_h = math.ceil(h / window_size[0]) * window_size[0] - h
|
67 |
+
pad_w = math.ceil(w / window_size[1]) * window_size[1] - w
|
68 |
+
if pad_h > 0 or pad_w > 0: # remove the center-padding on feature
|
69 |
+
return x[:, pad_h // 2 : pad_h // 2 + h, pad_w // 2 : pad_w // 2 + w, :].contiguous()
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
class Mlp(nn.Module):
|
74 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
75 |
+
super().__init__()
|
76 |
+
out_features = out_features or in_features
|
77 |
+
hidden_features = hidden_features or in_features
|
78 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
79 |
+
self.dwconv = DWConv(hidden_features)
|
80 |
+
self.act = act_layer()
|
81 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
82 |
+
self.drop = nn.Dropout(drop)
|
83 |
+
self.relu = nn.ReLU(inplace=True)
|
84 |
+
self.apply(self._init_weights)
|
85 |
+
|
86 |
+
def _init_weights(self, m):
|
87 |
+
if isinstance(m, nn.Linear):
|
88 |
+
trunc_normal_(m.weight, std=.02)
|
89 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
90 |
+
nn.init.constant_(m.bias, 0)
|
91 |
+
elif isinstance(m, nn.LayerNorm):
|
92 |
+
nn.init.constant_(m.bias, 0)
|
93 |
+
nn.init.constant_(m.weight, 1.0)
|
94 |
+
elif isinstance(m, nn.Conv2d):
|
95 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
96 |
+
fan_out //= m.groups
|
97 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
98 |
+
if m.bias is not None:
|
99 |
+
m.bias.data.zero_()
|
100 |
+
|
101 |
+
def forward(self, x, H, W):
|
102 |
+
x = self.fc1(x)
|
103 |
+
x = self.dwconv(x, H, W)
|
104 |
+
x = self.act(x)
|
105 |
+
x = self.drop(x)
|
106 |
+
x = self.fc2(x)
|
107 |
+
x = self.drop(x)
|
108 |
+
return x
|
109 |
+
|
110 |
+
|
111 |
+
class InterFrameAttention(nn.Module):
|
112 |
+
def __init__(self, dim, motion_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
113 |
+
super().__init__()
|
114 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
115 |
+
|
116 |
+
self.dim = dim
|
117 |
+
self.motion_dim = motion_dim
|
118 |
+
self.num_heads = num_heads
|
119 |
+
head_dim = dim // num_heads
|
120 |
+
self.scale = qk_scale or head_dim ** -0.5
|
121 |
+
|
122 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
123 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
124 |
+
self.cor_embed = nn.Linear(2, motion_dim, bias=qkv_bias)
|
125 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
126 |
+
self.proj = nn.Linear(dim, dim)
|
127 |
+
self.motion_proj = nn.Linear(motion_dim, motion_dim)
|
128 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
129 |
+
self.apply(self._init_weights)
|
130 |
+
|
131 |
+
def _init_weights(self, m):
|
132 |
+
if isinstance(m, nn.Linear):
|
133 |
+
trunc_normal_(m.weight, std=.02)
|
134 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
135 |
+
nn.init.constant_(m.bias, 0)
|
136 |
+
elif isinstance(m, nn.LayerNorm):
|
137 |
+
nn.init.constant_(m.bias, 0)
|
138 |
+
nn.init.constant_(m.weight, 1.0)
|
139 |
+
elif isinstance(m, nn.Conv2d):
|
140 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
141 |
+
fan_out //= m.groups
|
142 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
143 |
+
if m.bias is not None:
|
144 |
+
m.bias.data.zero_()
|
145 |
+
|
146 |
+
def forward(self, x1, x2, cor, H, W, mask=None):
|
147 |
+
B, N, C = x1.shape
|
148 |
+
B, N, C_c = cor.shape
|
149 |
+
q = self.q(x1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
150 |
+
kv = self.kv(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
151 |
+
cor_embed_ = self.cor_embed(cor)
|
152 |
+
cor_embed = cor_embed_.reshape(B, N, self.num_heads, self.motion_dim // self.num_heads).permute(0, 2, 1, 3)
|
153 |
+
k, v = kv[0], kv[1]
|
154 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
155 |
+
|
156 |
+
if mask is not None:
|
157 |
+
nW = mask.shape[0] # mask: nW, N, N
|
158 |
+
attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
|
159 |
+
1
|
160 |
+
).unsqueeze(0)
|
161 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
162 |
+
attn = attn.softmax(dim=-1)
|
163 |
+
else:
|
164 |
+
attn = attn.softmax(dim=-1)
|
165 |
+
|
166 |
+
attn = self.attn_drop(attn)
|
167 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
168 |
+
c_reverse = (attn @ cor_embed).transpose(1, 2).reshape(B, N, -1)
|
169 |
+
motion = self.motion_proj(c_reverse-cor_embed_)
|
170 |
+
x = self.proj(x)
|
171 |
+
x = self.proj_drop(x)
|
172 |
+
return x, motion
|
173 |
+
|
174 |
+
|
175 |
+
class MotionFormerBlock(nn.Module):
|
176 |
+
def __init__(self, dim, motion_dim, num_heads, window_size=0, shift_size=0, mlp_ratio=4., bidirectional=True, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
177 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,):
|
178 |
+
super().__init__()
|
179 |
+
self.window_size = window_size
|
180 |
+
if not isinstance(self.window_size, (tuple, list)):
|
181 |
+
self.window_size = to_2tuple(window_size)
|
182 |
+
self.shift_size = shift_size
|
183 |
+
if not isinstance(self.shift_size, (tuple, list)):
|
184 |
+
self.shift_size = to_2tuple(shift_size)
|
185 |
+
self.bidirectional = bidirectional
|
186 |
+
self.norm1 = norm_layer(dim)
|
187 |
+
self.attn = InterFrameAttention(
|
188 |
+
dim,
|
189 |
+
motion_dim,
|
190 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
191 |
+
attn_drop=attn_drop, proj_drop=drop)
|
192 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
193 |
+
self.norm2 = norm_layer(dim)
|
194 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
195 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
196 |
+
self.apply(self._init_weights)
|
197 |
+
|
198 |
+
def _init_weights(self, m):
|
199 |
+
if isinstance(m, nn.Linear):
|
200 |
+
trunc_normal_(m.weight, std=.02)
|
201 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
202 |
+
nn.init.constant_(m.bias, 0)
|
203 |
+
elif isinstance(m, nn.LayerNorm):
|
204 |
+
nn.init.constant_(m.bias, 0)
|
205 |
+
nn.init.constant_(m.weight, 1.0)
|
206 |
+
elif isinstance(m, nn.Conv2d):
|
207 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
208 |
+
fan_out //= m.groups
|
209 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
210 |
+
if m.bias is not None:
|
211 |
+
m.bias.data.zero_()
|
212 |
+
|
213 |
+
def forward(self, x, cor, H, W, B):
|
214 |
+
x = x.view(2*B, H, W, -1)
|
215 |
+
x_pad, mask = pad_if_needed(x, x.size(), self.window_size)
|
216 |
+
cor_pad, _ = pad_if_needed(cor, cor.size(), self.window_size)
|
217 |
+
|
218 |
+
if self.shift_size[0] or self.shift_size[1]:
|
219 |
+
_, H_p, W_p, C = x_pad.shape
|
220 |
+
x_pad = torch.roll(x_pad, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
|
221 |
+
cor_pad = torch.roll(cor_pad, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
|
222 |
+
|
223 |
+
if hasattr(self, 'HW') and self.HW.item() == H_p * W_p:
|
224 |
+
shift_mask = self.attn_mask
|
225 |
+
else:
|
226 |
+
shift_mask = torch.zeros((1, H_p, W_p, 1)) # 1 H W 1
|
227 |
+
h_slices = (slice(0, -self.window_size[0]),
|
228 |
+
slice(-self.window_size[0], -self.shift_size[0]),
|
229 |
+
slice(-self.shift_size[0], None))
|
230 |
+
w_slices = (slice(0, -self.window_size[1]),
|
231 |
+
slice(-self.window_size[1], -self.shift_size[1]),
|
232 |
+
slice(-self.shift_size[1], None))
|
233 |
+
cnt = 0
|
234 |
+
for h in h_slices:
|
235 |
+
for w in w_slices:
|
236 |
+
shift_mask[:, h, w, :] = cnt
|
237 |
+
cnt += 1
|
238 |
+
|
239 |
+
mask_windows = window_partition(shift_mask, self.window_size).squeeze(-1)
|
240 |
+
shift_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
241 |
+
shift_mask = shift_mask.masked_fill(shift_mask != 0,
|
242 |
+
float(-100.0)).masked_fill(shift_mask == 0,
|
243 |
+
float(0.0))
|
244 |
+
|
245 |
+
if mask is not None:
|
246 |
+
shift_mask = shift_mask.masked_fill(mask != 0,
|
247 |
+
float(-100.0))
|
248 |
+
self.register_buffer("attn_mask", shift_mask)
|
249 |
+
self.register_buffer("HW", torch.Tensor([H_p*W_p]))
|
250 |
+
else:
|
251 |
+
shift_mask = mask
|
252 |
+
|
253 |
+
if shift_mask is not None:
|
254 |
+
shift_mask = shift_mask.to(x_pad.device)
|
255 |
+
|
256 |
+
|
257 |
+
_, Hw, Ww, C = x_pad.shape
|
258 |
+
x_win = window_partition(x_pad, self.window_size)
|
259 |
+
cor_win = window_partition(cor_pad, self.window_size)
|
260 |
+
|
261 |
+
nwB = x_win.shape[0]
|
262 |
+
x_norm = self.norm1(x_win)
|
263 |
+
|
264 |
+
x_reverse = torch.cat([x_norm[nwB//2:], x_norm[:nwB//2]])
|
265 |
+
x_appearence, x_motion = self.attn(x_norm, x_reverse, cor_win, H, W, shift_mask)
|
266 |
+
x_norm = x_norm + self.drop_path(x_appearence)
|
267 |
+
|
268 |
+
x_back = x_norm
|
269 |
+
x_back_win = window_reverse(x_back, self.window_size, Hw, Ww)
|
270 |
+
x_motion = window_reverse(x_motion, self.window_size, Hw, Ww)
|
271 |
+
|
272 |
+
if self.shift_size[0] or self.shift_size[1]:
|
273 |
+
x_back_win = torch.roll(x_back_win, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
|
274 |
+
x_motion = torch.roll(x_motion, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
|
275 |
+
|
276 |
+
x = depad_if_needed(x_back_win, x.size(), self.window_size).view(2*B, H * W, -1)
|
277 |
+
x_motion = depad_if_needed(x_motion, cor.size(), self.window_size).view(2*B, H * W, -1)
|
278 |
+
|
279 |
+
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
280 |
+
return x, x_motion
|
281 |
+
|
282 |
+
|
283 |
+
class ConvBlock(nn.Module):
|
284 |
+
def __init__(self, in_dim, out_dim, depths=2,act_layer=nn.PReLU):
|
285 |
+
super().__init__()
|
286 |
+
layers = []
|
287 |
+
for i in range(depths):
|
288 |
+
if i == 0:
|
289 |
+
layers.append(nn.Conv2d(in_dim, out_dim, 3,1,1))
|
290 |
+
else:
|
291 |
+
layers.append(nn.Conv2d(out_dim, out_dim, 3,1,1))
|
292 |
+
layers.extend([
|
293 |
+
act_layer(out_dim),
|
294 |
+
])
|
295 |
+
self.conv = nn.Sequential(*layers)
|
296 |
+
|
297 |
+
def _init_weights(self, m):
|
298 |
+
if isinstance(m, nn.Conv2d):
|
299 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
300 |
+
fan_out //= m.groups
|
301 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
302 |
+
if m.bias is not None:
|
303 |
+
m.bias.data.zero_()
|
304 |
+
|
305 |
+
def forward(self, x):
|
306 |
+
x = self.conv(x)
|
307 |
+
return x
|
308 |
+
|
309 |
+
|
310 |
+
class OverlapPatchEmbed(nn.Module):
|
311 |
+
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
312 |
+
super().__init__()
|
313 |
+
patch_size = to_2tuple(patch_size)
|
314 |
+
|
315 |
+
self.patch_size = patch_size
|
316 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
317 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
318 |
+
self.norm = nn.LayerNorm(embed_dim)
|
319 |
+
|
320 |
+
self.apply(self._init_weights)
|
321 |
+
|
322 |
+
def _init_weights(self, m):
|
323 |
+
if isinstance(m, nn.Linear):
|
324 |
+
trunc_normal_(m.weight, std=.02)
|
325 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
326 |
+
nn.init.constant_(m.bias, 0)
|
327 |
+
elif isinstance(m, nn.LayerNorm):
|
328 |
+
nn.init.constant_(m.bias, 0)
|
329 |
+
nn.init.constant_(m.weight, 1.0)
|
330 |
+
elif isinstance(m, nn.Conv2d):
|
331 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
332 |
+
fan_out //= m.groups
|
333 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
334 |
+
if m.bias is not None:
|
335 |
+
m.bias.data.zero_()
|
336 |
+
|
337 |
+
def forward(self, x):
|
338 |
+
x = self.proj(x)
|
339 |
+
_, _, H, W = x.shape
|
340 |
+
x = x.flatten(2).transpose(1, 2)
|
341 |
+
x = self.norm(x)
|
342 |
+
|
343 |
+
return x, H, W
|
344 |
+
|
345 |
+
|
346 |
+
class CrossScalePatchEmbed(nn.Module):
|
347 |
+
def __init__(self, in_dims=[16,32,64], embed_dim=768):
|
348 |
+
super().__init__()
|
349 |
+
base_dim = in_dims[0]
|
350 |
+
|
351 |
+
layers = []
|
352 |
+
for i in range(len(in_dims)):
|
353 |
+
for j in range(2 ** i):
|
354 |
+
layers.append(nn.Conv2d(in_dims[-1-i], base_dim, 3, 2**(i+1), 1+j, 1+j))
|
355 |
+
self.layers = nn.ModuleList(layers)
|
356 |
+
self.proj = nn.Conv2d(base_dim * len(layers), embed_dim, 1, 1)
|
357 |
+
self.norm = nn.LayerNorm(embed_dim)
|
358 |
+
|
359 |
+
self.apply(self._init_weights)
|
360 |
+
|
361 |
+
def _init_weights(self, m):
|
362 |
+
if isinstance(m, nn.Linear):
|
363 |
+
trunc_normal_(m.weight, std=.02)
|
364 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
365 |
+
nn.init.constant_(m.bias, 0)
|
366 |
+
elif isinstance(m, nn.LayerNorm):
|
367 |
+
nn.init.constant_(m.bias, 0)
|
368 |
+
nn.init.constant_(m.weight, 1.0)
|
369 |
+
elif isinstance(m, nn.Conv2d):
|
370 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
371 |
+
fan_out //= m.groups
|
372 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
373 |
+
if m.bias is not None:
|
374 |
+
m.bias.data.zero_()
|
375 |
+
|
376 |
+
def forward(self, xs):
|
377 |
+
ys = []
|
378 |
+
k = 0
|
379 |
+
for i in range(len(xs)):
|
380 |
+
for _ in range(2 ** i):
|
381 |
+
ys.append(self.layers[k](xs[-1-i]))
|
382 |
+
k += 1
|
383 |
+
x = self.proj(torch.cat(ys,1))
|
384 |
+
_, _, H, W = x.shape
|
385 |
+
x = x.flatten(2).transpose(1, 2)
|
386 |
+
x = self.norm(x)
|
387 |
+
|
388 |
+
return x, H, W
|
389 |
+
|
390 |
+
|
391 |
+
class MotionFormer(nn.Module):
|
392 |
+
def __init__(self, in_chans=3, embed_dims=[32, 64, 128, 256, 512], motion_dims=64, num_heads=[8, 16],
|
393 |
+
mlp_ratios=[4, 4], qkv_bias=True, qk_scale=None, drop_rate=0.,
|
394 |
+
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
395 |
+
depths=[2, 2, 2, 6, 2], window_sizes=[11, 11],**kwarg):
|
396 |
+
super().__init__()
|
397 |
+
self.depths = depths
|
398 |
+
self.num_stages = len(embed_dims)
|
399 |
+
|
400 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
401 |
+
cur = 0
|
402 |
+
|
403 |
+
self.conv_stages = self.num_stages - len(num_heads)
|
404 |
+
|
405 |
+
for i in range(self.num_stages):
|
406 |
+
if i == 0:
|
407 |
+
block = ConvBlock(in_chans,embed_dims[i],depths[i])
|
408 |
+
else:
|
409 |
+
if i < self.conv_stages:
|
410 |
+
patch_embed = nn.Sequential(
|
411 |
+
nn.Conv2d(embed_dims[i-1], embed_dims[i], 3,2,1),
|
412 |
+
nn.PReLU(embed_dims[i])
|
413 |
+
)
|
414 |
+
block = ConvBlock(embed_dims[i],embed_dims[i],depths[i])
|
415 |
+
else:
|
416 |
+
if i == self.conv_stages:
|
417 |
+
patch_embed = CrossScalePatchEmbed(embed_dims[:i],
|
418 |
+
embed_dim=embed_dims[i])
|
419 |
+
else:
|
420 |
+
patch_embed = OverlapPatchEmbed(patch_size=3,
|
421 |
+
stride=2,
|
422 |
+
in_chans=embed_dims[i - 1],
|
423 |
+
embed_dim=embed_dims[i])
|
424 |
+
|
425 |
+
block = nn.ModuleList([MotionFormerBlock(
|
426 |
+
dim=embed_dims[i], motion_dim=motion_dims[i], num_heads=num_heads[i-self.conv_stages], window_size=window_sizes[i-self.conv_stages],
|
427 |
+
shift_size= 0 if (j % 2) == 0 else window_sizes[i-self.conv_stages] // 2,
|
428 |
+
mlp_ratio=mlp_ratios[i-self.conv_stages], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
429 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer)
|
430 |
+
for j in range(depths[i])])
|
431 |
+
|
432 |
+
norm = norm_layer(embed_dims[i])
|
433 |
+
setattr(self, f"norm{i + 1}", norm)
|
434 |
+
setattr(self, f"patch_embed{i + 1}", patch_embed)
|
435 |
+
cur += depths[i]
|
436 |
+
|
437 |
+
setattr(self, f"block{i + 1}", block)
|
438 |
+
|
439 |
+
self.cor = {}
|
440 |
+
|
441 |
+
self.apply(self._init_weights)
|
442 |
+
|
443 |
+
def _init_weights(self, m):
|
444 |
+
if isinstance(m, nn.Linear):
|
445 |
+
trunc_normal_(m.weight, std=.02)
|
446 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
447 |
+
nn.init.constant_(m.bias, 0)
|
448 |
+
elif isinstance(m, nn.LayerNorm):
|
449 |
+
nn.init.constant_(m.bias, 0)
|
450 |
+
nn.init.constant_(m.weight, 1.0)
|
451 |
+
elif isinstance(m, nn.Conv2d):
|
452 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
453 |
+
fan_out //= m.groups
|
454 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
455 |
+
if m.bias is not None:
|
456 |
+
m.bias.data.zero_()
|
457 |
+
|
458 |
+
def get_cor(self, shape, device):
|
459 |
+
k = (str(shape), str(device))
|
460 |
+
if k not in self.cor:
|
461 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, shape[2], device=device).view(
|
462 |
+
1, 1, 1, shape[2]).expand(shape[0], -1, shape[1], -1).permute(0, 2, 3, 1)
|
463 |
+
tenVertical = torch.linspace(-1.0, 1.0, shape[1], device=device).view(
|
464 |
+
1, 1, shape[1], 1).expand(shape[0], -1, -1, shape[2]).permute(0, 2, 3, 1)
|
465 |
+
self.cor[k] = torch.cat([tenHorizontal, tenVertical], -1).to(device)
|
466 |
+
return self.cor[k]
|
467 |
+
|
468 |
+
def forward(self, x1, x2):
|
469 |
+
B = x1.shape[0]
|
470 |
+
x = torch.cat([x1, x2], 0)
|
471 |
+
motion_features = []
|
472 |
+
appearence_features = []
|
473 |
+
xs = []
|
474 |
+
for i in range(self.num_stages):
|
475 |
+
motion_features.append([])
|
476 |
+
patch_embed = getattr(self, f"patch_embed{i + 1}",None)
|
477 |
+
block = getattr(self, f"block{i + 1}",None)
|
478 |
+
norm = getattr(self, f"norm{i + 1}",None)
|
479 |
+
if i < self.conv_stages:
|
480 |
+
if i > 0:
|
481 |
+
x = patch_embed(x)
|
482 |
+
x = block(x)
|
483 |
+
xs.append(x)
|
484 |
+
else:
|
485 |
+
if i == self.conv_stages:
|
486 |
+
x, H, W = patch_embed(xs)
|
487 |
+
else:
|
488 |
+
x, H, W = patch_embed(x)
|
489 |
+
cor = self.get_cor((x.shape[0], H, W), x.device)
|
490 |
+
for blk in block:
|
491 |
+
x, x_motion = blk(x, cor, H, W, B)
|
492 |
+
motion_features[i].append(x_motion.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous())
|
493 |
+
x = norm(x)
|
494 |
+
x = x.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
495 |
+
motion_features[i] = torch.cat(motion_features[i], 1)
|
496 |
+
appearence_features.append(x)
|
497 |
+
return appearence_features, motion_features
|
498 |
+
|
499 |
+
|
500 |
+
class DWConv(nn.Module):
|
501 |
+
def __init__(self, dim):
|
502 |
+
super(DWConv, self).__init__()
|
503 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
504 |
+
|
505 |
+
def forward(self, x, H, W):
|
506 |
+
B, N, C = x.shape
|
507 |
+
x = x.transpose(1, 2).reshape(B, C, H, W)
|
508 |
+
x = self.dwconv(x)
|
509 |
+
x = x.reshape(B, C, -1).transpose(1, 2)
|
510 |
+
|
511 |
+
return x
|
512 |
+
|
513 |
+
|
514 |
+
def feature_extractor(**kargs):
|
515 |
+
model = MotionFormer(**kargs)
|
516 |
+
return model
|
i2v_enhance/thirdparty/VFI/model/flow_estimation.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/flow_estimation
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .warplayer import warp
|
7 |
+
from .refine import *
|
8 |
+
|
9 |
+
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
10 |
+
return nn.Sequential(
|
11 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
12 |
+
padding=padding, dilation=dilation, bias=True),
|
13 |
+
nn.PReLU(out_planes)
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class Head(nn.Module):
|
18 |
+
def __init__(self, in_planes, scale, c, in_else=17):
|
19 |
+
super(Head, self).__init__()
|
20 |
+
self.upsample = nn.Sequential(nn.PixelShuffle(2), nn.PixelShuffle(2))
|
21 |
+
self.scale = scale
|
22 |
+
self.conv = nn.Sequential(
|
23 |
+
conv(in_planes*2 // (4*4) + in_else, c),
|
24 |
+
conv(c, c),
|
25 |
+
conv(c, 5),
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, motion_feature, x, flow): # /16 /8 /4
|
29 |
+
motion_feature = self.upsample(motion_feature) #/4 /2 /1
|
30 |
+
if self.scale != 4:
|
31 |
+
x = F.interpolate(x, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False)
|
32 |
+
if flow != None:
|
33 |
+
if self.scale != 4:
|
34 |
+
flow = F.interpolate(flow, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) * 4. / self.scale
|
35 |
+
x = torch.cat((x, flow), 1)
|
36 |
+
x = self.conv(torch.cat([motion_feature, x], 1))
|
37 |
+
if self.scale != 4:
|
38 |
+
x = F.interpolate(x, scale_factor = self.scale // 4, mode="bilinear", align_corners=False)
|
39 |
+
flow = x[:, :4] * (self.scale // 4)
|
40 |
+
else:
|
41 |
+
flow = x[:, :4]
|
42 |
+
mask = x[:, 4:5]
|
43 |
+
return flow, mask
|
44 |
+
|
45 |
+
|
46 |
+
class MultiScaleFlow(nn.Module):
|
47 |
+
def __init__(self, backbone, **kargs):
|
48 |
+
super(MultiScaleFlow, self).__init__()
|
49 |
+
self.flow_num_stage = len(kargs['hidden_dims'])
|
50 |
+
self.feature_bone = backbone
|
51 |
+
self.block = nn.ModuleList([Head( kargs['motion_dims'][-1-i] * kargs['depths'][-1-i] + kargs['embed_dims'][-1-i],
|
52 |
+
kargs['scales'][-1-i],
|
53 |
+
kargs['hidden_dims'][-1-i],
|
54 |
+
6 if i==0 else 17)
|
55 |
+
for i in range(self.flow_num_stage)])
|
56 |
+
self.unet = Unet(kargs['c'] * 2)
|
57 |
+
|
58 |
+
def warp_features(self, xs, flow):
|
59 |
+
y0 = []
|
60 |
+
y1 = []
|
61 |
+
B = xs[0].size(0) // 2
|
62 |
+
for x in xs:
|
63 |
+
y0.append(warp(x[:B], flow[:, 0:2]))
|
64 |
+
y1.append(warp(x[B:], flow[:, 2:4]))
|
65 |
+
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
|
66 |
+
return y0, y1
|
67 |
+
|
68 |
+
def calculate_flow(self, imgs, timestep, af=None, mf=None):
|
69 |
+
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
70 |
+
B = img0.size(0)
|
71 |
+
flow, mask = None, None
|
72 |
+
# appearence_features & motion_features
|
73 |
+
if (af is None) or (mf is None):
|
74 |
+
af, mf = self.feature_bone(img0, img1)
|
75 |
+
for i in range(self.flow_num_stage):
|
76 |
+
t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda()
|
77 |
+
if flow != None:
|
78 |
+
warped_img0 = warp(img0, flow[:, :2])
|
79 |
+
warped_img1 = warp(img1, flow[:, 2:4])
|
80 |
+
flow_, mask_ = self.block[i](
|
81 |
+
torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
|
82 |
+
torch.cat((img0, img1, warped_img0, warped_img1, mask), 1),
|
83 |
+
flow
|
84 |
+
)
|
85 |
+
flow = flow + flow_
|
86 |
+
mask = mask + mask_
|
87 |
+
else:
|
88 |
+
flow, mask = self.block[i](
|
89 |
+
torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
|
90 |
+
torch.cat((img0, img1), 1),
|
91 |
+
None
|
92 |
+
)
|
93 |
+
|
94 |
+
return flow, mask
|
95 |
+
|
96 |
+
def coraseWarp_and_Refine(self, imgs, af, flow, mask):
|
97 |
+
img0, img1 = imgs[:, :3], imgs[:, 3:6]
|
98 |
+
warped_img0 = warp(img0, flow[:, :2])
|
99 |
+
warped_img1 = warp(img1, flow[:, 2:4])
|
100 |
+
c0, c1 = self.warp_features(af, flow)
|
101 |
+
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
102 |
+
res = tmp[:, :3] * 2 - 1
|
103 |
+
mask_ = torch.sigmoid(mask)
|
104 |
+
merged = warped_img0 * mask_ + warped_img1 * (1 - mask_)
|
105 |
+
pred = torch.clamp(merged + res, 0, 1)
|
106 |
+
return pred
|
107 |
+
|
108 |
+
|
109 |
+
# Actually consist of 'calculate_flow' and 'coraseWarp_and_Refine'
|
110 |
+
def forward(self, x, timestep=0.5):
|
111 |
+
img0, img1 = x[:, :3], x[:, 3:6]
|
112 |
+
B = x.size(0)
|
113 |
+
flow_list = []
|
114 |
+
merged = []
|
115 |
+
mask_list = []
|
116 |
+
warped_img0 = img0
|
117 |
+
warped_img1 = img1
|
118 |
+
flow = None
|
119 |
+
# appearence_features & motion_features
|
120 |
+
af, mf = self.feature_bone(img0, img1)
|
121 |
+
for i in range(self.flow_num_stage):
|
122 |
+
t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda()
|
123 |
+
if flow != None:
|
124 |
+
flow_d, mask_d = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-timestep)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
|
125 |
+
torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow)
|
126 |
+
flow = flow + flow_d
|
127 |
+
mask = mask + mask_d
|
128 |
+
else:
|
129 |
+
flow, mask = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1),
|
130 |
+
torch.cat((img0, img1), 1), None)
|
131 |
+
mask_list.append(torch.sigmoid(mask))
|
132 |
+
flow_list.append(flow)
|
133 |
+
warped_img0 = warp(img0, flow[:, :2])
|
134 |
+
warped_img1 = warp(img1, flow[:, 2:4])
|
135 |
+
merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i]))
|
136 |
+
|
137 |
+
c0, c1 = self.warp_features(af, flow)
|
138 |
+
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
139 |
+
res = tmp[:, :3] * 2 - 1
|
140 |
+
pred = torch.clamp(merged[-1] + res, 0, 1)
|
141 |
+
return flow_list, mask_list, merged, pred
|
i2v_enhance/thirdparty/VFI/model/loss.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/loss.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
|
9 |
+
def gauss_kernel(channels=3):
|
10 |
+
kernel = torch.tensor([[1., 4., 6., 4., 1],
|
11 |
+
[4., 16., 24., 16., 4.],
|
12 |
+
[6., 24., 36., 24., 6.],
|
13 |
+
[4., 16., 24., 16., 4.],
|
14 |
+
[1., 4., 6., 4., 1.]])
|
15 |
+
kernel /= 256.
|
16 |
+
kernel = kernel.repeat(channels, 1, 1, 1)
|
17 |
+
kernel = kernel.to(device)
|
18 |
+
return kernel
|
19 |
+
|
20 |
+
def downsample(x):
|
21 |
+
return x[:, :, ::2, ::2]
|
22 |
+
|
23 |
+
def upsample(x):
|
24 |
+
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3)
|
25 |
+
cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3])
|
26 |
+
cc = cc.permute(0,1,3,2)
|
27 |
+
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2).to(device)], dim=3)
|
28 |
+
cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2)
|
29 |
+
x_up = cc.permute(0,1,3,2)
|
30 |
+
return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1]))
|
31 |
+
|
32 |
+
def conv_gauss(img, kernel):
|
33 |
+
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
|
34 |
+
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
|
35 |
+
return out
|
36 |
+
|
37 |
+
def laplacian_pyramid(img, kernel, max_levels=3):
|
38 |
+
current = img
|
39 |
+
pyr = []
|
40 |
+
for level in range(max_levels):
|
41 |
+
filtered = conv_gauss(current, kernel)
|
42 |
+
down = downsample(filtered)
|
43 |
+
up = upsample(down)
|
44 |
+
diff = current-up
|
45 |
+
pyr.append(diff)
|
46 |
+
current = down
|
47 |
+
return pyr
|
48 |
+
|
49 |
+
class LapLoss(torch.nn.Module):
|
50 |
+
def __init__(self, max_levels=5, channels=3):
|
51 |
+
super(LapLoss, self).__init__()
|
52 |
+
self.max_levels = max_levels
|
53 |
+
self.gauss_kernel = gauss_kernel(channels=channels)
|
54 |
+
|
55 |
+
def forward(self, input, target):
|
56 |
+
pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
|
57 |
+
pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
|
58 |
+
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
|
59 |
+
|
60 |
+
class Ternary(nn.Module):
|
61 |
+
def __init__(self, device):
|
62 |
+
super(Ternary, self).__init__()
|
63 |
+
patch_size = 7
|
64 |
+
out_channels = patch_size * patch_size
|
65 |
+
self.w = np.eye(out_channels).reshape(
|
66 |
+
(patch_size, patch_size, 1, out_channels))
|
67 |
+
self.w = np.transpose(self.w, (3, 2, 0, 1))
|
68 |
+
self.w = torch.tensor(self.w).float().to(device)
|
69 |
+
|
70 |
+
def transform(self, img):
|
71 |
+
patches = F.conv2d(img, self.w, padding=3, bias=None)
|
72 |
+
transf = patches - img
|
73 |
+
transf_norm = transf / torch.sqrt(0.81 + transf**2)
|
74 |
+
return transf_norm
|
75 |
+
|
76 |
+
def rgb2gray(self, rgb):
|
77 |
+
r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
|
78 |
+
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
79 |
+
return gray
|
80 |
+
|
81 |
+
def hamming(self, t1, t2):
|
82 |
+
dist = (t1 - t2) ** 2
|
83 |
+
dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
|
84 |
+
return dist_norm
|
85 |
+
|
86 |
+
def valid_mask(self, t, padding):
|
87 |
+
n, _, h, w = t.size()
|
88 |
+
inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
|
89 |
+
mask = F.pad(inner, [padding] * 4)
|
90 |
+
return mask
|
91 |
+
|
92 |
+
def forward(self, img0, img1):
|
93 |
+
img0 = self.transform(self.rgb2gray(img0))
|
94 |
+
img1 = self.transform(self.rgb2gray(img1))
|
95 |
+
return self.hamming(img0, img1) * self.valid_mask(img0, 1)
|
i2v_enhance/thirdparty/VFI/model/refine.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from timm.models.layers import trunc_normal_
|
5 |
+
|
6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
|
8 |
+
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
9 |
+
return nn.Sequential(
|
10 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
11 |
+
padding=padding, dilation=dilation, bias=True),
|
12 |
+
nn.PReLU(out_planes)
|
13 |
+
)
|
14 |
+
|
15 |
+
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
16 |
+
return nn.Sequential(
|
17 |
+
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
|
18 |
+
nn.PReLU(out_planes)
|
19 |
+
)
|
20 |
+
|
21 |
+
class Conv2(nn.Module):
|
22 |
+
def __init__(self, in_planes, out_planes, stride=2):
|
23 |
+
super(Conv2, self).__init__()
|
24 |
+
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
|
25 |
+
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
x = self.conv1(x)
|
29 |
+
x = self.conv2(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
class Unet(nn.Module):
|
33 |
+
def __init__(self, c, out=3):
|
34 |
+
super(Unet, self).__init__()
|
35 |
+
self.down0 = Conv2(17+c, 2*c)
|
36 |
+
self.down1 = Conv2(4*c, 4*c)
|
37 |
+
self.down2 = Conv2(8*c, 8*c)
|
38 |
+
self.down3 = Conv2(16*c, 16*c)
|
39 |
+
self.up0 = deconv(32*c, 8*c)
|
40 |
+
self.up1 = deconv(16*c, 4*c)
|
41 |
+
self.up2 = deconv(8*c, 2*c)
|
42 |
+
self.up3 = deconv(4*c, c)
|
43 |
+
self.conv = nn.Conv2d(c, out, 3, 1, 1)
|
44 |
+
self.apply(self._init_weights)
|
45 |
+
|
46 |
+
def _init_weights(self, m):
|
47 |
+
if isinstance(m, nn.Linear):
|
48 |
+
trunc_normal_(m.weight, std=.02)
|
49 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
50 |
+
nn.init.constant_(m.bias, 0)
|
51 |
+
elif isinstance(m, nn.LayerNorm):
|
52 |
+
nn.init.constant_(m.bias, 0)
|
53 |
+
nn.init.constant_(m.weight, 1.0)
|
54 |
+
elif isinstance(m, nn.Conv2d):
|
55 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
56 |
+
fan_out //= m.groups
|
57 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
58 |
+
if m.bias is not None:
|
59 |
+
m.bias.data.zero_()
|
60 |
+
|
61 |
+
def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
|
62 |
+
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow,c0[0], c1[0]), 1))
|
63 |
+
s1 = self.down1(torch.cat((s0, c0[1], c1[1]), 1))
|
64 |
+
s2 = self.down2(torch.cat((s1, c0[2], c1[2]), 1))
|
65 |
+
s3 = self.down3(torch.cat((s2, c0[3], c1[3]), 1))
|
66 |
+
x = self.up0(torch.cat((s3, c0[4], c1[4]), 1))
|
67 |
+
x = self.up1(torch.cat((x, s2), 1))
|
68 |
+
x = self.up2(torch.cat((x, s1), 1))
|
69 |
+
x = self.up3(torch.cat((x, s0), 1))
|
70 |
+
x = self.conv(x)
|
71 |
+
return torch.sigmoid(x)
|
i2v_enhance/thirdparty/VFI/model/warplayer.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/model/warplayer.py
|
2 |
+
import torch
|
3 |
+
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
backwarp_tenGrid = {}
|
6 |
+
|
7 |
+
def warp(tenInput, tenFlow):
|
8 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
9 |
+
if k not in backwarp_tenGrid:
|
10 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
11 |
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
12 |
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
13 |
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
14 |
+
backwarp_tenGrid[k] = torch.cat(
|
15 |
+
[tenHorizontal, tenVertical], 1).to(device)
|
16 |
+
|
17 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
18 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
19 |
+
|
20 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
21 |
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
i2v_enhance/thirdparty/VFI/train.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/train.py
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import math
|
5 |
+
import time
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
import numpy as np
|
9 |
+
import random
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
from Trainer import Model
|
13 |
+
from dataset import VimeoDataset
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from torch.utils.tensorboard import SummaryWriter
|
16 |
+
from torch.utils.data.distributed import DistributedSampler
|
17 |
+
from config import *
|
18 |
+
|
19 |
+
device = torch.device("cuda")
|
20 |
+
exp = os.path.abspath('.').split('/')[-1]
|
21 |
+
|
22 |
+
def get_learning_rate(step):
|
23 |
+
if step < 2000:
|
24 |
+
mul = step / 2000
|
25 |
+
return 2e-4 * mul
|
26 |
+
else:
|
27 |
+
mul = np.cos((step - 2000) / (300 * args.step_per_epoch - 2000) * math.pi) * 0.5 + 0.5
|
28 |
+
return (2e-4 - 2e-5) * mul + 2e-5
|
29 |
+
|
30 |
+
def train(model, local_rank, batch_size, data_path):
|
31 |
+
if local_rank == 0:
|
32 |
+
writer = SummaryWriter('log/train_EMAVFI')
|
33 |
+
step = 0
|
34 |
+
nr_eval = 0
|
35 |
+
best = 0
|
36 |
+
dataset = VimeoDataset('train', data_path)
|
37 |
+
sampler = DistributedSampler(dataset)
|
38 |
+
train_data = DataLoader(dataset, batch_size=batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler)
|
39 |
+
args.step_per_epoch = train_data.__len__()
|
40 |
+
dataset_val = VimeoDataset('test', data_path)
|
41 |
+
val_data = DataLoader(dataset_val, batch_size=batch_size, pin_memory=True, num_workers=8)
|
42 |
+
print('training...')
|
43 |
+
time_stamp = time.time()
|
44 |
+
for epoch in range(300):
|
45 |
+
sampler.set_epoch(epoch)
|
46 |
+
for i, imgs in enumerate(train_data):
|
47 |
+
data_time_interval = time.time() - time_stamp
|
48 |
+
time_stamp = time.time()
|
49 |
+
imgs = imgs.to(device, non_blocking=True) / 255.
|
50 |
+
imgs, gt = imgs[:, 0:6], imgs[:, 6:]
|
51 |
+
learning_rate = get_learning_rate(step)
|
52 |
+
_, loss = model.update(imgs, gt, learning_rate, training=True)
|
53 |
+
train_time_interval = time.time() - time_stamp
|
54 |
+
time_stamp = time.time()
|
55 |
+
if step % 200 == 1 and local_rank == 0:
|
56 |
+
writer.add_scalar('learning_rate', learning_rate, step)
|
57 |
+
writer.add_scalar('loss', loss, step)
|
58 |
+
if local_rank == 0:
|
59 |
+
print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, loss))
|
60 |
+
step += 1
|
61 |
+
nr_eval += 1
|
62 |
+
if nr_eval % 3 == 0:
|
63 |
+
evaluate(model, val_data, nr_eval, local_rank)
|
64 |
+
model.save_model(local_rank)
|
65 |
+
|
66 |
+
dist.barrier()
|
67 |
+
|
68 |
+
def evaluate(model, val_data, nr_eval, local_rank):
|
69 |
+
if local_rank == 0:
|
70 |
+
writer_val = SummaryWriter('log/validate_EMAVFI')
|
71 |
+
|
72 |
+
psnr = []
|
73 |
+
for _, imgs in enumerate(val_data):
|
74 |
+
imgs = imgs.to(device, non_blocking=True) / 255.
|
75 |
+
imgs, gt = imgs[:, 0:6], imgs[:, 6:]
|
76 |
+
with torch.no_grad():
|
77 |
+
pred, _ = model.update(imgs, gt, training=False)
|
78 |
+
for j in range(gt.shape[0]):
|
79 |
+
psnr.append(-10 * math.log10(((gt[j] - pred[j]) * (gt[j] - pred[j])).mean().cpu().item()))
|
80 |
+
|
81 |
+
psnr = np.array(psnr).mean()
|
82 |
+
if local_rank == 0:
|
83 |
+
print(str(nr_eval), psnr)
|
84 |
+
writer_val.add_scalar('psnr', psnr, nr_eval)
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
parser = argparse.ArgumentParser()
|
88 |
+
parser.add_argument('--local_rank', default=0, type=int, help='local rank')
|
89 |
+
parser.add_argument('--world_size', default=4, type=int, help='world size')
|
90 |
+
parser.add_argument('--batch_size', default=8, type=int, help='batch size')
|
91 |
+
parser.add_argument('--data_path', type=str, help='data path of vimeo90k')
|
92 |
+
args = parser.parse_args()
|
93 |
+
torch.distributed.init_process_group(backend="nccl", world_size=args.world_size)
|
94 |
+
torch.cuda.set_device(args.local_rank)
|
95 |
+
if args.local_rank == 0 and not os.path.exists('log'):
|
96 |
+
os.mkdir('log')
|
97 |
+
seed = 1234
|
98 |
+
random.seed(seed)
|
99 |
+
np.random.seed(seed)
|
100 |
+
torch.manual_seed(seed)
|
101 |
+
torch.cuda.manual_seed_all(seed)
|
102 |
+
torch.backends.cudnn.benchmark = True
|
103 |
+
model = Model(args.local_rank)
|
104 |
+
train(model, args.local_rank, args.batch_size, args.data_path)
|
105 |
+
|
lib/__init__.py
ADDED
File without changes
|
lib/farancia/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .libimage import IImage
|
2 |
+
|
3 |
+
from os.path import dirname, pardir, realpath
|
4 |
+
import os
|
lib/farancia/animation.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
from matplotlib import animation
|
3 |
+
|
4 |
+
|
5 |
+
class Animation:
|
6 |
+
JS = 0
|
7 |
+
HTML = 1
|
8 |
+
ANIMATION_MODE = HTML
|
9 |
+
|
10 |
+
def __init__(self, frames, fps=30):
|
11 |
+
"""_summary_
|
12 |
+
|
13 |
+
Args:
|
14 |
+
frames (np.ndarray): _description_
|
15 |
+
"""
|
16 |
+
self.frames = frames
|
17 |
+
self.fps = fps
|
18 |
+
self.anim_obj = None
|
19 |
+
self.anim_str = None
|
20 |
+
|
21 |
+
def render(self):
|
22 |
+
size = (self.frames.shape[2], self.frames.shape[1])
|
23 |
+
self.fig = plt.figure(figsize=size, dpi=1)
|
24 |
+
plt.axis('off')
|
25 |
+
img = plt.imshow(self.frames[0], cmap='gray', vmin=0, vmax=255)
|
26 |
+
self.fig.subplots_adjust(0, 0, 1, 1)
|
27 |
+
self.anim_obj = animation.FuncAnimation(
|
28 |
+
self.fig,
|
29 |
+
lambda i: img.set_data(self.frames[i, :, :, :]),
|
30 |
+
frames=self.frames.shape[0],
|
31 |
+
interval=1000 / self.fps
|
32 |
+
)
|
33 |
+
plt.close()
|
34 |
+
if Animation.ANIMATION_MODE == Animation.HTML:
|
35 |
+
self.anim_str = self.anim_obj.to_html5_video()
|
36 |
+
elif Animation.ANIMATION_MODE == Animation.JS:
|
37 |
+
self.anim_str = self.anim_obj.to_jshtml()
|
38 |
+
return self.anim_obj
|
39 |
+
|
40 |
+
def _repr_html_(self):
|
41 |
+
if self.anim_obj is None:
|
42 |
+
self.render()
|
43 |
+
return self.anim_str
|
lib/farancia/config.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
IMG_THUMBSIZE = None
|
lib/farancia/libimage/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .iimage import IImage
|
2 |
+
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
# ========= STATIC FUNCTIONS =============
|
8 |
+
def find_max_h(images):
|
9 |
+
return max([x.size[1] for x in images])
|
10 |
+
def find_max_w(images):
|
11 |
+
return max([x.size[0] for x in images])
|
12 |
+
def find_max_size(images):
|
13 |
+
return find_max_w(images), find_max_h(images)
|
14 |
+
|
15 |
+
|
16 |
+
def stack(images, axis = 0):
|
17 |
+
return IImage(np.concatenate([x.data for x in images], axis))
|
18 |
+
def tstack(images):
|
19 |
+
w,h = find_max_size(images)
|
20 |
+
images = [x.pad2wh(w,h) for x in images]
|
21 |
+
return IImage(np.concatenate([x.data for x in images], 0))
|
22 |
+
def hstack(images):
|
23 |
+
h = find_max_h(images)
|
24 |
+
images = [x.pad2wh(h = h) for x in images]
|
25 |
+
return IImage(np.concatenate([x.data for x in images], 2))
|
26 |
+
def vstack(images):
|
27 |
+
w = find_max_w(images)
|
28 |
+
images = [x.pad2wh(w = w) for x in images]
|
29 |
+
return IImage(np.concatenate([x.data for x in images], 1))
|
30 |
+
|
31 |
+
def grid(images, nrows = None, ncols = None):
|
32 |
+
combined = stack(images)
|
33 |
+
if nrows is not None:
|
34 |
+
ncols = math.ceil(combined.data.shape[0] / nrows)
|
35 |
+
elif ncols is not None:
|
36 |
+
nrows = math.ceil(combined.data.shape[0] / ncols)
|
37 |
+
else:
|
38 |
+
warnings.warn("No dimensions specified, creating a grid with 5 columns (default)")
|
39 |
+
ncols = 5
|
40 |
+
nrows = math.ceil(combined.data.shape[0] / ncols)
|
41 |
+
|
42 |
+
pad = nrows * ncols - combined.data.shape[0]
|
43 |
+
data = np.pad(combined.data, ((0,pad),(0,0),(0,0),(0,0)))
|
44 |
+
rows = [np.concatenate(x,1,dtype=np.uint8) for x in np.array_split(data, nrows)]
|
45 |
+
return IImage(np.concatenate(rows, 0, dtype = np.uint8)[None])
|
lib/farancia/libimage/iimage.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import PIL.Image
|
5 |
+
import numpy as np
|
6 |
+
import imageio.v3 as iio
|
7 |
+
import warnings
|
8 |
+
from torchvision.utils import flow_to_image
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torchvision.transforms.functional as TF
|
12 |
+
from scipy.ndimage import binary_dilation, binary_erosion
|
13 |
+
import cv2
|
14 |
+
|
15 |
+
from ..animation import Animation
|
16 |
+
from .. import config
|
17 |
+
from .. import libimage
|
18 |
+
import re
|
19 |
+
|
20 |
+
|
21 |
+
def torch2np(x, vmin=-1, vmax=1):
|
22 |
+
if x.ndim != 4:
|
23 |
+
# raise Exception("Please only use (B,C,H,W) torch tensors!")
|
24 |
+
warnings.warn(
|
25 |
+
"Warning! Shape of the image was not provided in (B,C,H,W) format, the shape was inferred automatically!")
|
26 |
+
if x.ndim == 3:
|
27 |
+
x = x[None]
|
28 |
+
if x.ndim == 2:
|
29 |
+
x = x[None, None]
|
30 |
+
assert x.shape[1] == 3 or x.shape[1] == 1
|
31 |
+
x = x.detach().cpu().float()
|
32 |
+
if x.dtype == torch.uint8:
|
33 |
+
return x.numpy().astype(np.uint8)
|
34 |
+
elif vmin is not None and vmax is not None:
|
35 |
+
x = (255 * (x.clip(vmin, vmax) - vmin) / (vmax - vmin))
|
36 |
+
x = x.permute(0, 2, 3, 1).to(torch.uint8)
|
37 |
+
return x.numpy()
|
38 |
+
else:
|
39 |
+
raise NotImplementedError()
|
40 |
+
|
41 |
+
|
42 |
+
class IImage:
|
43 |
+
'''
|
44 |
+
Generic media storage. Can store both images and videos.
|
45 |
+
Stores data as a numpy array by default.
|
46 |
+
Can be viewed in a jupyter notebook.
|
47 |
+
'''
|
48 |
+
@staticmethod
|
49 |
+
def open(path):
|
50 |
+
|
51 |
+
iio_obj = iio.imopen(path, 'r')
|
52 |
+
data = iio_obj.read()
|
53 |
+
try:
|
54 |
+
# .properties() does not work for images but for gif files
|
55 |
+
if not iio_obj.properties().is_batch:
|
56 |
+
data = data[None]
|
57 |
+
except AttributeError as e:
|
58 |
+
# this one works for gif files
|
59 |
+
if not "duration" in iio_obj.metadata():
|
60 |
+
data = data[None]
|
61 |
+
if data.ndim == 3:
|
62 |
+
data = data[..., None]
|
63 |
+
image = IImage(data)
|
64 |
+
image.link = os.path.abspath(path)
|
65 |
+
return image
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def flow_field(flow):
|
69 |
+
flow_images = flow_to_image(flow)
|
70 |
+
return IImage(flow_images, vmin=0, vmax=255)
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def normalized(x, dims=[-1, -2]):
|
74 |
+
x = (x - x.amin(dims, True)) / \
|
75 |
+
(x.amax(dims, True) - x.amin(dims, True))
|
76 |
+
return IImage(x, 0)
|
77 |
+
|
78 |
+
def numpy(self): return self.data
|
79 |
+
|
80 |
+
def torch(self, vmin=-1, vmax=1):
|
81 |
+
if self.data.ndim == 3:
|
82 |
+
data = self.data.transpose(2, 0, 1) / 255.
|
83 |
+
else:
|
84 |
+
data = self.data.transpose(0, 3, 1, 2) / 255.
|
85 |
+
return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin)
|
86 |
+
|
87 |
+
def cuda(self):
|
88 |
+
self.device = 'cuda'
|
89 |
+
return self
|
90 |
+
|
91 |
+
def cpu(self):
|
92 |
+
self.device = 'cpu'
|
93 |
+
return self
|
94 |
+
|
95 |
+
def pil(self):
|
96 |
+
ans = []
|
97 |
+
for x in self.data:
|
98 |
+
if x.shape[-1] == 1:
|
99 |
+
x = x[..., 0]
|
100 |
+
|
101 |
+
ans.append(PIL.Image.fromarray(x))
|
102 |
+
if len(ans) == 1:
|
103 |
+
return ans[0]
|
104 |
+
return ans
|
105 |
+
|
106 |
+
def is_iimage(self):
|
107 |
+
return True
|
108 |
+
|
109 |
+
@property
|
110 |
+
def shape(self): return self.data.shape
|
111 |
+
@property
|
112 |
+
def size(self): return (self.data.shape[-2], self.data.shape[-3])
|
113 |
+
|
114 |
+
def setFps(self, fps):
|
115 |
+
self.fps = fps
|
116 |
+
self.generate_display()
|
117 |
+
return self
|
118 |
+
|
119 |
+
def __init__(self, x, vmin=-1, vmax=1, fps=None):
|
120 |
+
|
121 |
+
if isinstance(x, PIL.Image.Image):
|
122 |
+
self.data = np.array(x)
|
123 |
+
if self.data.ndim == 2:
|
124 |
+
self.data = self.data[..., None] # (H,W,C)
|
125 |
+
self.data = self.data[None] # (B,H,W,C)
|
126 |
+
elif isinstance(x, IImage):
|
127 |
+
self.data = x.data.copy() # Simple Copy
|
128 |
+
elif isinstance(x, np.ndarray):
|
129 |
+
self.data = x.copy().astype(np.uint8)
|
130 |
+
if self.data.ndim == 2:
|
131 |
+
self.data = self.data[None, ..., None]
|
132 |
+
if self.data.ndim == 3:
|
133 |
+
warnings.warn(
|
134 |
+
"Inferred dimensions for a 3D array as (H,W,C), but could've been (B,H,W)")
|
135 |
+
self.data = self.data[None]
|
136 |
+
elif isinstance(x, torch.Tensor):
|
137 |
+
assert x.min() >= vmin and x.max(
|
138 |
+
) <= vmax, f"input data was [{x.min()},{x.max()}], but expected [{vmin},{vmax}]"
|
139 |
+
self.data = torch2np(x, vmin, vmax)
|
140 |
+
self.display_str = None
|
141 |
+
self.device = 'cpu'
|
142 |
+
self.fps = fps if fps is not None else (
|
143 |
+
1 if len(self.data) < 10 else 30)
|
144 |
+
self.link = None
|
145 |
+
|
146 |
+
def generate_display(self):
|
147 |
+
if config.IMG_THUMBSIZE is not None:
|
148 |
+
if self.size[1] < self.size[0]:
|
149 |
+
thumb = self.resize(
|
150 |
+
(self.size[1]*config.IMG_THUMBSIZE//self.size[0], config.IMG_THUMBSIZE))
|
151 |
+
else:
|
152 |
+
thumb = self.resize(
|
153 |
+
(config.IMG_THUMBSIZE, self.size[0]*config.IMG_THUMBSIZE//self.size[1]))
|
154 |
+
else:
|
155 |
+
thumb = self
|
156 |
+
if self.is_video():
|
157 |
+
self.anim = Animation(thumb.data, fps=self.fps)
|
158 |
+
self.anim.render()
|
159 |
+
self.display_str = self.anim.anim_str
|
160 |
+
else:
|
161 |
+
b = io.BytesIO()
|
162 |
+
data = thumb.data[0]
|
163 |
+
if data.shape[-1] == 1:
|
164 |
+
data = data[..., 0]
|
165 |
+
PIL.Image.fromarray(data).save(b, "PNG")
|
166 |
+
self.display_str = b.getvalue()
|
167 |
+
return self.display_str
|
168 |
+
|
169 |
+
def resize(self, size, *args, **kwargs):
|
170 |
+
if size is None:
|
171 |
+
return self
|
172 |
+
use_small_edge_when_int = kwargs.pop('use_small_edge_when_int', False)
|
173 |
+
|
174 |
+
# Backward compatibility
|
175 |
+
resample = kwargs.pop('filter', PIL.Image.BICUBIC)
|
176 |
+
resample = kwargs.pop('resample', resample)
|
177 |
+
|
178 |
+
if isinstance(size, int):
|
179 |
+
if use_small_edge_when_int:
|
180 |
+
h, w = self.data.shape[1:3]
|
181 |
+
aspect_ratio = h / w
|
182 |
+
size = (max(size, int(size * aspect_ratio)),
|
183 |
+
max(size, int(size / aspect_ratio)))
|
184 |
+
else:
|
185 |
+
h, w = self.data.shape[1:3]
|
186 |
+
aspect_ratio = h / w
|
187 |
+
size = (min(size, int(size * aspect_ratio)),
|
188 |
+
min(size, int(size / aspect_ratio)))
|
189 |
+
|
190 |
+
if self.size == size[::-1]:
|
191 |
+
return self
|
192 |
+
return libimage.stack([IImage(x.pil().resize(size[::-1], *args, resample=resample, **kwargs)) for x in self])
|
193 |
+
# return IImage(TF.resize(self.cpu().torch(0), size, *args, **kwargs), 0)
|
194 |
+
|
195 |
+
def pad(self, padding, *args, **kwargs):
|
196 |
+
return IImage(TF.pad(self.torch(0), padding=padding, *args, **kwargs), 0)
|
197 |
+
|
198 |
+
def padx(self, multiplier, *args, **kwargs):
|
199 |
+
size = np.array(self.size)
|
200 |
+
padding = np.concatenate(
|
201 |
+
[[0, 0], np.ceil(size / multiplier).astype(int) * multiplier - size])
|
202 |
+
return self.pad(list(padding), *args, **kwargs)
|
203 |
+
|
204 |
+
def pad2wh(self, w=0, h=0, **kwargs):
|
205 |
+
cw, ch = self.size
|
206 |
+
return self.pad([0, 0, max(0, w - cw), max(0, h-ch)], **kwargs)
|
207 |
+
|
208 |
+
def pad2square(self, *args, **kwargs):
|
209 |
+
if self.size[0] > self.size[1]:
|
210 |
+
dx = self.size[0] - self.size[1]
|
211 |
+
return self.pad([0, dx//2, 0, dx-dx//2], *args, **kwargs)
|
212 |
+
elif self.size[0] < self.size[1]:
|
213 |
+
dx = self.size[1] - self.size[0]
|
214 |
+
return self.pad([dx//2, 0, dx-dx//2, 0], *args, **kwargs)
|
215 |
+
return self
|
216 |
+
|
217 |
+
def crop2square(self, *args, **kwargs):
|
218 |
+
if self.size[0] > self.size[1]:
|
219 |
+
dx = self.size[0] - self.size[1]
|
220 |
+
return self.crop([dx//2, 0, self.size[1], self.size[1]], *args, **kwargs)
|
221 |
+
elif self.size[0] < self.size[1]:
|
222 |
+
dx = self.size[1] - self.size[0]
|
223 |
+
return self.crop([0, dx//2, self.size[0], self.size[0]], *args, **kwargs)
|
224 |
+
return self
|
225 |
+
|
226 |
+
def alpha(self):
|
227 |
+
return IImage(self.data[..., -1, None], fps=self.fps)
|
228 |
+
|
229 |
+
def rgb(self):
|
230 |
+
return IImage(self.pil().convert('RGB'), fps=self.fps)
|
231 |
+
|
232 |
+
def png(self):
|
233 |
+
return IImage(np.concatenate([self.data, 255 * np.ones_like(self.data)[..., :1]], -1))
|
234 |
+
|
235 |
+
def grid(self, nrows=None, ncols=None):
|
236 |
+
if nrows is not None:
|
237 |
+
ncols = math.ceil(self.data.shape[0] / nrows)
|
238 |
+
elif ncols is not None:
|
239 |
+
nrows = math.ceil(self.data.shape[0] / ncols)
|
240 |
+
else:
|
241 |
+
warnings.warn(
|
242 |
+
"No dimensions specified, creating a grid with 5 columns (default)")
|
243 |
+
ncols = 5
|
244 |
+
nrows = math.ceil(self.data.shape[0] / ncols)
|
245 |
+
|
246 |
+
pad = nrows * ncols - self.data.shape[0]
|
247 |
+
data = np.pad(self.data, ((0, pad), (0, 0), (0, 0), (0, 0)))
|
248 |
+
rows = [np.concatenate(x, 1, dtype=np.uint8)
|
249 |
+
for x in np.array_split(data, nrows)]
|
250 |
+
return IImage(np.concatenate(rows, 0, dtype=np.uint8)[None])
|
251 |
+
|
252 |
+
def hstack(self):
|
253 |
+
return IImage(np.concatenate(self.data, 1, dtype=np.uint8)[None])
|
254 |
+
|
255 |
+
def vstack(self):
|
256 |
+
return IImage(np.concatenate(self.data, 0, dtype=np.uint8)[None])
|
257 |
+
|
258 |
+
def vsplit(self, number_of_splits):
|
259 |
+
return IImage(np.concatenate(np.split(self.data, number_of_splits, 1)))
|
260 |
+
|
261 |
+
def hsplit(self, number_of_splits):
|
262 |
+
return IImage(np.concatenate(np.split(self.data, number_of_splits, 2)))
|
263 |
+
|
264 |
+
def heatmap(self, resize=None, cmap=cv2.COLORMAP_JET):
|
265 |
+
data = np.stack([cv2.cvtColor(cv2.applyColorMap(
|
266 |
+
x, cmap), cv2.COLOR_BGR2RGB) for x in self.data])
|
267 |
+
return IImage(data).resize(resize, use_small_edge_when_int=True)
|
268 |
+
|
269 |
+
def display(self):
|
270 |
+
try:
|
271 |
+
display(self)
|
272 |
+
except:
|
273 |
+
print("No display")
|
274 |
+
return self
|
275 |
+
|
276 |
+
def dilate(self, iterations=1, *args, **kwargs):
|
277 |
+
if iterations == 0:
|
278 |
+
return IImage(self.data)
|
279 |
+
return IImage((binary_dilation(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8))
|
280 |
+
|
281 |
+
def erode(self, iterations=1, *args, **kwargs):
|
282 |
+
return IImage((binary_erosion(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8))
|
283 |
+
|
284 |
+
def hull(self):
|
285 |
+
convex_hulls = []
|
286 |
+
for frame in self.data:
|
287 |
+
contours, hierarchy = cv2.findContours(
|
288 |
+
frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
289 |
+
contours = [x.astype(np.int32) for x in contours]
|
290 |
+
mask_contours = [cv2.convexHull(np.concatenate(contours))]
|
291 |
+
canvas = np.zeros(self.data[0].shape, np.uint8)
|
292 |
+
convex_hull = cv2.drawContours(
|
293 |
+
canvas, mask_contours, -1, (255, 0, 0), -1)
|
294 |
+
convex_hulls.append(convex_hull)
|
295 |
+
return IImage(np.array(convex_hulls))
|
296 |
+
|
297 |
+
def is_video(self):
|
298 |
+
return self.data.shape[0] > 1
|
299 |
+
|
300 |
+
def __getitem__(self, idx):
|
301 |
+
return IImage(self.data[None, idx], fps=self.fps)
|
302 |
+
# if self.is_video(): return IImage(self.data[idx], fps = self.fps)
|
303 |
+
# return self
|
304 |
+
|
305 |
+
def _repr_png_(self):
|
306 |
+
if self.is_video():
|
307 |
+
return None
|
308 |
+
if self.display_str is None:
|
309 |
+
self.generate_display()
|
310 |
+
return self.display_str
|
311 |
+
|
312 |
+
def _repr_html_(self):
|
313 |
+
if not self.is_video():
|
314 |
+
return None
|
315 |
+
if self.display_str is None:
|
316 |
+
self.generate_display()
|
317 |
+
return self.display_str
|
318 |
+
|
319 |
+
def save(self, path):
|
320 |
+
_, ext = os.path.splitext(path)
|
321 |
+
if self.is_video():
|
322 |
+
# if ext in ['.jpg', '.png']:
|
323 |
+
if self.display_str is None:
|
324 |
+
self.generate_display()
|
325 |
+
if ext == ".apng":
|
326 |
+
self.anim.anim_obj.save(path, writer="pillow")
|
327 |
+
else:
|
328 |
+
self.anim.anim_obj.save(path)
|
329 |
+
else:
|
330 |
+
data = self.data if self.data.ndim == 3 else self.data[0]
|
331 |
+
if data.shape[-1] == 1:
|
332 |
+
data = data[:, :, 0]
|
333 |
+
PIL.Image.fromarray(data).save(path)
|
334 |
+
return self
|
335 |
+
|
336 |
+
def to_html(self, width='auto', root_path='/'):
|
337 |
+
if self.display_str is None:
|
338 |
+
self.generate_display()
|
339 |
+
# print (self.display_str)
|
340 |
+
html_tag = bytes2html(self.display_str, width=width)
|
341 |
+
if self.link is not None:
|
342 |
+
link = os.path.relpath(self.link, root_path)
|
343 |
+
return f'<a href="{link}" >{html_tag}</a>'
|
344 |
+
return html_tag
|
345 |
+
|
346 |
+
def write(self, text, center=(0, 25), font_scale=0.8, color=(255, 255, 255), thickness=2):
|
347 |
+
if not isinstance(text, list):
|
348 |
+
text = [text for _ in self.data]
|
349 |
+
data = np.stack([cv2.putText(x.copy(), t, center, cv2.FONT_HERSHEY_COMPLEX,
|
350 |
+
font_scale, color, thickness) for x, t in zip(self.data, text)])
|
351 |
+
return IImage(data)
|
352 |
+
|
353 |
+
def append_text(self, text, padding, font_scale=0.8, color=(255, 255, 255), thickness=2, scale_factor=0.9, center=(0, 0), fill=0):
|
354 |
+
|
355 |
+
assert np.count_nonzero(padding) == 1
|
356 |
+
axis_padding = np.nonzero(padding)[0][0]
|
357 |
+
scale_padding = padding[axis_padding]
|
358 |
+
|
359 |
+
y_0 = 0
|
360 |
+
x_0 = 0
|
361 |
+
if axis_padding == 0:
|
362 |
+
width = scale_padding
|
363 |
+
y_max = self.shape[1]
|
364 |
+
elif axis_padding == 1:
|
365 |
+
width = self.shape[2]
|
366 |
+
y_max = scale_padding
|
367 |
+
elif axis_padding == 2:
|
368 |
+
x_0 = self.shape[2]
|
369 |
+
width = scale_padding
|
370 |
+
y_max = self.shape[1]
|
371 |
+
elif axis_padding == 3:
|
372 |
+
width = self.shape[2]
|
373 |
+
y_0 = self.shape[1]
|
374 |
+
y_max = self.shape[1]+scale_padding
|
375 |
+
|
376 |
+
width -= center[0]
|
377 |
+
x_0 += center[0]
|
378 |
+
y_0 += center[1]
|
379 |
+
|
380 |
+
self = self.pad(padding, fill=fill)
|
381 |
+
|
382 |
+
def wrap_text(text, width, _font_scale):
|
383 |
+
allowed_seperator = ' |-|_|/|\n'
|
384 |
+
words = re.split(allowed_seperator, text)
|
385 |
+
# words = text.split()
|
386 |
+
lines = []
|
387 |
+
current_line = words[0]
|
388 |
+
sep_list = []
|
389 |
+
start_idx = 0
|
390 |
+
for start_word in words[:-1]:
|
391 |
+
pos = text.find(start_word, start_idx)
|
392 |
+
pos += len(start_word)
|
393 |
+
sep_list.append(text[pos])
|
394 |
+
start_idx = pos+1
|
395 |
+
|
396 |
+
for word, separator in zip(words[1:], sep_list):
|
397 |
+
if cv2.getTextSize(current_line + separator + word, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
|
398 |
+
current_line += separator + word
|
399 |
+
else:
|
400 |
+
if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
|
401 |
+
lines.append(current_line)
|
402 |
+
current_line = word
|
403 |
+
else:
|
404 |
+
return []
|
405 |
+
|
406 |
+
if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
|
407 |
+
lines.append(current_line)
|
408 |
+
else:
|
409 |
+
return []
|
410 |
+
return lines
|
411 |
+
|
412 |
+
def wrap_text_and_scale(text, width, _font_scale, y_0, y_max):
|
413 |
+
height = y_max+1
|
414 |
+
while height > y_max:
|
415 |
+
text_lines = wrap_text(text, width, _font_scale)
|
416 |
+
if len(text) > 0 and len(text_lines) == 0:
|
417 |
+
|
418 |
+
height = y_max+1
|
419 |
+
else:
|
420 |
+
line_height = cv2.getTextSize(
|
421 |
+
text_lines[0], cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][1]
|
422 |
+
height = line_height * len(text_lines) + y_0
|
423 |
+
|
424 |
+
# scale font if out of frame
|
425 |
+
if height > y_max:
|
426 |
+
_font_scale = _font_scale * scale_factor
|
427 |
+
|
428 |
+
return text_lines, line_height, _font_scale
|
429 |
+
|
430 |
+
result = []
|
431 |
+
if not isinstance(text, list):
|
432 |
+
text = [text for _ in self.data]
|
433 |
+
else:
|
434 |
+
assert len(text) == len(self.data)
|
435 |
+
|
436 |
+
for x, t in zip(self.data, text):
|
437 |
+
x = x.copy()
|
438 |
+
text_lines, line_height, _font_scale = wrap_text_and_scale(
|
439 |
+
t, width, font_scale, y_0, y_max)
|
440 |
+
y = line_height
|
441 |
+
for line in text_lines:
|
442 |
+
x = cv2.putText(
|
443 |
+
x, line, (x_0, y_0+y), cv2.FONT_HERSHEY_COMPLEX, _font_scale, color, thickness)
|
444 |
+
y += line_height
|
445 |
+
result.append(x)
|
446 |
+
data = np.stack(result)
|
447 |
+
|
448 |
+
return IImage(data)
|
449 |
+
|
450 |
+
# ========== OPERATORS =============
|
451 |
+
|
452 |
+
def __or__(self, other):
|
453 |
+
# TODO: fix for variable sizes
|
454 |
+
return IImage(np.concatenate([self.data, other.data], 2))
|
455 |
+
|
456 |
+
def __truediv__(self, other):
|
457 |
+
# TODO: fix for variable sizes
|
458 |
+
return IImage(np.concatenate([self.data, other.data], 1))
|
459 |
+
|
460 |
+
def __and__(self, other):
|
461 |
+
return IImage(np.concatenate([self.data, other.data], 0))
|
462 |
+
|
463 |
+
def __add__(self, other):
|
464 |
+
return IImage(0.5 * self.data + 0.5 * other.data)
|
465 |
+
|
466 |
+
def __mul__(self, other):
|
467 |
+
if isinstance(other, IImage):
|
468 |
+
return IImage(self.data / 255. * other.data)
|
469 |
+
return IImage(self.data * other / 255.)
|
470 |
+
|
471 |
+
def __xor__(self, other):
|
472 |
+
return IImage(0.5 * self.data + 0.5 * other.data + 0.5 * self.data * (other.data.sum(-1, keepdims=True) == 0))
|
473 |
+
|
474 |
+
def __invert__(self):
|
475 |
+
return IImage(255 - self.data)
|
476 |
+
__rmul__ = __mul__
|
477 |
+
|
478 |
+
def bbox(self):
|
479 |
+
return [cv2.boundingRect(x) for x in self.data]
|
480 |
+
|
481 |
+
def fill_bbox(self, bbox_list, fill=255):
|
482 |
+
data = self.data.copy()
|
483 |
+
for bbox in bbox_list:
|
484 |
+
x, y, w, h = bbox
|
485 |
+
data[:, y:y+h, x:x+w, :] = fill
|
486 |
+
return IImage(data)
|
487 |
+
|
488 |
+
def crop(self, bbox):
|
489 |
+
assert len(bbox) in [2, 4]
|
490 |
+
if len(bbox) == 2:
|
491 |
+
x, y = 0, 0
|
492 |
+
w, h = bbox
|
493 |
+
elif len(bbox) == 4:
|
494 |
+
x, y, w, h = bbox
|
495 |
+
return IImage(self.data[:, y:y+h, x:x+w, :])
|
496 |
+
|
497 |
+
# def alpha(self):
|
498 |
+
# return BetterImage(self.img.split()[-1])
|
499 |
+
# def resize(self, size, *args, **kwargs):
|
500 |
+
# if size is None: return self
|
501 |
+
# return BetterImage(TF.resize(self.img, size, *args, **kwargs))
|
502 |
+
# def pad(self, *args):
|
503 |
+
# return BetterImage(TF.pad(self.img, *args))
|
504 |
+
# def padx(self, mult):
|
505 |
+
# size = np.array(self.img.size)
|
506 |
+
# padding = np.concatenate([[0,0],np.ceil(size / mult).astype(int) * mult - size])
|
507 |
+
# return self.pad(list(padding))
|
508 |
+
# def crop(self, *args):
|
509 |
+
# return BetterImage(self.img.crop(*args))
|
510 |
+
# def torch(self, min = -1., max = 1.):
|
511 |
+
# return (max - min) * TF.to_tensor(self.img)[None] + min
|
lib/farancia/libimage/utils.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from IPython.display import Image as IpyImage
|
2 |
+
|
3 |
+
def bytes2html(data, width='auto'):
|
4 |
+
img_obj = IpyImage(data=data, format='JPG')
|
5 |
+
for bundle in img_obj._repr_mimebundle_():
|
6 |
+
for mimetype, b64value in bundle.items():
|
7 |
+
if mimetype.startswith('image/'):
|
8 |
+
return f'<img src="data:{mimetype};base64,{b64value}" style="width: {width}; max-width: 100%">'
|
models/cam/conditioning.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from einops import rearrange
|
4 |
+
from diffusers.models.attention_processor import Attention
|
5 |
+
|
6 |
+
|
7 |
+
class CrossAttention(nn.Module):
|
8 |
+
"""
|
9 |
+
CrossAttention module implements per-pixel temporal attention to fuse the conditional attention module with the base module.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
input_channels (int): Number of input channels.
|
13 |
+
attention_head_dim (int): Dimension of attention head.
|
14 |
+
norm_num_groups (int): Number of groups for GroupNorm normalization (default is 32).
|
15 |
+
|
16 |
+
Attributes:
|
17 |
+
attention (Attention): Attention module for computing attention scores.
|
18 |
+
norm (torch.nn.GroupNorm): Group normalization layer.
|
19 |
+
proj_in (nn.Linear): Linear layer for projecting input data.
|
20 |
+
proj_out (nn.Linear): Linear layer for projecting output data.
|
21 |
+
dropout (nn.Dropout): Dropout layer for regularization.
|
22 |
+
|
23 |
+
Methods:
|
24 |
+
forward(hidden_state, encoder_hidden_states, num_frames, num_conditional_frames):
|
25 |
+
Forward pass of the CrossAttention module.
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, input_channels, attention_head_dim, norm_num_groups=32):
|
30 |
+
super().__init__()
|
31 |
+
self.attention = Attention(
|
32 |
+
query_dim=input_channels, cross_attention_dim=input_channels, heads=input_channels//attention_head_dim, dim_head=attention_head_dim, bias=False, upcast_attention=False)
|
33 |
+
self.norm = torch.nn.GroupNorm(
|
34 |
+
num_groups=norm_num_groups, num_channels=input_channels, eps=1e-6, affine=True)
|
35 |
+
self.proj_in = nn.Linear(input_channels, input_channels)
|
36 |
+
self.proj_out = nn.Linear(input_channels, input_channels)
|
37 |
+
self.dropout = nn.Dropout(p=0.25)
|
38 |
+
|
39 |
+
def forward(self, hidden_state, encoder_hidden_states, num_frames, num_conditional_frames):
|
40 |
+
"""
|
41 |
+
The input hidden state is normalized, then projected using a linear layer.
|
42 |
+
Multi-head cross attention is computed between the hidden state (latent of noisy video) and encoder hidden states (CLIP image encoder).
|
43 |
+
The output is projected using a linear layer.
|
44 |
+
We apply dropout to the newly generated frames (without the control frames).
|
45 |
+
|
46 |
+
Args:
|
47 |
+
hidden_state (torch.Tensor): Input hidden state tensor.
|
48 |
+
encoder_hidden_states (torch.Tensor): Encoder hidden states tensor.
|
49 |
+
num_frames (int): Number of frames.
|
50 |
+
num_conditional_frames (int): Number of conditional frames.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
output (torch.Tensor): Output tensor after processing with attention mechanism.
|
54 |
+
|
55 |
+
"""
|
56 |
+
h, w = hidden_state.shape[2], hidden_state.shape[3]
|
57 |
+
hidden_state_norm = rearrange(
|
58 |
+
hidden_state, "(B F) C H W -> B C F H W", F=num_frames)
|
59 |
+
hidden_state_norm = self.norm(hidden_state_norm)
|
60 |
+
hidden_state_norm = rearrange(
|
61 |
+
hidden_state_norm, "B C F H W -> (B H W) F C")
|
62 |
+
|
63 |
+
hidden_state_norm = self.proj_in(hidden_state_norm)
|
64 |
+
|
65 |
+
attn = self.attention(hidden_state_norm,
|
66 |
+
encoder_hidden_states=encoder_hidden_states,
|
67 |
+
attention_mask=None,
|
68 |
+
)
|
69 |
+
# proj_out
|
70 |
+
|
71 |
+
residual = self.proj_out(attn) # (B H W) F C
|
72 |
+
hidden_state = rearrange(
|
73 |
+
hidden_state, "(B F) ... -> B F ...", F=num_frames)
|
74 |
+
hidden_state = torch.cat([hidden_state[:, :num_conditional_frames], self.dropout(
|
75 |
+
hidden_state[:, num_conditional_frames:])], dim=1)
|
76 |
+
hidden_state = rearrange(hidden_state, "B F ... -> (B F) ... ")
|
77 |
+
|
78 |
+
residual = rearrange(
|
79 |
+
residual, "(B H W) F C -> (B F) C H W", H=h, W=w)
|
80 |
+
output = hidden_state + residual
|
81 |
+
return output
|
82 |
+
|
83 |
+
|
84 |
+
class ConditionalModel(nn.Module):
|
85 |
+
"""
|
86 |
+
ConditionalModel module performs the fusion of the conditional attention module to be base model.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
input_channels (int): Number of input channels.
|
90 |
+
conditional_model (str): Type of conditional model to use. Currently only "cross_attention" is implemented.
|
91 |
+
attention_head_dim (int): Dimension of attention head (default is 64).
|
92 |
+
|
93 |
+
Attributes:
|
94 |
+
temporal_transformer (CrossAttention): CrossAttention module for temporal transformation.
|
95 |
+
conditional_model (str): Type of conditional model used.
|
96 |
+
|
97 |
+
Methods:
|
98 |
+
forward(sample, conditioning, num_frames=None, num_conditional_frames=None):
|
99 |
+
Forward pass of the ConditionalModel module.
|
100 |
+
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, input_channels, conditional_model: str, attention_head_dim=64):
|
104 |
+
super().__init__()
|
105 |
+
|
106 |
+
if conditional_model == "cross_attention":
|
107 |
+
self.temporal_transformer = CrossAttention(
|
108 |
+
input_channels=input_channels, attention_head_dim=attention_head_dim)
|
109 |
+
else:
|
110 |
+
raise NotImplementedError(
|
111 |
+
f"mode {conditional_model} not implemented")
|
112 |
+
|
113 |
+
nn.init.zeros_(self.temporal_transformer.proj_out.weight)
|
114 |
+
nn.init.zeros_(self.temporal_transformer.proj_out.bias)
|
115 |
+
self.conditional_model = conditional_model
|
116 |
+
|
117 |
+
def forward(self, sample, conditioning, num_frames=None, num_conditional_frames=None):
|
118 |
+
"""
|
119 |
+
Forward pass of the ConditionalModel module.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
sample (torch.Tensor): Input sample tensor.
|
123 |
+
conditioning (torch.Tensor): Conditioning tensor containing the enconding of the conditional frames.
|
124 |
+
num_frames (int): Number of frames in the sample.
|
125 |
+
num_conditional_frames (int): Number of conditional frames.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
sample (torch.Tensor): Transformed sample tensor.
|
129 |
+
|
130 |
+
"""
|
131 |
+
sample = rearrange(sample, "(B F) ... -> B F ...", F=num_frames)
|
132 |
+
batch_size = sample.shape[0]
|
133 |
+
conditioning = rearrange(
|
134 |
+
conditioning, "(B F) ... -> B F ...", B=batch_size)
|
135 |
+
|
136 |
+
assert conditioning.ndim == 5
|
137 |
+
assert sample.ndim == 5
|
138 |
+
|
139 |
+
conditioning = rearrange(conditioning, "B F C H W -> (B H W) F C")
|
140 |
+
|
141 |
+
sample = rearrange(sample, "B F C H W -> (B F) C H W")
|
142 |
+
|
143 |
+
sample = self.temporal_transformer(
|
144 |
+
sample, encoder_hidden_states=conditioning, num_frames=num_frames, num_conditional_frames=num_conditional_frames)
|
145 |
+
|
146 |
+
return sample
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
model = CrossAttention(input_channels=320, attention_head_dim=32)
|
models/control/controlnet.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
from models.svd.sgm.util import default
|
5 |
+
from models.svd.sgm.modules.video_attention import SpatialVideoTransformer
|
6 |
+
from models.svd.sgm.modules.diffusionmodules.openaimodel import *
|
7 |
+
from models.diffusion.video_model import VideoResBlock, VideoUNet
|
8 |
+
from einops import repeat, rearrange
|
9 |
+
from models.svd.sgm.modules.diffusionmodules.wrappers import OpenAIWrapper
|
10 |
+
|
11 |
+
|
12 |
+
class Merger(nn.Module):
|
13 |
+
"""
|
14 |
+
Merges the controlnet latents with the conditioning embedding (encoding of control frames).
|
15 |
+
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, merge_mode: str = "addition", input_channels=0, frame_expansion="last_frame") -> None:
|
19 |
+
super().__init__()
|
20 |
+
self.merge_mode = merge_mode
|
21 |
+
self.frame_expansion = frame_expansion
|
22 |
+
|
23 |
+
def forward(self, x, condition_signal, num_video_frames, num_video_frames_conditional):
|
24 |
+
x = rearrange(x, "(B F) C H W -> B F C H W", F=num_video_frames)
|
25 |
+
|
26 |
+
condition_signal = rearrange(
|
27 |
+
condition_signal, "(B F) C H W -> B F C H W", B=x.shape[0])
|
28 |
+
|
29 |
+
if x.shape[1] - condition_signal.shape[1] > 0:
|
30 |
+
if self.frame_expansion == "last_frame":
|
31 |
+
fillup_latent = repeat(
|
32 |
+
condition_signal[:, -1], "B C H W -> B F C H W", F=x.shape[1] - condition_signal.shape[1])
|
33 |
+
elif self.frame_expansion == "zero":
|
34 |
+
fillup_latent = torch.zeros(
|
35 |
+
(x.shape[0], num_video_frames-num_video_frames_conditional, *x.shape[2:]), device=x.device, dtype=x.dtype)
|
36 |
+
|
37 |
+
if self.frame_expansion != "none":
|
38 |
+
condition_signal = torch.cat(
|
39 |
+
[condition_signal, fillup_latent], dim=1)
|
40 |
+
|
41 |
+
if self.merge_mode == "addition":
|
42 |
+
out = x + condition_signal
|
43 |
+
else:
|
44 |
+
raise NotImplementedError(
|
45 |
+
f"Merging mode {self.merge_mode} not implemented.")
|
46 |
+
|
47 |
+
out = rearrange(out, "B F C H W -> (B F) C H W")
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
class ControlNetConditioningEmbedding(nn.Module):
|
52 |
+
"""
|
53 |
+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
54 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
55 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
56 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
57 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
58 |
+
model) to encode image-space conditions ... into feature maps ..."
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
conditioning_embedding_channels: int,
|
64 |
+
conditioning_channels: int = 3,
|
65 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
66 |
+
downsample: bool = True,
|
67 |
+
final_3d_conv: bool = False,
|
68 |
+
zero_init: bool = True,
|
69 |
+
use_controlnet_mask: bool = False,
|
70 |
+
use_normalization: bool = False,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
self.final_3d_conv = final_3d_conv
|
75 |
+
self.conv_in = nn.Conv2d(
|
76 |
+
conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
77 |
+
if final_3d_conv:
|
78 |
+
print("USING 3D CONV in ControlNET")
|
79 |
+
|
80 |
+
self.blocks = nn.ModuleList([])
|
81 |
+
if use_normalization:
|
82 |
+
self.norms = nn.ModuleList([])
|
83 |
+
self.use_normalization = use_normalization
|
84 |
+
|
85 |
+
stride = 2 if downsample else 1
|
86 |
+
|
87 |
+
for i in range(len(block_out_channels) - 1):
|
88 |
+
channel_in = block_out_channels[i]
|
89 |
+
channel_out = block_out_channels[i + 1]
|
90 |
+
self.blocks.append(
|
91 |
+
nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
92 |
+
if use_normalization:
|
93 |
+
self.norms.append(nn.LayerNorm((channel_in)))
|
94 |
+
self.blocks.append(
|
95 |
+
nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=stride))
|
96 |
+
if use_normalization:
|
97 |
+
self.norms.append(nn.LayerNorm((channel_out)))
|
98 |
+
|
99 |
+
self.conv_out = zero_module(
|
100 |
+
nn.Conv2d(
|
101 |
+
block_out_channels[-1]+int(use_controlnet_mask), conditioning_embedding_channels, kernel_size=3, padding=1), reset=zero_init
|
102 |
+
)
|
103 |
+
|
104 |
+
def forward(self, conditioning):
|
105 |
+
embedding = self.conv_in(conditioning)
|
106 |
+
embedding = F.silu(embedding)
|
107 |
+
|
108 |
+
if self.use_normalization:
|
109 |
+
for block, norm in zip(self.blocks, self.norms):
|
110 |
+
embedding = block(embedding)
|
111 |
+
embedding = rearrange(embedding, " ... C W H -> ... W H C")
|
112 |
+
embedding = norm(embedding)
|
113 |
+
embedding = rearrange(embedding, "... W H C -> ... C W H")
|
114 |
+
embedding = F.silu(embedding)
|
115 |
+
else:
|
116 |
+
for block in self.blocks:
|
117 |
+
embedding = block(embedding)
|
118 |
+
embedding = F.silu(embedding)
|
119 |
+
|
120 |
+
embedding = self.conv_out(embedding)
|
121 |
+
return embedding
|
122 |
+
|
123 |
+
|
124 |
+
class ControlNet(nn.Module):
|
125 |
+
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
in_channels: int,
|
129 |
+
model_channels: int,
|
130 |
+
out_channels: int,
|
131 |
+
num_res_blocks: int,
|
132 |
+
attention_resolutions: Union[List[int], int],
|
133 |
+
dropout: float = 0.0,
|
134 |
+
channel_mult: List[int] = (1, 2, 4, 8),
|
135 |
+
conv_resample: bool = True,
|
136 |
+
dims: int = 2,
|
137 |
+
num_classes: Optional[Union[int, str]] = None,
|
138 |
+
use_checkpoint: bool = False,
|
139 |
+
num_heads: int = -1,
|
140 |
+
num_head_channels: int = -1,
|
141 |
+
num_heads_upsample: int = -1,
|
142 |
+
use_scale_shift_norm: bool = False,
|
143 |
+
resblock_updown: bool = False,
|
144 |
+
transformer_depth: Union[List[int], int] = 1,
|
145 |
+
transformer_depth_middle: Optional[int] = None,
|
146 |
+
context_dim: Optional[int] = None,
|
147 |
+
time_downup: bool = False,
|
148 |
+
time_context_dim: Optional[int] = None,
|
149 |
+
extra_ff_mix_layer: bool = False,
|
150 |
+
use_spatial_context: bool = False,
|
151 |
+
merge_strategy: str = "fixed",
|
152 |
+
merge_factor: float = 0.5,
|
153 |
+
spatial_transformer_attn_type: str = "softmax",
|
154 |
+
video_kernel_size: Union[int, List[int]] = 3,
|
155 |
+
use_linear_in_transformer: bool = False,
|
156 |
+
adm_in_channels: Optional[int] = None,
|
157 |
+
disable_temporal_crossattention: bool = False,
|
158 |
+
max_ddpm_temb_period: int = 10000,
|
159 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (
|
160 |
+
16, 32, 96, 256),
|
161 |
+
condition_encoder: str = "",
|
162 |
+
use_controlnet_mask: bool = False,
|
163 |
+
downsample_controlnet_cond: bool = True,
|
164 |
+
use_image_encoder_normalization: bool = False,
|
165 |
+
zero_conv_mode: str = "Identity",
|
166 |
+
frame_expansion: str = "none",
|
167 |
+
merging_mode: str = "addition",
|
168 |
+
):
|
169 |
+
super().__init__()
|
170 |
+
assert zero_conv_mode == "Identity", "Zero convolution not implemented"
|
171 |
+
|
172 |
+
assert context_dim is not None
|
173 |
+
|
174 |
+
if num_heads_upsample == -1:
|
175 |
+
num_heads_upsample = num_heads
|
176 |
+
|
177 |
+
if num_heads == -1:
|
178 |
+
assert num_head_channels != -1
|
179 |
+
|
180 |
+
if num_head_channels == -1:
|
181 |
+
assert num_heads != -1
|
182 |
+
|
183 |
+
self.in_channels = in_channels
|
184 |
+
self.model_channels = model_channels
|
185 |
+
self.out_channels = out_channels
|
186 |
+
if isinstance(transformer_depth, int):
|
187 |
+
transformer_depth = len(channel_mult) * [transformer_depth]
|
188 |
+
transformer_depth_middle = default(
|
189 |
+
transformer_depth_middle, transformer_depth[-1]
|
190 |
+
)
|
191 |
+
|
192 |
+
self.num_res_blocks = num_res_blocks
|
193 |
+
self.attention_resolutions = attention_resolutions
|
194 |
+
self.dropout = dropout
|
195 |
+
self.channel_mult = channel_mult
|
196 |
+
self.conv_resample = conv_resample
|
197 |
+
self.num_classes = num_classes
|
198 |
+
self.use_checkpoint = use_checkpoint
|
199 |
+
self.num_heads = num_heads
|
200 |
+
self.num_head_channels = num_head_channels
|
201 |
+
self.num_heads_upsample = num_heads_upsample
|
202 |
+
self.dims = dims
|
203 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
204 |
+
self.resblock_updown = resblock_updown
|
205 |
+
self.transformer_depth = transformer_depth
|
206 |
+
self.transformer_depth_middle = transformer_depth_middle
|
207 |
+
self.context_dim = context_dim
|
208 |
+
self.time_downup = time_downup
|
209 |
+
self.time_context_dim = time_context_dim
|
210 |
+
self.extra_ff_mix_layer = extra_ff_mix_layer
|
211 |
+
self.use_spatial_context = use_spatial_context
|
212 |
+
self.merge_strategy = merge_strategy
|
213 |
+
self.merge_factor = merge_factor
|
214 |
+
self.spatial_transformer_attn_type = spatial_transformer_attn_type
|
215 |
+
self.video_kernel_size = video_kernel_size
|
216 |
+
self.use_linear_in_transformer = use_linear_in_transformer
|
217 |
+
self.adm_in_channels = adm_in_channels
|
218 |
+
self.disable_temporal_crossattention = disable_temporal_crossattention
|
219 |
+
self.max_ddpm_temb_period = max_ddpm_temb_period
|
220 |
+
|
221 |
+
time_embed_dim = model_channels * 4
|
222 |
+
self.time_embed = nn.Sequential(
|
223 |
+
linear(model_channels, time_embed_dim),
|
224 |
+
nn.SiLU(),
|
225 |
+
linear(time_embed_dim, time_embed_dim),
|
226 |
+
)
|
227 |
+
|
228 |
+
if self.num_classes is not None:
|
229 |
+
if isinstance(self.num_classes, int):
|
230 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
231 |
+
elif self.num_classes == "continuous":
|
232 |
+
print("setting up linear c_adm embedding layer")
|
233 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
234 |
+
elif self.num_classes == "timestep":
|
235 |
+
self.label_emb = nn.Sequential(
|
236 |
+
Timestep(model_channels),
|
237 |
+
nn.Sequential(
|
238 |
+
linear(model_channels, time_embed_dim),
|
239 |
+
nn.SiLU(),
|
240 |
+
linear(time_embed_dim, time_embed_dim),
|
241 |
+
),
|
242 |
+
)
|
243 |
+
|
244 |
+
elif self.num_classes == "sequential":
|
245 |
+
assert adm_in_channels is not None
|
246 |
+
self.label_emb = nn.Sequential(
|
247 |
+
nn.Sequential(
|
248 |
+
linear(adm_in_channels, time_embed_dim),
|
249 |
+
nn.SiLU(),
|
250 |
+
linear(time_embed_dim, time_embed_dim),
|
251 |
+
)
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
raise ValueError()
|
255 |
+
|
256 |
+
self.input_blocks = nn.ModuleList(
|
257 |
+
[
|
258 |
+
TimestepEmbedSequential(
|
259 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
260 |
+
)
|
261 |
+
]
|
262 |
+
)
|
263 |
+
self._feature_size = model_channels
|
264 |
+
input_block_chans = [model_channels]
|
265 |
+
ch = model_channels
|
266 |
+
ds = 1
|
267 |
+
|
268 |
+
def get_attention_layer(
|
269 |
+
ch,
|
270 |
+
num_heads,
|
271 |
+
dim_head,
|
272 |
+
depth=1,
|
273 |
+
context_dim=None,
|
274 |
+
use_checkpoint=False,
|
275 |
+
disabled_sa=False,
|
276 |
+
):
|
277 |
+
return SpatialVideoTransformer(
|
278 |
+
ch,
|
279 |
+
num_heads,
|
280 |
+
dim_head,
|
281 |
+
depth=depth,
|
282 |
+
context_dim=context_dim,
|
283 |
+
time_context_dim=time_context_dim,
|
284 |
+
dropout=dropout,
|
285 |
+
ff_in=extra_ff_mix_layer,
|
286 |
+
use_spatial_context=use_spatial_context,
|
287 |
+
merge_strategy=merge_strategy,
|
288 |
+
merge_factor=merge_factor,
|
289 |
+
checkpoint=use_checkpoint,
|
290 |
+
use_linear=use_linear_in_transformer,
|
291 |
+
attn_mode=spatial_transformer_attn_type,
|
292 |
+
disable_self_attn=disabled_sa,
|
293 |
+
disable_temporal_crossattention=disable_temporal_crossattention,
|
294 |
+
max_time_embed_period=max_ddpm_temb_period,
|
295 |
+
)
|
296 |
+
|
297 |
+
def get_resblock(
|
298 |
+
merge_factor,
|
299 |
+
merge_strategy,
|
300 |
+
video_kernel_size,
|
301 |
+
ch,
|
302 |
+
time_embed_dim,
|
303 |
+
dropout,
|
304 |
+
out_ch,
|
305 |
+
dims,
|
306 |
+
use_checkpoint,
|
307 |
+
use_scale_shift_norm,
|
308 |
+
down=False,
|
309 |
+
up=False,
|
310 |
+
):
|
311 |
+
return VideoResBlock(
|
312 |
+
merge_factor=merge_factor,
|
313 |
+
merge_strategy=merge_strategy,
|
314 |
+
video_kernel_size=video_kernel_size,
|
315 |
+
channels=ch,
|
316 |
+
emb_channels=time_embed_dim,
|
317 |
+
dropout=dropout,
|
318 |
+
out_channels=out_ch,
|
319 |
+
dims=dims,
|
320 |
+
use_checkpoint=use_checkpoint,
|
321 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
322 |
+
down=down,
|
323 |
+
up=up,
|
324 |
+
)
|
325 |
+
|
326 |
+
for level, mult in enumerate(channel_mult):
|
327 |
+
for _ in range(num_res_blocks):
|
328 |
+
layers = [
|
329 |
+
get_resblock(
|
330 |
+
merge_factor=merge_factor,
|
331 |
+
merge_strategy=merge_strategy,
|
332 |
+
video_kernel_size=video_kernel_size,
|
333 |
+
ch=ch,
|
334 |
+
time_embed_dim=time_embed_dim,
|
335 |
+
dropout=dropout,
|
336 |
+
out_ch=mult * model_channels,
|
337 |
+
dims=dims,
|
338 |
+
use_checkpoint=use_checkpoint,
|
339 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
340 |
+
)
|
341 |
+
]
|
342 |
+
ch = mult * model_channels
|
343 |
+
if ds in attention_resolutions:
|
344 |
+
if num_head_channels == -1:
|
345 |
+
dim_head = ch // num_heads
|
346 |
+
else:
|
347 |
+
num_heads = ch // num_head_channels
|
348 |
+
dim_head = num_head_channels
|
349 |
+
|
350 |
+
layers.append(
|
351 |
+
get_attention_layer(
|
352 |
+
ch,
|
353 |
+
num_heads,
|
354 |
+
dim_head,
|
355 |
+
depth=transformer_depth[level],
|
356 |
+
context_dim=context_dim,
|
357 |
+
use_checkpoint=use_checkpoint,
|
358 |
+
disabled_sa=False,
|
359 |
+
)
|
360 |
+
)
|
361 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
362 |
+
self._feature_size += ch
|
363 |
+
input_block_chans.append(ch)
|
364 |
+
if level != len(channel_mult) - 1:
|
365 |
+
ds *= 2
|
366 |
+
out_ch = ch
|
367 |
+
self.input_blocks.append(
|
368 |
+
TimestepEmbedSequential(
|
369 |
+
get_resblock(
|
370 |
+
merge_factor=merge_factor,
|
371 |
+
merge_strategy=merge_strategy,
|
372 |
+
video_kernel_size=video_kernel_size,
|
373 |
+
ch=ch,
|
374 |
+
time_embed_dim=time_embed_dim,
|
375 |
+
dropout=dropout,
|
376 |
+
out_ch=out_ch,
|
377 |
+
dims=dims,
|
378 |
+
use_checkpoint=use_checkpoint,
|
379 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
380 |
+
down=True,
|
381 |
+
)
|
382 |
+
if resblock_updown
|
383 |
+
else Downsample(
|
384 |
+
ch,
|
385 |
+
conv_resample,
|
386 |
+
dims=dims,
|
387 |
+
out_channels=out_ch,
|
388 |
+
third_down=time_downup,
|
389 |
+
)
|
390 |
+
)
|
391 |
+
)
|
392 |
+
ch = out_ch
|
393 |
+
input_block_chans.append(ch)
|
394 |
+
|
395 |
+
self._feature_size += ch
|
396 |
+
|
397 |
+
if num_head_channels == -1:
|
398 |
+
dim_head = ch // num_heads
|
399 |
+
else:
|
400 |
+
num_heads = ch // num_head_channels
|
401 |
+
dim_head = num_head_channels
|
402 |
+
|
403 |
+
self.middle_block = TimestepEmbedSequential(
|
404 |
+
get_resblock(
|
405 |
+
merge_factor=merge_factor,
|
406 |
+
merge_strategy=merge_strategy,
|
407 |
+
video_kernel_size=video_kernel_size,
|
408 |
+
ch=ch,
|
409 |
+
time_embed_dim=time_embed_dim,
|
410 |
+
out_ch=None,
|
411 |
+
dropout=dropout,
|
412 |
+
dims=dims,
|
413 |
+
use_checkpoint=use_checkpoint,
|
414 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
415 |
+
),
|
416 |
+
get_attention_layer(
|
417 |
+
ch,
|
418 |
+
num_heads,
|
419 |
+
dim_head,
|
420 |
+
depth=transformer_depth_middle,
|
421 |
+
context_dim=context_dim,
|
422 |
+
use_checkpoint=use_checkpoint,
|
423 |
+
),
|
424 |
+
get_resblock(
|
425 |
+
merge_factor=merge_factor,
|
426 |
+
merge_strategy=merge_strategy,
|
427 |
+
video_kernel_size=video_kernel_size,
|
428 |
+
ch=ch,
|
429 |
+
out_ch=None,
|
430 |
+
time_embed_dim=time_embed_dim,
|
431 |
+
dropout=dropout,
|
432 |
+
dims=dims,
|
433 |
+
use_checkpoint=use_checkpoint,
|
434 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
435 |
+
),
|
436 |
+
)
|
437 |
+
self._feature_size += ch
|
438 |
+
|
439 |
+
self.merger = Merger(
|
440 |
+
merge_mode=merging_mode, input_channels=model_channels, frame_expansion=frame_expansion)
|
441 |
+
|
442 |
+
conditioning_channels = 3 if downsample_controlnet_cond else 4
|
443 |
+
block_out_channels = (320, 640, 1280, 1280)
|
444 |
+
|
445 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
446 |
+
conditioning_embedding_channels=block_out_channels[0],
|
447 |
+
conditioning_channels=conditioning_channels,
|
448 |
+
block_out_channels=conditioning_embedding_out_channels,
|
449 |
+
downsample=downsample_controlnet_cond,
|
450 |
+
final_3d_conv=condition_encoder.endswith("3DConv"),
|
451 |
+
use_controlnet_mask=use_controlnet_mask,
|
452 |
+
use_normalization=use_image_encoder_normalization,
|
453 |
+
)
|
454 |
+
|
455 |
+
def forward(
|
456 |
+
self,
|
457 |
+
x: th.Tensor,
|
458 |
+
timesteps: th.Tensor,
|
459 |
+
controlnet_cond: th.Tensor,
|
460 |
+
context: Optional[th.Tensor] = None,
|
461 |
+
y: Optional[th.Tensor] = None,
|
462 |
+
time_context: Optional[th.Tensor] = None,
|
463 |
+
num_video_frames: Optional[int] = None,
|
464 |
+
num_video_frames_conditional: Optional[int] = None,
|
465 |
+
image_only_indicator: Optional[th.Tensor] = None,
|
466 |
+
):
|
467 |
+
assert (y is not None) == (
|
468 |
+
self.num_classes is not None
|
469 |
+
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
470 |
+
hs = []
|
471 |
+
t_emb = timestep_embedding(
|
472 |
+
timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
473 |
+
|
474 |
+
emb = self.time_embed(t_emb)
|
475 |
+
|
476 |
+
# TODO restrict y to [:self.num_frames] (conditonal frames)
|
477 |
+
|
478 |
+
if self.num_classes is not None:
|
479 |
+
assert y.shape[0] == x.shape[0]
|
480 |
+
emb = emb + self.label_emb(y)
|
481 |
+
|
482 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
483 |
+
|
484 |
+
h = x
|
485 |
+
for idx, module in enumerate(self.input_blocks):
|
486 |
+
h = module(
|
487 |
+
h,
|
488 |
+
emb,
|
489 |
+
context=context,
|
490 |
+
image_only_indicator=image_only_indicator,
|
491 |
+
time_context=time_context,
|
492 |
+
num_video_frames=num_video_frames,
|
493 |
+
)
|
494 |
+
if idx == 0:
|
495 |
+
h = self.merger(h, controlnet_cond, num_video_frames=num_video_frames,
|
496 |
+
num_video_frames_conditional=num_video_frames_conditional)
|
497 |
+
|
498 |
+
hs.append(h)
|
499 |
+
h = self.middle_block(
|
500 |
+
h,
|
501 |
+
emb,
|
502 |
+
context=context,
|
503 |
+
image_only_indicator=image_only_indicator,
|
504 |
+
time_context=time_context,
|
505 |
+
num_video_frames=num_video_frames,
|
506 |
+
)
|
507 |
+
|
508 |
+
# 5. Control net blocks
|
509 |
+
|
510 |
+
down_block_res_samples = hs
|
511 |
+
|
512 |
+
mid_block_res_sample = h
|
513 |
+
|
514 |
+
return (down_block_res_samples, mid_block_res_sample)
|
515 |
+
|
516 |
+
@classmethod
|
517 |
+
def from_unet(cls,
|
518 |
+
model: OpenAIWrapper,
|
519 |
+
merging_mode: str = "addition",
|
520 |
+
zero_conv_mode: str = "Identity",
|
521 |
+
frame_expansion: str = "none",
|
522 |
+
downsample_controlnet_cond: bool = True,
|
523 |
+
use_image_encoder_normalization: bool = False,
|
524 |
+
use_controlnet_mask: bool = False,
|
525 |
+
condition_encoder: str = "",
|
526 |
+
conditioning_embedding_out_channels: List[int] = None,
|
527 |
+
|
528 |
+
):
|
529 |
+
|
530 |
+
unet: VideoUNet = model.diffusion_model
|
531 |
+
|
532 |
+
controlnet = cls(in_channels=unet.in_channels,
|
533 |
+
model_channels=unet.model_channels,
|
534 |
+
out_channels=unet.out_channels,
|
535 |
+
num_res_blocks=unet.num_res_blocks,
|
536 |
+
attention_resolutions=unet.attention_resolutions,
|
537 |
+
dropout=unet.dropout,
|
538 |
+
channel_mult=unet.channel_mult,
|
539 |
+
conv_resample=unet.conv_resample,
|
540 |
+
dims=unet.dims,
|
541 |
+
num_classes=unet.num_classes,
|
542 |
+
use_checkpoint=unet.use_checkpoint,
|
543 |
+
num_heads=unet.num_heads,
|
544 |
+
num_head_channels=unet.num_head_channels,
|
545 |
+
num_heads_upsample=unet.num_heads_upsample,
|
546 |
+
use_scale_shift_norm=unet.use_scale_shift_norm,
|
547 |
+
resblock_updown=unet.resblock_updown,
|
548 |
+
transformer_depth=unet.transformer_depth,
|
549 |
+
transformer_depth_middle=unet.transformer_depth_middle,
|
550 |
+
context_dim=unet.context_dim,
|
551 |
+
time_downup=unet.time_downup,
|
552 |
+
time_context_dim=unet.time_context_dim,
|
553 |
+
extra_ff_mix_layer=unet.extra_ff_mix_layer,
|
554 |
+
use_spatial_context=unet.use_spatial_context,
|
555 |
+
merge_strategy=unet.merge_strategy,
|
556 |
+
merge_factor=unet.merge_factor,
|
557 |
+
spatial_transformer_attn_type=unet.spatial_transformer_attn_type,
|
558 |
+
video_kernel_size=unet.video_kernel_size,
|
559 |
+
use_linear_in_transformer=unet.use_linear_in_transformer,
|
560 |
+
adm_in_channels=unet.adm_in_channels,
|
561 |
+
disable_temporal_crossattention=unet.disable_temporal_crossattention,
|
562 |
+
max_ddpm_temb_period=unet.max_ddpm_temb_period, # up to here unet params
|
563 |
+
merging_mode=merging_mode,
|
564 |
+
zero_conv_mode=zero_conv_mode,
|
565 |
+
frame_expansion=frame_expansion,
|
566 |
+
downsample_controlnet_cond=downsample_controlnet_cond,
|
567 |
+
use_image_encoder_normalization=use_image_encoder_normalization,
|
568 |
+
use_controlnet_mask=use_controlnet_mask,
|
569 |
+
condition_encoder=condition_encoder,
|
570 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
571 |
+
)
|
572 |
+
controlnet: ControlNet
|
573 |
+
|
574 |
+
return controlnet
|
575 |
+
|
576 |
+
|
577 |
+
def zero_module(module, reset=True):
|
578 |
+
if reset:
|
579 |
+
for p in module.parameters():
|
580 |
+
nn.init.zeros_(p)
|
581 |
+
return module
|
models/diffusion/discretizer.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from models.svd.sgm.modules.diffusionmodules.discretizer import Discretization
|
5 |
+
|
6 |
+
|
7 |
+
# Implementation of https://arxiv.org/abs/2404.14507
|
8 |
+
class AlignYourSteps(Discretization):
|
9 |
+
|
10 |
+
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
|
11 |
+
self.sigma_min = sigma_min
|
12 |
+
self.sigma_max = sigma_max
|
13 |
+
self.rho = rho
|
14 |
+
|
15 |
+
def loglinear_interp(self, t_steps, num_steps):
|
16 |
+
"""
|
17 |
+
Performs log-linear interpolation of a given array of decreasing numbers.
|
18 |
+
"""
|
19 |
+
xs = np.linspace(0, 1, len(t_steps))
|
20 |
+
ys = np.log(t_steps[::-1])
|
21 |
+
|
22 |
+
new_xs = np.linspace(0, 1, num_steps)
|
23 |
+
new_ys = np.interp(new_xs, xs, ys)
|
24 |
+
|
25 |
+
interped_ys = np.exp(new_ys)[::-1].copy()
|
26 |
+
return interped_ys
|
27 |
+
|
28 |
+
def get_sigmas(self, n, device="cpu"):
|
29 |
+
sampling_schedule = [700.00, 54.5, 15.886, 7.977,
|
30 |
+
4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]
|
31 |
+
sigmas = torch.from_numpy(self.loglinear_interp(
|
32 |
+
sampling_schedule, n)).to(device)
|
33 |
+
return sigmas
|
models/diffusion/video_model.py
ADDED
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/diffusionmodules/video_model.py
|
2 |
+
from functools import partial
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from models.svd.sgm.modules.diffusionmodules.openaimodel import *
|
8 |
+
from models.svd.sgm.modules.video_attention import SpatialVideoTransformer
|
9 |
+
from models.svd.sgm.util import default
|
10 |
+
from models.svd.sgm.modules.diffusionmodules.util import AlphaBlender
|
11 |
+
from functools import partial
|
12 |
+
from models.cam.conditioning import ConditionalModel
|
13 |
+
|
14 |
+
|
15 |
+
class VideoResBlock(ResBlock):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
channels: int,
|
19 |
+
emb_channels: int,
|
20 |
+
dropout: float,
|
21 |
+
video_kernel_size: Union[int, List[int]] = 3,
|
22 |
+
merge_strategy: str = "fixed",
|
23 |
+
merge_factor: float = 0.5,
|
24 |
+
out_channels: Optional[int] = None,
|
25 |
+
use_conv: bool = False,
|
26 |
+
use_scale_shift_norm: bool = False,
|
27 |
+
dims: int = 2,
|
28 |
+
use_checkpoint: bool = False,
|
29 |
+
up: bool = False,
|
30 |
+
down: bool = False,
|
31 |
+
):
|
32 |
+
super().__init__(
|
33 |
+
channels,
|
34 |
+
emb_channels,
|
35 |
+
dropout,
|
36 |
+
out_channels=out_channels,
|
37 |
+
use_conv=use_conv,
|
38 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
39 |
+
dims=dims,
|
40 |
+
use_checkpoint=use_checkpoint,
|
41 |
+
up=up,
|
42 |
+
down=down,
|
43 |
+
)
|
44 |
+
|
45 |
+
self.time_stack = ResBlock(
|
46 |
+
default(out_channels, channels),
|
47 |
+
emb_channels,
|
48 |
+
dropout=dropout,
|
49 |
+
dims=3,
|
50 |
+
out_channels=default(out_channels, channels),
|
51 |
+
use_scale_shift_norm=False,
|
52 |
+
use_conv=False,
|
53 |
+
up=False,
|
54 |
+
down=False,
|
55 |
+
kernel_size=video_kernel_size,
|
56 |
+
use_checkpoint=use_checkpoint,
|
57 |
+
exchange_temb_dims=True,
|
58 |
+
)
|
59 |
+
self.time_mixer = AlphaBlender(
|
60 |
+
alpha=merge_factor,
|
61 |
+
merge_strategy=merge_strategy,
|
62 |
+
rearrange_pattern="b t -> b 1 t 1 1",
|
63 |
+
)
|
64 |
+
|
65 |
+
def forward(
|
66 |
+
self,
|
67 |
+
x: th.Tensor,
|
68 |
+
emb: th.Tensor,
|
69 |
+
num_video_frames: int,
|
70 |
+
image_only_indicator: Optional[th.Tensor] = None,
|
71 |
+
) -> th.Tensor:
|
72 |
+
x = super().forward(x, emb)
|
73 |
+
|
74 |
+
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
75 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
76 |
+
|
77 |
+
x = self.time_stack(
|
78 |
+
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
79 |
+
)
|
80 |
+
x = self.time_mixer(
|
81 |
+
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
82 |
+
)
|
83 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
class VideoUNet(nn.Module):
|
88 |
+
'''
|
89 |
+
Adapted from the vanilla SVD model. We add "cross_attention_merger_input_blocks" and "cross_attention_merger_mid_block" to incorporate the CAM control features.
|
90 |
+
|
91 |
+
'''
|
92 |
+
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
in_channels: int,
|
96 |
+
model_channels: int,
|
97 |
+
out_channels: int,
|
98 |
+
num_res_blocks: int,
|
99 |
+
num_conditional_frames: int,
|
100 |
+
attention_resolutions: Union[List[int], int],
|
101 |
+
dropout: float = 0.0,
|
102 |
+
channel_mult: List[int] = (1, 2, 4, 8),
|
103 |
+
conv_resample: bool = True,
|
104 |
+
dims: int = 2,
|
105 |
+
num_classes: Optional[Union[int, str]] = None,
|
106 |
+
use_checkpoint: bool = False,
|
107 |
+
num_heads: int = -1,
|
108 |
+
num_head_channels: int = -1,
|
109 |
+
num_heads_upsample: int = -1,
|
110 |
+
use_scale_shift_norm: bool = False,
|
111 |
+
resblock_updown: bool = False,
|
112 |
+
transformer_depth: Union[List[int], int] = 1,
|
113 |
+
transformer_depth_middle: Optional[int] = None,
|
114 |
+
context_dim: Optional[int] = None,
|
115 |
+
time_downup: bool = False,
|
116 |
+
time_context_dim: Optional[int] = None,
|
117 |
+
extra_ff_mix_layer: bool = False,
|
118 |
+
use_spatial_context: bool = False,
|
119 |
+
merge_strategy: str = "fixed",
|
120 |
+
merge_factor: float = 0.5,
|
121 |
+
spatial_transformer_attn_type: str = "softmax",
|
122 |
+
video_kernel_size: Union[int, List[int]] = 3,
|
123 |
+
use_linear_in_transformer: bool = False,
|
124 |
+
adm_in_channels: Optional[int] = None,
|
125 |
+
disable_temporal_crossattention: bool = False,
|
126 |
+
max_ddpm_temb_period: int = 10000,
|
127 |
+
merging_mode: str = "addition",
|
128 |
+
controlnet_mode: bool = False,
|
129 |
+
use_apm: bool = False,
|
130 |
+
):
|
131 |
+
super().__init__()
|
132 |
+
assert context_dim is not None
|
133 |
+
self.controlnet_mode = controlnet_mode
|
134 |
+
if controlnet_mode:
|
135 |
+
assert merging_mode.startswith(
|
136 |
+
"attention"), "other merging modes not implemented"
|
137 |
+
AttentionCondModel = partial(
|
138 |
+
ConditionalModel, conditional_model=merging_mode.split("attention_")[1])
|
139 |
+
self.cross_attention_merger_input_blocks = nn.ModuleList([])
|
140 |
+
if num_heads_upsample == -1:
|
141 |
+
num_heads_upsample = num_heads
|
142 |
+
|
143 |
+
if num_heads == -1:
|
144 |
+
assert num_head_channels != -1
|
145 |
+
|
146 |
+
if num_head_channels == -1:
|
147 |
+
assert num_heads != -1
|
148 |
+
|
149 |
+
self.in_channels = in_channels
|
150 |
+
self.model_channels = model_channels
|
151 |
+
self.out_channels = out_channels
|
152 |
+
if isinstance(transformer_depth, int):
|
153 |
+
transformer_depth = len(channel_mult) * [transformer_depth]
|
154 |
+
transformer_depth_middle = default(
|
155 |
+
transformer_depth_middle, transformer_depth[-1]
|
156 |
+
)
|
157 |
+
|
158 |
+
self.num_res_blocks = num_res_blocks
|
159 |
+
self.attention_resolutions = attention_resolutions
|
160 |
+
self.dropout = dropout
|
161 |
+
self.channel_mult = channel_mult
|
162 |
+
self.conv_resample = conv_resample
|
163 |
+
self.num_classes = num_classes
|
164 |
+
self.use_checkpoint = use_checkpoint
|
165 |
+
self.num_heads = num_heads
|
166 |
+
self.num_head_channels = num_head_channels
|
167 |
+
self.num_heads_upsample = num_heads_upsample
|
168 |
+
self.dims = dims
|
169 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
170 |
+
self.resblock_updown = resblock_updown
|
171 |
+
self.transformer_depth = transformer_depth
|
172 |
+
self.transformer_depth_middle = transformer_depth_middle
|
173 |
+
self.context_dim = context_dim
|
174 |
+
self.time_downup = time_downup
|
175 |
+
self.time_context_dim = time_context_dim
|
176 |
+
self.extra_ff_mix_layer = extra_ff_mix_layer
|
177 |
+
self.use_spatial_context = use_spatial_context
|
178 |
+
self.merge_strategy = merge_strategy
|
179 |
+
self.merge_factor = merge_factor
|
180 |
+
self.spatial_transformer_attn_type = spatial_transformer_attn_type
|
181 |
+
self.video_kernel_size = video_kernel_size
|
182 |
+
self.use_linear_in_transformer = use_linear_in_transformer
|
183 |
+
self.adm_in_channels = adm_in_channels
|
184 |
+
self.disable_temporal_crossattention = disable_temporal_crossattention
|
185 |
+
self.max_ddpm_temb_period = max_ddpm_temb_period
|
186 |
+
|
187 |
+
time_embed_dim = model_channels * 4
|
188 |
+
self.time_embed = nn.Sequential(
|
189 |
+
linear(model_channels, time_embed_dim),
|
190 |
+
nn.SiLU(),
|
191 |
+
linear(time_embed_dim, time_embed_dim),
|
192 |
+
)
|
193 |
+
|
194 |
+
if self.num_classes is not None:
|
195 |
+
if isinstance(self.num_classes, int):
|
196 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
197 |
+
elif self.num_classes == "continuous":
|
198 |
+
print("setting up linear c_adm embedding layer")
|
199 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
200 |
+
elif self.num_classes == "timestep":
|
201 |
+
self.label_emb = nn.Sequential(
|
202 |
+
Timestep(model_channels),
|
203 |
+
nn.Sequential(
|
204 |
+
linear(model_channels, time_embed_dim),
|
205 |
+
nn.SiLU(),
|
206 |
+
linear(time_embed_dim, time_embed_dim),
|
207 |
+
),
|
208 |
+
)
|
209 |
+
|
210 |
+
elif self.num_classes == "sequential":
|
211 |
+
assert adm_in_channels is not None
|
212 |
+
self.label_emb = nn.Sequential(
|
213 |
+
nn.Sequential(
|
214 |
+
linear(adm_in_channels, time_embed_dim),
|
215 |
+
nn.SiLU(),
|
216 |
+
linear(time_embed_dim, time_embed_dim),
|
217 |
+
)
|
218 |
+
)
|
219 |
+
else:
|
220 |
+
raise ValueError()
|
221 |
+
|
222 |
+
self.input_blocks = nn.ModuleList(
|
223 |
+
[
|
224 |
+
TimestepEmbedSequential(
|
225 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
226 |
+
)
|
227 |
+
]
|
228 |
+
)
|
229 |
+
self._feature_size = model_channels
|
230 |
+
input_block_chans = [model_channels]
|
231 |
+
ch = model_channels
|
232 |
+
ds = 1
|
233 |
+
if controlnet_mode and merging_mode.startswith("attention"):
|
234 |
+
self.cross_attention_merger_input_blocks.append(
|
235 |
+
AttentionCondModel(input_channels=ch))
|
236 |
+
|
237 |
+
def get_attention_layer(
|
238 |
+
ch,
|
239 |
+
num_heads,
|
240 |
+
dim_head,
|
241 |
+
depth=1,
|
242 |
+
context_dim=None,
|
243 |
+
use_checkpoint=False,
|
244 |
+
disabled_sa=False,
|
245 |
+
use_apm: bool = False,
|
246 |
+
):
|
247 |
+
return SpatialVideoTransformer(
|
248 |
+
ch,
|
249 |
+
num_heads,
|
250 |
+
dim_head,
|
251 |
+
depth=depth,
|
252 |
+
context_dim=context_dim,
|
253 |
+
time_context_dim=time_context_dim,
|
254 |
+
dropout=dropout,
|
255 |
+
ff_in=extra_ff_mix_layer,
|
256 |
+
use_spatial_context=use_spatial_context,
|
257 |
+
merge_strategy=merge_strategy,
|
258 |
+
merge_factor=merge_factor,
|
259 |
+
checkpoint=use_checkpoint,
|
260 |
+
use_linear=use_linear_in_transformer,
|
261 |
+
attn_mode=spatial_transformer_attn_type,
|
262 |
+
disable_self_attn=disabled_sa,
|
263 |
+
disable_temporal_crossattention=disable_temporal_crossattention,
|
264 |
+
max_time_embed_period=max_ddpm_temb_period,
|
265 |
+
use_apm=use_apm,
|
266 |
+
)
|
267 |
+
|
268 |
+
def get_resblock(
|
269 |
+
merge_factor,
|
270 |
+
merge_strategy,
|
271 |
+
video_kernel_size,
|
272 |
+
ch,
|
273 |
+
time_embed_dim,
|
274 |
+
dropout,
|
275 |
+
out_ch,
|
276 |
+
dims,
|
277 |
+
use_checkpoint,
|
278 |
+
use_scale_shift_norm,
|
279 |
+
down=False,
|
280 |
+
up=False,
|
281 |
+
):
|
282 |
+
return VideoResBlock(
|
283 |
+
merge_factor=merge_factor,
|
284 |
+
merge_strategy=merge_strategy,
|
285 |
+
video_kernel_size=video_kernel_size,
|
286 |
+
channels=ch,
|
287 |
+
emb_channels=time_embed_dim,
|
288 |
+
dropout=dropout,
|
289 |
+
out_channels=out_ch,
|
290 |
+
dims=dims,
|
291 |
+
use_checkpoint=use_checkpoint,
|
292 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
293 |
+
down=down,
|
294 |
+
up=up,
|
295 |
+
)
|
296 |
+
|
297 |
+
for level, mult in enumerate(channel_mult):
|
298 |
+
for _ in range(num_res_blocks):
|
299 |
+
layers = [
|
300 |
+
get_resblock(
|
301 |
+
merge_factor=merge_factor,
|
302 |
+
merge_strategy=merge_strategy,
|
303 |
+
video_kernel_size=video_kernel_size,
|
304 |
+
ch=ch,
|
305 |
+
time_embed_dim=time_embed_dim,
|
306 |
+
dropout=dropout,
|
307 |
+
out_ch=mult * model_channels,
|
308 |
+
dims=dims,
|
309 |
+
use_checkpoint=use_checkpoint,
|
310 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
311 |
+
)
|
312 |
+
]
|
313 |
+
ch = mult * model_channels
|
314 |
+
if ds in attention_resolutions:
|
315 |
+
if num_head_channels == -1:
|
316 |
+
dim_head = ch // num_heads
|
317 |
+
else:
|
318 |
+
num_heads = ch // num_head_channels
|
319 |
+
dim_head = num_head_channels
|
320 |
+
|
321 |
+
layers.append(
|
322 |
+
get_attention_layer(
|
323 |
+
ch,
|
324 |
+
num_heads,
|
325 |
+
dim_head,
|
326 |
+
depth=transformer_depth[level],
|
327 |
+
context_dim=context_dim,
|
328 |
+
use_checkpoint=use_checkpoint,
|
329 |
+
disabled_sa=False,
|
330 |
+
use_apm=use_apm,
|
331 |
+
)
|
332 |
+
)
|
333 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
334 |
+
if controlnet_mode and merging_mode.startswith("attention"):
|
335 |
+
self.cross_attention_merger_input_blocks.append(
|
336 |
+
AttentionCondModel(input_channels=ch))
|
337 |
+
self._feature_size += ch
|
338 |
+
input_block_chans.append(ch)
|
339 |
+
if level != len(channel_mult) - 1:
|
340 |
+
ds *= 2
|
341 |
+
out_ch = ch
|
342 |
+
self.input_blocks.append(
|
343 |
+
TimestepEmbedSequential(
|
344 |
+
get_resblock(
|
345 |
+
merge_factor=merge_factor,
|
346 |
+
merge_strategy=merge_strategy,
|
347 |
+
video_kernel_size=video_kernel_size,
|
348 |
+
ch=ch,
|
349 |
+
time_embed_dim=time_embed_dim,
|
350 |
+
dropout=dropout,
|
351 |
+
out_ch=out_ch,
|
352 |
+
dims=dims,
|
353 |
+
use_checkpoint=use_checkpoint,
|
354 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
355 |
+
down=True,
|
356 |
+
)
|
357 |
+
if resblock_updown
|
358 |
+
else Downsample(
|
359 |
+
ch,
|
360 |
+
conv_resample,
|
361 |
+
dims=dims,
|
362 |
+
out_channels=out_ch,
|
363 |
+
third_down=time_downup,
|
364 |
+
)
|
365 |
+
)
|
366 |
+
)
|
367 |
+
ch = out_ch
|
368 |
+
input_block_chans.append(ch)
|
369 |
+
|
370 |
+
if controlnet_mode and merging_mode.startswith("attention"):
|
371 |
+
self.cross_attention_merger_input_blocks.append(
|
372 |
+
AttentionCondModel(input_channels=ch))
|
373 |
+
self._feature_size += ch
|
374 |
+
|
375 |
+
if num_head_channels == -1:
|
376 |
+
dim_head = ch // num_heads
|
377 |
+
else:
|
378 |
+
num_heads = ch // num_head_channels
|
379 |
+
dim_head = num_head_channels
|
380 |
+
|
381 |
+
self.middle_block = TimestepEmbedSequential(
|
382 |
+
get_resblock(
|
383 |
+
merge_factor=merge_factor,
|
384 |
+
merge_strategy=merge_strategy,
|
385 |
+
video_kernel_size=video_kernel_size,
|
386 |
+
ch=ch,
|
387 |
+
time_embed_dim=time_embed_dim,
|
388 |
+
out_ch=None,
|
389 |
+
dropout=dropout,
|
390 |
+
dims=dims,
|
391 |
+
use_checkpoint=use_checkpoint,
|
392 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
393 |
+
),
|
394 |
+
get_attention_layer(
|
395 |
+
ch,
|
396 |
+
num_heads,
|
397 |
+
dim_head,
|
398 |
+
depth=transformer_depth_middle,
|
399 |
+
context_dim=context_dim,
|
400 |
+
use_checkpoint=use_checkpoint,
|
401 |
+
use_apm=use_apm,
|
402 |
+
),
|
403 |
+
get_resblock(
|
404 |
+
merge_factor=merge_factor,
|
405 |
+
merge_strategy=merge_strategy,
|
406 |
+
video_kernel_size=video_kernel_size,
|
407 |
+
ch=ch,
|
408 |
+
out_ch=None,
|
409 |
+
time_embed_dim=time_embed_dim,
|
410 |
+
dropout=dropout,
|
411 |
+
dims=dims,
|
412 |
+
use_checkpoint=use_checkpoint,
|
413 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
414 |
+
),
|
415 |
+
)
|
416 |
+
self._feature_size += ch
|
417 |
+
if controlnet_mode and merging_mode.startswith("attention"):
|
418 |
+
self.cross_attention_merger_mid_block = AttentionCondModel(
|
419 |
+
input_channels=ch)
|
420 |
+
|
421 |
+
self.output_blocks = nn.ModuleList([])
|
422 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
423 |
+
for i in range(num_res_blocks + 1):
|
424 |
+
ich = input_block_chans.pop()
|
425 |
+
layers = [
|
426 |
+
get_resblock(
|
427 |
+
merge_factor=merge_factor,
|
428 |
+
merge_strategy=merge_strategy,
|
429 |
+
video_kernel_size=video_kernel_size,
|
430 |
+
ch=ch + ich,
|
431 |
+
time_embed_dim=time_embed_dim,
|
432 |
+
dropout=dropout,
|
433 |
+
out_ch=model_channels * mult,
|
434 |
+
dims=dims,
|
435 |
+
use_checkpoint=use_checkpoint,
|
436 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
437 |
+
)
|
438 |
+
]
|
439 |
+
ch = model_channels * mult
|
440 |
+
if ds in attention_resolutions:
|
441 |
+
if num_head_channels == -1:
|
442 |
+
dim_head = ch // num_heads
|
443 |
+
else:
|
444 |
+
num_heads = ch // num_head_channels
|
445 |
+
dim_head = num_head_channels
|
446 |
+
|
447 |
+
layers.append(
|
448 |
+
get_attention_layer(
|
449 |
+
ch,
|
450 |
+
num_heads,
|
451 |
+
dim_head,
|
452 |
+
depth=transformer_depth[level],
|
453 |
+
context_dim=context_dim,
|
454 |
+
use_checkpoint=use_checkpoint,
|
455 |
+
disabled_sa=False,
|
456 |
+
use_apm=use_apm,
|
457 |
+
)
|
458 |
+
)
|
459 |
+
if level and i == num_res_blocks:
|
460 |
+
out_ch = ch
|
461 |
+
ds //= 2
|
462 |
+
layers.append(
|
463 |
+
get_resblock(
|
464 |
+
merge_factor=merge_factor,
|
465 |
+
merge_strategy=merge_strategy,
|
466 |
+
video_kernel_size=video_kernel_size,
|
467 |
+
ch=ch,
|
468 |
+
time_embed_dim=time_embed_dim,
|
469 |
+
dropout=dropout,
|
470 |
+
out_ch=out_ch,
|
471 |
+
dims=dims,
|
472 |
+
use_checkpoint=use_checkpoint,
|
473 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
474 |
+
up=True,
|
475 |
+
)
|
476 |
+
if resblock_updown
|
477 |
+
else Upsample(
|
478 |
+
ch,
|
479 |
+
conv_resample,
|
480 |
+
dims=dims,
|
481 |
+
out_channels=out_ch,
|
482 |
+
third_up=time_downup,
|
483 |
+
)
|
484 |
+
)
|
485 |
+
|
486 |
+
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
487 |
+
self._feature_size += ch
|
488 |
+
|
489 |
+
self.out = nn.Sequential(
|
490 |
+
normalization(ch),
|
491 |
+
nn.SiLU(),
|
492 |
+
zero_module(conv_nd(dims, model_channels,
|
493 |
+
out_channels, 3, padding=1)),
|
494 |
+
)
|
495 |
+
|
496 |
+
def forward(
|
497 |
+
self,
|
498 |
+
# [28,8,72,128], i.e. (B F) (2 C) H W = concat([z_t,<cond_frames>])
|
499 |
+
x: th.Tensor,
|
500 |
+
timesteps: th.Tensor, # [28], i.e. (B F)
|
501 |
+
# [28, 1, 1024], i.e. (B F) 1 T, for cross attention from clip image encoder, <cond_frames_without_noise>
|
502 |
+
context: Optional[th.Tensor] = None,
|
503 |
+
# [28, 768], i.e. (B F) T ? concat([<fps_id>,<motion_bucket_id>,<cond_aug>]
|
504 |
+
y: Optional[th.Tensor] = None,
|
505 |
+
time_context: Optional[th.Tensor] = None, # NONE
|
506 |
+
num_video_frames: Optional[int] = None, # 14
|
507 |
+
num_conditional_frames: Optional[int] = None, # 8
|
508 |
+
# zeros, [2,14], i.e. [B, F]
|
509 |
+
image_only_indicator: Optional[th.Tensor] = None,
|
510 |
+
hs_control_input: Optional[th.Tensor] = None, # cam features
|
511 |
+
hs_control_mid: Optional[th.Tensor] = None, # cam features
|
512 |
+
):
|
513 |
+
assert (y is not None) == (
|
514 |
+
self.num_classes is not None
|
515 |
+
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
516 |
+
hs = []
|
517 |
+
t_emb = timestep_embedding(
|
518 |
+
timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
519 |
+
emb = self.time_embed(t_emb)
|
520 |
+
|
521 |
+
if self.num_classes is not None:
|
522 |
+
assert y.shape[0] == x.shape[0]
|
523 |
+
emb = emb + self.label_emb(y)
|
524 |
+
|
525 |
+
h = x
|
526 |
+
for module in self.input_blocks:
|
527 |
+
h = module(
|
528 |
+
h,
|
529 |
+
emb,
|
530 |
+
context=context,
|
531 |
+
image_only_indicator=image_only_indicator,
|
532 |
+
time_context=time_context,
|
533 |
+
num_video_frames=num_video_frames,
|
534 |
+
)
|
535 |
+
hs.append(h)
|
536 |
+
|
537 |
+
# fusion of cam features with base features
|
538 |
+
if hs_control_input is not None:
|
539 |
+
new_hs = []
|
540 |
+
|
541 |
+
assert len(hs) == len(hs_control_input) and len(
|
542 |
+
hs) == len(self.cross_attention_merger_input_blocks)
|
543 |
+
for h_no_ctrl, h_ctrl, merger in zip(hs, hs_control_input, self.cross_attention_merger_input_blocks):
|
544 |
+
merged_h = merger(h_no_ctrl, h_ctrl, num_frames=num_video_frames,
|
545 |
+
num_conditional_frames=num_conditional_frames)
|
546 |
+
new_hs.append(merged_h)
|
547 |
+
hs = new_hs
|
548 |
+
|
549 |
+
h = self.middle_block(
|
550 |
+
h,
|
551 |
+
emb,
|
552 |
+
context=context,
|
553 |
+
image_only_indicator=image_only_indicator,
|
554 |
+
time_context=time_context,
|
555 |
+
num_video_frames=num_video_frames,
|
556 |
+
)
|
557 |
+
|
558 |
+
# fusion of cam features with base features
|
559 |
+
if hs_control_mid is not None:
|
560 |
+
h = self.cross_attention_merger_mid_block(
|
561 |
+
h, hs_control_mid, num_frames=num_video_frames, num_conditional_frames=num_conditional_frames)
|
562 |
+
|
563 |
+
for module in self.output_blocks:
|
564 |
+
h = th.cat([h, hs.pop()], dim=1)
|
565 |
+
h = module(
|
566 |
+
h,
|
567 |
+
emb,
|
568 |
+
context=context,
|
569 |
+
image_only_indicator=image_only_indicator,
|
570 |
+
time_context=time_context,
|
571 |
+
num_video_frames=num_video_frames,
|
572 |
+
)
|
573 |
+
h = h.type(x.dtype)
|
574 |
+
return self.out(h)
|
models/diffusion/wrappers.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from models.svd.sgm.modules.diffusionmodules.wrappers import OpenAIWrapper
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
|
6 |
+
|
7 |
+
class StreamingWrapper(OpenAIWrapper):
|
8 |
+
"""
|
9 |
+
Modelwrapper for StreamingSVD, which holds the CAM model and the base model
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, diffusion_model, controlnet, num_frame_conditioning: int, compile_model: bool = False, pipeline_offloading: bool = False):
|
14 |
+
super().__init__(diffusion_model=diffusion_model,
|
15 |
+
compile_model=compile_model)
|
16 |
+
self.controlnet = controlnet
|
17 |
+
self.num_frame_conditioning = num_frame_conditioning
|
18 |
+
self.pipeline_offloading = pipeline_offloading
|
19 |
+
if pipeline_offloading:
|
20 |
+
raise NotImplementedError(
|
21 |
+
"Pipeline offloading for StreamingI2V not implemented yet.")
|
22 |
+
|
23 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs):
|
24 |
+
|
25 |
+
batch_size = kwargs.pop("batch_size")
|
26 |
+
|
27 |
+
# We apply the controlnet model only to the control frames.
|
28 |
+
def reduce_to_cond_frames(input):
|
29 |
+
input = rearrange(input, "(B F) ... -> B F ...", B=batch_size)
|
30 |
+
input = input[:, :self.num_frame_conditioning]
|
31 |
+
return rearrange(input, "B F ... -> (B F) ...")
|
32 |
+
|
33 |
+
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
34 |
+
x_ctrl = reduce_to_cond_frames(x)
|
35 |
+
t_ctrl = reduce_to_cond_frames(t)
|
36 |
+
|
37 |
+
context = c.get("crossattn", None)
|
38 |
+
# controlnet is not using APM so we remove potentially additional tokens
|
39 |
+
context_ctrl = context[:, :1]
|
40 |
+
context_ctrl = reduce_to_cond_frames(context_ctrl)
|
41 |
+
y = c.get("vector", None)
|
42 |
+
y_ctrl = reduce_to_cond_frames(y)
|
43 |
+
num_video_frames = kwargs.pop("num_video_frames")
|
44 |
+
image_only_indicator = kwargs.pop("image_only_indicator")
|
45 |
+
ctrl_img_enc_frames = repeat(
|
46 |
+
kwargs['ctrl_frames'], "B ... -> (2 B) ... ")
|
47 |
+
controlnet_cond = rearrange(
|
48 |
+
ctrl_img_enc_frames, "B F ... -> (B F) ...")
|
49 |
+
|
50 |
+
if self.diffusion_model.controlnet_mode:
|
51 |
+
hs_control_input, hs_control_mid = self.controlnet(x=x_ctrl, # video latent
|
52 |
+
timesteps=t_ctrl, # timestep
|
53 |
+
context=context_ctrl, # clip image conditioning
|
54 |
+
y=y_ctrl, # conditionigs, e.g. fps
|
55 |
+
controlnet_cond=controlnet_cond, # control frames
|
56 |
+
num_video_frames=self.num_frame_conditioning,
|
57 |
+
num_video_frames_conditional=self.num_frame_conditioning,
|
58 |
+
image_only_indicator=image_only_indicator[:,
|
59 |
+
:self.num_frame_conditioning]
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
hs_control_input = None
|
63 |
+
hs_control_mid = None
|
64 |
+
kwargs["hs_control_input"] = hs_control_input
|
65 |
+
kwargs["hs_control_mid"] = hs_control_mid
|
66 |
+
|
67 |
+
out = self.diffusion_model(
|
68 |
+
x=x,
|
69 |
+
timesteps=t,
|
70 |
+
context=context, # must be (B F) T C
|
71 |
+
y=y, # must be (B F) 768
|
72 |
+
num_video_frames=num_video_frames,
|
73 |
+
num_conditional_frames=self.num_frame_conditioning,
|
74 |
+
image_only_indicator=image_only_indicator,
|
75 |
+
hs_control_input=hs_control_input,
|
76 |
+
hs_control_mid=hs_control_mid,
|
77 |
+
)
|
78 |
+
return out
|
models/svd/sgm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.svd.sgm.models import AutoencodingEngine, DiffusionEngine
|
2 |
+
from models.svd.sgm.util import get_configs_path, instantiate_from_config
|
3 |
+
|
4 |
+
__version__ = "0.1.0"
|
models/svd/sgm/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .dataset import StableDataModuleFromConfig
|
models/svd/sgm/data/cifar10.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torchvision
|
3 |
+
from torch.utils.data import DataLoader, Dataset
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
|
7 |
+
class CIFAR10DataDictWrapper(Dataset):
|
8 |
+
def __init__(self, dset):
|
9 |
+
super().__init__()
|
10 |
+
self.dset = dset
|
11 |
+
|
12 |
+
def __getitem__(self, i):
|
13 |
+
x, y = self.dset[i]
|
14 |
+
return {"jpg": x, "cls": y}
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.dset)
|
18 |
+
|
19 |
+
|
20 |
+
class CIFAR10Loader(pl.LightningDataModule):
|
21 |
+
def __init__(self, batch_size, num_workers=0, shuffle=True):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
transform = transforms.Compose(
|
25 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
26 |
+
)
|
27 |
+
|
28 |
+
self.batch_size = batch_size
|
29 |
+
self.num_workers = num_workers
|
30 |
+
self.shuffle = shuffle
|
31 |
+
self.train_dataset = CIFAR10DataDictWrapper(
|
32 |
+
torchvision.datasets.CIFAR10(
|
33 |
+
root=".data/", train=True, download=True, transform=transform
|
34 |
+
)
|
35 |
+
)
|
36 |
+
self.test_dataset = CIFAR10DataDictWrapper(
|
37 |
+
torchvision.datasets.CIFAR10(
|
38 |
+
root=".data/", train=False, download=True, transform=transform
|
39 |
+
)
|
40 |
+
)
|
41 |
+
|
42 |
+
def prepare_data(self):
|
43 |
+
pass
|
44 |
+
|
45 |
+
def train_dataloader(self):
|
46 |
+
return DataLoader(
|
47 |
+
self.train_dataset,
|
48 |
+
batch_size=self.batch_size,
|
49 |
+
shuffle=self.shuffle,
|
50 |
+
num_workers=self.num_workers,
|
51 |
+
)
|
52 |
+
|
53 |
+
def test_dataloader(self):
|
54 |
+
return DataLoader(
|
55 |
+
self.test_dataset,
|
56 |
+
batch_size=self.batch_size,
|
57 |
+
shuffle=self.shuffle,
|
58 |
+
num_workers=self.num_workers,
|
59 |
+
)
|
60 |
+
|
61 |
+
def val_dataloader(self):
|
62 |
+
return DataLoader(
|
63 |
+
self.test_dataset,
|
64 |
+
batch_size=self.batch_size,
|
65 |
+
shuffle=self.shuffle,
|
66 |
+
num_workers=self.num_workers,
|
67 |
+
)
|
models/svd/sgm/data/dataset.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torchdata.datapipes.iter
|
4 |
+
import webdataset as wds
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
from pytorch_lightning import LightningDataModule
|
7 |
+
|
8 |
+
try:
|
9 |
+
from sdata import create_dataset, create_dummy_dataset, create_loader
|
10 |
+
except ImportError as e:
|
11 |
+
print("#" * 100)
|
12 |
+
print("Datasets not yet available")
|
13 |
+
print("to enable, we need to add stable-datasets as a submodule")
|
14 |
+
print("please use ``git submodule update --init --recursive``")
|
15 |
+
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
|
16 |
+
print("#" * 100)
|
17 |
+
exit(1)
|
18 |
+
|
19 |
+
|
20 |
+
class StableDataModuleFromConfig(LightningDataModule):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
train: DictConfig,
|
24 |
+
validation: Optional[DictConfig] = None,
|
25 |
+
test: Optional[DictConfig] = None,
|
26 |
+
skip_val_loader: bool = False,
|
27 |
+
dummy: bool = False,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.train_config = train
|
31 |
+
assert (
|
32 |
+
"datapipeline" in self.train_config and "loader" in self.train_config
|
33 |
+
), "train config requires the fields `datapipeline` and `loader`"
|
34 |
+
|
35 |
+
self.val_config = validation
|
36 |
+
if not skip_val_loader:
|
37 |
+
if self.val_config is not None:
|
38 |
+
assert (
|
39 |
+
"datapipeline" in self.val_config and "loader" in self.val_config
|
40 |
+
), "validation config requires the fields `datapipeline` and `loader`"
|
41 |
+
else:
|
42 |
+
print(
|
43 |
+
"Warning: No Validation datapipeline defined, using that one from training"
|
44 |
+
)
|
45 |
+
self.val_config = train
|
46 |
+
|
47 |
+
self.test_config = test
|
48 |
+
if self.test_config is not None:
|
49 |
+
assert (
|
50 |
+
"datapipeline" in self.test_config and "loader" in self.test_config
|
51 |
+
), "test config requires the fields `datapipeline` and `loader`"
|
52 |
+
|
53 |
+
self.dummy = dummy
|
54 |
+
if self.dummy:
|
55 |
+
print("#" * 100)
|
56 |
+
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
57 |
+
print("#" * 100)
|
58 |
+
|
59 |
+
def setup(self, stage: str) -> None:
|
60 |
+
print("Preparing datasets")
|
61 |
+
if self.dummy:
|
62 |
+
data_fn = create_dummy_dataset
|
63 |
+
else:
|
64 |
+
data_fn = create_dataset
|
65 |
+
|
66 |
+
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
|
67 |
+
if self.val_config:
|
68 |
+
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
|
69 |
+
if self.test_config:
|
70 |
+
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
|
71 |
+
|
72 |
+
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
|
73 |
+
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
|
74 |
+
return loader
|
75 |
+
|
76 |
+
def val_dataloader(self) -> wds.DataPipeline:
|
77 |
+
return create_loader(self.val_datapipeline, **self.val_config.loader)
|
78 |
+
|
79 |
+
def test_dataloader(self) -> wds.DataPipeline:
|
80 |
+
return create_loader(self.test_datapipeline, **self.test_config.loader)
|
models/svd/sgm/data/mnist.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torchvision
|
3 |
+
from torch.utils.data import DataLoader, Dataset
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
|
7 |
+
class MNISTDataDictWrapper(Dataset):
|
8 |
+
def __init__(self, dset):
|
9 |
+
super().__init__()
|
10 |
+
self.dset = dset
|
11 |
+
|
12 |
+
def __getitem__(self, i):
|
13 |
+
x, y = self.dset[i]
|
14 |
+
return {"jpg": x, "cls": y}
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return len(self.dset)
|
18 |
+
|
19 |
+
|
20 |
+
class MNISTLoader(pl.LightningDataModule):
|
21 |
+
def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
transform = transforms.Compose(
|
25 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
26 |
+
)
|
27 |
+
|
28 |
+
self.batch_size = batch_size
|
29 |
+
self.num_workers = num_workers
|
30 |
+
self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
|
31 |
+
self.shuffle = shuffle
|
32 |
+
self.train_dataset = MNISTDataDictWrapper(
|
33 |
+
torchvision.datasets.MNIST(
|
34 |
+
root=".data/", train=True, download=True, transform=transform
|
35 |
+
)
|
36 |
+
)
|
37 |
+
self.test_dataset = MNISTDataDictWrapper(
|
38 |
+
torchvision.datasets.MNIST(
|
39 |
+
root=".data/", train=False, download=True, transform=transform
|
40 |
+
)
|
41 |
+
)
|
42 |
+
|
43 |
+
def prepare_data(self):
|
44 |
+
pass
|
45 |
+
|
46 |
+
def train_dataloader(self):
|
47 |
+
return DataLoader(
|
48 |
+
self.train_dataset,
|
49 |
+
batch_size=self.batch_size,
|
50 |
+
shuffle=self.shuffle,
|
51 |
+
num_workers=self.num_workers,
|
52 |
+
prefetch_factor=self.prefetch_factor,
|
53 |
+
)
|
54 |
+
|
55 |
+
def test_dataloader(self):
|
56 |
+
return DataLoader(
|
57 |
+
self.test_dataset,
|
58 |
+
batch_size=self.batch_size,
|
59 |
+
shuffle=self.shuffle,
|
60 |
+
num_workers=self.num_workers,
|
61 |
+
prefetch_factor=self.prefetch_factor,
|
62 |
+
)
|
63 |
+
|
64 |
+
def val_dataloader(self):
|
65 |
+
return DataLoader(
|
66 |
+
self.test_dataset,
|
67 |
+
batch_size=self.batch_size,
|
68 |
+
shuffle=self.shuffle,
|
69 |
+
num_workers=self.num_workers,
|
70 |
+
prefetch_factor=self.prefetch_factor,
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
dset = MNISTDataDictWrapper(
|
76 |
+
torchvision.datasets.MNIST(
|
77 |
+
root=".data/",
|
78 |
+
train=False,
|
79 |
+
download=True,
|
80 |
+
transform=transforms.Compose(
|
81 |
+
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
82 |
+
),
|
83 |
+
)
|
84 |
+
)
|
85 |
+
ex = dset[0]
|
models/svd/sgm/inference/api.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
|
8 |
+
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
|
9 |
+
do_sample)
|
10 |
+
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
11 |
+
DPMPP2SAncestralSampler,
|
12 |
+
EulerAncestralSampler,
|
13 |
+
EulerEDMSampler,
|
14 |
+
HeunEDMSampler,
|
15 |
+
LinearMultistepSampler)
|
16 |
+
from sgm.util import load_model_from_config
|
17 |
+
|
18 |
+
|
19 |
+
class ModelArchitecture(str, Enum):
|
20 |
+
SD_2_1 = "stable-diffusion-v2-1"
|
21 |
+
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
22 |
+
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
23 |
+
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
24 |
+
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
25 |
+
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
26 |
+
|
27 |
+
|
28 |
+
class Sampler(str, Enum):
|
29 |
+
EULER_EDM = "EulerEDMSampler"
|
30 |
+
HEUN_EDM = "HeunEDMSampler"
|
31 |
+
EULER_ANCESTRAL = "EulerAncestralSampler"
|
32 |
+
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
|
33 |
+
DPMPP2M = "DPMPP2MSampler"
|
34 |
+
LINEAR_MULTISTEP = "LinearMultistepSampler"
|
35 |
+
|
36 |
+
|
37 |
+
class Discretization(str, Enum):
|
38 |
+
LEGACY_DDPM = "LegacyDDPMDiscretization"
|
39 |
+
EDM = "EDMDiscretization"
|
40 |
+
|
41 |
+
|
42 |
+
class Guider(str, Enum):
|
43 |
+
VANILLA = "VanillaCFG"
|
44 |
+
IDENTITY = "IdentityGuider"
|
45 |
+
|
46 |
+
|
47 |
+
class Thresholder(str, Enum):
|
48 |
+
NONE = "None"
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class SamplingParams:
|
53 |
+
width: int = 1024
|
54 |
+
height: int = 1024
|
55 |
+
steps: int = 50
|
56 |
+
sampler: Sampler = Sampler.DPMPP2M
|
57 |
+
discretization: Discretization = Discretization.LEGACY_DDPM
|
58 |
+
guider: Guider = Guider.VANILLA
|
59 |
+
thresholder: Thresholder = Thresholder.NONE
|
60 |
+
scale: float = 6.0
|
61 |
+
aesthetic_score: float = 5.0
|
62 |
+
negative_aesthetic_score: float = 5.0
|
63 |
+
img2img_strength: float = 1.0
|
64 |
+
orig_width: int = 1024
|
65 |
+
orig_height: int = 1024
|
66 |
+
crop_coords_top: int = 0
|
67 |
+
crop_coords_left: int = 0
|
68 |
+
sigma_min: float = 0.0292
|
69 |
+
sigma_max: float = 14.6146
|
70 |
+
rho: float = 3.0
|
71 |
+
s_churn: float = 0.0
|
72 |
+
s_tmin: float = 0.0
|
73 |
+
s_tmax: float = 999.0
|
74 |
+
s_noise: float = 1.0
|
75 |
+
eta: float = 1.0
|
76 |
+
order: int = 4
|
77 |
+
|
78 |
+
|
79 |
+
@dataclass
|
80 |
+
class SamplingSpec:
|
81 |
+
width: int
|
82 |
+
height: int
|
83 |
+
channels: int
|
84 |
+
factor: int
|
85 |
+
is_legacy: bool
|
86 |
+
config: str
|
87 |
+
ckpt: str
|
88 |
+
is_guided: bool
|
89 |
+
|
90 |
+
|
91 |
+
model_specs = {
|
92 |
+
ModelArchitecture.SD_2_1: SamplingSpec(
|
93 |
+
height=512,
|
94 |
+
width=512,
|
95 |
+
channels=4,
|
96 |
+
factor=8,
|
97 |
+
is_legacy=True,
|
98 |
+
config="sd_2_1.yaml",
|
99 |
+
ckpt="v2-1_512-ema-pruned.safetensors",
|
100 |
+
is_guided=True,
|
101 |
+
),
|
102 |
+
ModelArchitecture.SD_2_1_768: SamplingSpec(
|
103 |
+
height=768,
|
104 |
+
width=768,
|
105 |
+
channels=4,
|
106 |
+
factor=8,
|
107 |
+
is_legacy=True,
|
108 |
+
config="sd_2_1_768.yaml",
|
109 |
+
ckpt="v2-1_768-ema-pruned.safetensors",
|
110 |
+
is_guided=True,
|
111 |
+
),
|
112 |
+
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
113 |
+
height=1024,
|
114 |
+
width=1024,
|
115 |
+
channels=4,
|
116 |
+
factor=8,
|
117 |
+
is_legacy=False,
|
118 |
+
config="sd_xl_base.yaml",
|
119 |
+
ckpt="sd_xl_base_0.9.safetensors",
|
120 |
+
is_guided=True,
|
121 |
+
),
|
122 |
+
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
123 |
+
height=1024,
|
124 |
+
width=1024,
|
125 |
+
channels=4,
|
126 |
+
factor=8,
|
127 |
+
is_legacy=True,
|
128 |
+
config="sd_xl_refiner.yaml",
|
129 |
+
ckpt="sd_xl_refiner_0.9.safetensors",
|
130 |
+
is_guided=True,
|
131 |
+
),
|
132 |
+
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
133 |
+
height=1024,
|
134 |
+
width=1024,
|
135 |
+
channels=4,
|
136 |
+
factor=8,
|
137 |
+
is_legacy=False,
|
138 |
+
config="sd_xl_base.yaml",
|
139 |
+
ckpt="sd_xl_base_1.0.safetensors",
|
140 |
+
is_guided=True,
|
141 |
+
),
|
142 |
+
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
143 |
+
height=1024,
|
144 |
+
width=1024,
|
145 |
+
channels=4,
|
146 |
+
factor=8,
|
147 |
+
is_legacy=True,
|
148 |
+
config="sd_xl_refiner.yaml",
|
149 |
+
ckpt="sd_xl_refiner_1.0.safetensors",
|
150 |
+
is_guided=True,
|
151 |
+
),
|
152 |
+
}
|
153 |
+
|
154 |
+
|
155 |
+
class SamplingPipeline:
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
model_id: ModelArchitecture,
|
159 |
+
model_path="checkpoints",
|
160 |
+
config_path="configs/inference",
|
161 |
+
device="cuda",
|
162 |
+
use_fp16=True,
|
163 |
+
) -> None:
|
164 |
+
if model_id not in model_specs:
|
165 |
+
raise ValueError(f"Model {model_id} not supported")
|
166 |
+
self.model_id = model_id
|
167 |
+
self.specs = model_specs[self.model_id]
|
168 |
+
self.config = str(pathlib.Path(config_path, self.specs.config))
|
169 |
+
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
170 |
+
self.device = device
|
171 |
+
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
172 |
+
|
173 |
+
def _load_model(self, device="cuda", use_fp16=True):
|
174 |
+
config = OmegaConf.load(self.config)
|
175 |
+
model = load_model_from_config(config, self.ckpt)
|
176 |
+
if model is None:
|
177 |
+
raise ValueError(f"Model {self.model_id} could not be loaded")
|
178 |
+
model.to(device)
|
179 |
+
if use_fp16:
|
180 |
+
model.conditioner.half()
|
181 |
+
model.model.half()
|
182 |
+
return model
|
183 |
+
|
184 |
+
def text_to_image(
|
185 |
+
self,
|
186 |
+
params: SamplingParams,
|
187 |
+
prompt: str,
|
188 |
+
negative_prompt: str = "",
|
189 |
+
samples: int = 1,
|
190 |
+
return_latents: bool = False,
|
191 |
+
):
|
192 |
+
sampler = get_sampler_config(params)
|
193 |
+
value_dict = asdict(params)
|
194 |
+
value_dict["prompt"] = prompt
|
195 |
+
value_dict["negative_prompt"] = negative_prompt
|
196 |
+
value_dict["target_width"] = params.width
|
197 |
+
value_dict["target_height"] = params.height
|
198 |
+
return do_sample(
|
199 |
+
self.model,
|
200 |
+
sampler,
|
201 |
+
value_dict,
|
202 |
+
samples,
|
203 |
+
params.height,
|
204 |
+
params.width,
|
205 |
+
self.specs.channels,
|
206 |
+
self.specs.factor,
|
207 |
+
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
208 |
+
return_latents=return_latents,
|
209 |
+
filter=None,
|
210 |
+
)
|
211 |
+
|
212 |
+
def image_to_image(
|
213 |
+
self,
|
214 |
+
params: SamplingParams,
|
215 |
+
image,
|
216 |
+
prompt: str,
|
217 |
+
negative_prompt: str = "",
|
218 |
+
samples: int = 1,
|
219 |
+
return_latents: bool = False,
|
220 |
+
):
|
221 |
+
sampler = get_sampler_config(params)
|
222 |
+
|
223 |
+
if params.img2img_strength < 1.0:
|
224 |
+
sampler.discretization = Img2ImgDiscretizationWrapper(
|
225 |
+
sampler.discretization,
|
226 |
+
strength=params.img2img_strength,
|
227 |
+
)
|
228 |
+
height, width = image.shape[2], image.shape[3]
|
229 |
+
value_dict = asdict(params)
|
230 |
+
value_dict["prompt"] = prompt
|
231 |
+
value_dict["negative_prompt"] = negative_prompt
|
232 |
+
value_dict["target_width"] = width
|
233 |
+
value_dict["target_height"] = height
|
234 |
+
return do_img2img(
|
235 |
+
image,
|
236 |
+
self.model,
|
237 |
+
sampler,
|
238 |
+
value_dict,
|
239 |
+
samples,
|
240 |
+
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
241 |
+
return_latents=return_latents,
|
242 |
+
filter=None,
|
243 |
+
)
|
244 |
+
|
245 |
+
def refiner(
|
246 |
+
self,
|
247 |
+
params: SamplingParams,
|
248 |
+
image,
|
249 |
+
prompt: str,
|
250 |
+
negative_prompt: Optional[str] = None,
|
251 |
+
samples: int = 1,
|
252 |
+
return_latents: bool = False,
|
253 |
+
):
|
254 |
+
sampler = get_sampler_config(params)
|
255 |
+
value_dict = {
|
256 |
+
"orig_width": image.shape[3] * 8,
|
257 |
+
"orig_height": image.shape[2] * 8,
|
258 |
+
"target_width": image.shape[3] * 8,
|
259 |
+
"target_height": image.shape[2] * 8,
|
260 |
+
"prompt": prompt,
|
261 |
+
"negative_prompt": negative_prompt,
|
262 |
+
"crop_coords_top": 0,
|
263 |
+
"crop_coords_left": 0,
|
264 |
+
"aesthetic_score": 6.0,
|
265 |
+
"negative_aesthetic_score": 2.5,
|
266 |
+
}
|
267 |
+
|
268 |
+
return do_img2img(
|
269 |
+
image,
|
270 |
+
self.model,
|
271 |
+
sampler,
|
272 |
+
value_dict,
|
273 |
+
samples,
|
274 |
+
skip_encode=True,
|
275 |
+
return_latents=return_latents,
|
276 |
+
filter=None,
|
277 |
+
)
|
278 |
+
|
279 |
+
|
280 |
+
def get_guider_config(params: SamplingParams):
|
281 |
+
if params.guider == Guider.IDENTITY:
|
282 |
+
guider_config = {
|
283 |
+
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
284 |
+
}
|
285 |
+
elif params.guider == Guider.VANILLA:
|
286 |
+
scale = params.scale
|
287 |
+
|
288 |
+
thresholder = params.thresholder
|
289 |
+
|
290 |
+
if thresholder == Thresholder.NONE:
|
291 |
+
dyn_thresh_config = {
|
292 |
+
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
293 |
+
}
|
294 |
+
else:
|
295 |
+
raise NotImplementedError
|
296 |
+
|
297 |
+
guider_config = {
|
298 |
+
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
299 |
+
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
300 |
+
}
|
301 |
+
else:
|
302 |
+
raise NotImplementedError
|
303 |
+
return guider_config
|
304 |
+
|
305 |
+
|
306 |
+
def get_discretization_config(params: SamplingParams):
|
307 |
+
if params.discretization == Discretization.LEGACY_DDPM:
|
308 |
+
discretization_config = {
|
309 |
+
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
310 |
+
}
|
311 |
+
elif params.discretization == Discretization.EDM:
|
312 |
+
discretization_config = {
|
313 |
+
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
314 |
+
"params": {
|
315 |
+
"sigma_min": params.sigma_min,
|
316 |
+
"sigma_max": params.sigma_max,
|
317 |
+
"rho": params.rho,
|
318 |
+
},
|
319 |
+
}
|
320 |
+
else:
|
321 |
+
raise ValueError(f"unknown discretization {params.discretization}")
|
322 |
+
return discretization_config
|
323 |
+
|
324 |
+
|
325 |
+
def get_sampler_config(params: SamplingParams):
|
326 |
+
discretization_config = get_discretization_config(params)
|
327 |
+
guider_config = get_guider_config(params)
|
328 |
+
sampler = None
|
329 |
+
if params.sampler == Sampler.EULER_EDM:
|
330 |
+
return EulerEDMSampler(
|
331 |
+
num_steps=params.steps,
|
332 |
+
discretization_config=discretization_config,
|
333 |
+
guider_config=guider_config,
|
334 |
+
s_churn=params.s_churn,
|
335 |
+
s_tmin=params.s_tmin,
|
336 |
+
s_tmax=params.s_tmax,
|
337 |
+
s_noise=params.s_noise,
|
338 |
+
verbose=True,
|
339 |
+
)
|
340 |
+
if params.sampler == Sampler.HEUN_EDM:
|
341 |
+
return HeunEDMSampler(
|
342 |
+
num_steps=params.steps,
|
343 |
+
discretization_config=discretization_config,
|
344 |
+
guider_config=guider_config,
|
345 |
+
s_churn=params.s_churn,
|
346 |
+
s_tmin=params.s_tmin,
|
347 |
+
s_tmax=params.s_tmax,
|
348 |
+
s_noise=params.s_noise,
|
349 |
+
verbose=True,
|
350 |
+
)
|
351 |
+
if params.sampler == Sampler.EULER_ANCESTRAL:
|
352 |
+
return EulerAncestralSampler(
|
353 |
+
num_steps=params.steps,
|
354 |
+
discretization_config=discretization_config,
|
355 |
+
guider_config=guider_config,
|
356 |
+
eta=params.eta,
|
357 |
+
s_noise=params.s_noise,
|
358 |
+
verbose=True,
|
359 |
+
)
|
360 |
+
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
|
361 |
+
return DPMPP2SAncestralSampler(
|
362 |
+
num_steps=params.steps,
|
363 |
+
discretization_config=discretization_config,
|
364 |
+
guider_config=guider_config,
|
365 |
+
eta=params.eta,
|
366 |
+
s_noise=params.s_noise,
|
367 |
+
verbose=True,
|
368 |
+
)
|
369 |
+
if params.sampler == Sampler.DPMPP2M:
|
370 |
+
return DPMPP2MSampler(
|
371 |
+
num_steps=params.steps,
|
372 |
+
discretization_config=discretization_config,
|
373 |
+
guider_config=guider_config,
|
374 |
+
verbose=True,
|
375 |
+
)
|
376 |
+
if params.sampler == Sampler.LINEAR_MULTISTEP:
|
377 |
+
return LinearMultistepSampler(
|
378 |
+
num_steps=params.steps,
|
379 |
+
discretization_config=discretization_config,
|
380 |
+
guider_config=guider_config,
|
381 |
+
order=params.order,
|
382 |
+
verbose=True,
|
383 |
+
)
|
384 |
+
|
385 |
+
raise ValueError(f"unknown sampler {params.sampler}!")
|
models/svd/sgm/inference/helpers.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from einops import rearrange
|
8 |
+
from imwatermark import WatermarkEncoder
|
9 |
+
from omegaconf import ListConfig
|
10 |
+
from PIL import Image
|
11 |
+
from torch import autocast
|
12 |
+
|
13 |
+
from sgm.util import append_dims
|
14 |
+
|
15 |
+
|
16 |
+
class WatermarkEmbedder:
|
17 |
+
def __init__(self, watermark):
|
18 |
+
self.watermark = watermark
|
19 |
+
self.num_bits = len(WATERMARK_BITS)
|
20 |
+
self.encoder = WatermarkEncoder()
|
21 |
+
self.encoder.set_watermark("bits", self.watermark)
|
22 |
+
|
23 |
+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Adds a predefined watermark to the input image
|
26 |
+
|
27 |
+
Args:
|
28 |
+
image: ([N,] B, RGB, H, W) in range [0, 1]
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
same as input but watermarked
|
32 |
+
"""
|
33 |
+
squeeze = len(image.shape) == 4
|
34 |
+
if squeeze:
|
35 |
+
image = image[None, ...]
|
36 |
+
n = image.shape[0]
|
37 |
+
image_np = rearrange(
|
38 |
+
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
39 |
+
).numpy()[:, :, :, ::-1]
|
40 |
+
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
41 |
+
# watermarking libary expects input as cv2 BGR format
|
42 |
+
for k in range(image_np.shape[0]):
|
43 |
+
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
44 |
+
image = torch.from_numpy(
|
45 |
+
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
46 |
+
).to(image.device)
|
47 |
+
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
48 |
+
if squeeze:
|
49 |
+
image = image[0]
|
50 |
+
return image
|
51 |
+
|
52 |
+
|
53 |
+
# A fixed 48-bit message that was choosen at random
|
54 |
+
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
55 |
+
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
56 |
+
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
57 |
+
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
58 |
+
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
59 |
+
|
60 |
+
|
61 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
62 |
+
return list({x.input_key for x in conditioner.embedders})
|
63 |
+
|
64 |
+
|
65 |
+
def perform_save_locally(save_path, samples):
|
66 |
+
os.makedirs(os.path.join(save_path), exist_ok=True)
|
67 |
+
base_count = len(os.listdir(os.path.join(save_path)))
|
68 |
+
samples = embed_watermark(samples)
|
69 |
+
for sample in samples:
|
70 |
+
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
71 |
+
Image.fromarray(sample.astype(np.uint8)).save(
|
72 |
+
os.path.join(save_path, f"{base_count:09}.png")
|
73 |
+
)
|
74 |
+
base_count += 1
|
75 |
+
|
76 |
+
|
77 |
+
class Img2ImgDiscretizationWrapper:
|
78 |
+
"""
|
79 |
+
wraps a discretizer, and prunes the sigmas
|
80 |
+
params:
|
81 |
+
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, discretization, strength: float = 1.0):
|
85 |
+
self.discretization = discretization
|
86 |
+
self.strength = strength
|
87 |
+
assert 0.0 <= self.strength <= 1.0
|
88 |
+
|
89 |
+
def __call__(self, *args, **kwargs):
|
90 |
+
# sigmas start large first, and decrease then
|
91 |
+
sigmas = self.discretization(*args, **kwargs)
|
92 |
+
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
93 |
+
sigmas = torch.flip(sigmas, (0,))
|
94 |
+
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
95 |
+
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
96 |
+
sigmas = torch.flip(sigmas, (0,))
|
97 |
+
print(f"sigmas after pruning: ", sigmas)
|
98 |
+
return sigmas
|
99 |
+
|
100 |
+
|
101 |
+
def do_sample(
|
102 |
+
model,
|
103 |
+
sampler,
|
104 |
+
value_dict,
|
105 |
+
num_samples,
|
106 |
+
H,
|
107 |
+
W,
|
108 |
+
C,
|
109 |
+
F,
|
110 |
+
force_uc_zero_embeddings: Optional[List] = None,
|
111 |
+
batch2model_input: Optional[List] = None,
|
112 |
+
return_latents=False,
|
113 |
+
filter=None,
|
114 |
+
device="cuda",
|
115 |
+
):
|
116 |
+
if force_uc_zero_embeddings is None:
|
117 |
+
force_uc_zero_embeddings = []
|
118 |
+
if batch2model_input is None:
|
119 |
+
batch2model_input = []
|
120 |
+
|
121 |
+
with torch.no_grad():
|
122 |
+
with autocast(device) as precision_scope:
|
123 |
+
with model.ema_scope():
|
124 |
+
num_samples = [num_samples]
|
125 |
+
batch, batch_uc = get_batch(
|
126 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
127 |
+
value_dict,
|
128 |
+
num_samples,
|
129 |
+
)
|
130 |
+
for key in batch:
|
131 |
+
if isinstance(batch[key], torch.Tensor):
|
132 |
+
print(key, batch[key].shape)
|
133 |
+
elif isinstance(batch[key], list):
|
134 |
+
print(key, [len(l) for l in batch[key]])
|
135 |
+
else:
|
136 |
+
print(key, batch[key])
|
137 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
138 |
+
batch,
|
139 |
+
batch_uc=batch_uc,
|
140 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
141 |
+
)
|
142 |
+
|
143 |
+
for k in c:
|
144 |
+
if not k == "crossattn":
|
145 |
+
c[k], uc[k] = map(
|
146 |
+
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
147 |
+
)
|
148 |
+
|
149 |
+
additional_model_inputs = {}
|
150 |
+
for k in batch2model_input:
|
151 |
+
additional_model_inputs[k] = batch[k]
|
152 |
+
|
153 |
+
shape = (math.prod(num_samples), C, H // F, W // F)
|
154 |
+
randn = torch.randn(shape).to(device)
|
155 |
+
|
156 |
+
def denoiser(input, sigma, c):
|
157 |
+
return model.denoiser(
|
158 |
+
model.model, input, sigma, c, **additional_model_inputs
|
159 |
+
)
|
160 |
+
|
161 |
+
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
162 |
+
samples_x = model.decode_first_stage(samples_z)
|
163 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
164 |
+
|
165 |
+
if filter is not None:
|
166 |
+
samples = filter(samples)
|
167 |
+
|
168 |
+
if return_latents:
|
169 |
+
return samples, samples_z
|
170 |
+
return samples
|
171 |
+
|
172 |
+
|
173 |
+
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
174 |
+
# Hardcoded demo setups; might undergo some changes in the future
|
175 |
+
|
176 |
+
batch = {}
|
177 |
+
batch_uc = {}
|
178 |
+
|
179 |
+
for key in keys:
|
180 |
+
if key == "txt":
|
181 |
+
batch["txt"] = (
|
182 |
+
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
183 |
+
.reshape(N)
|
184 |
+
.tolist()
|
185 |
+
)
|
186 |
+
batch_uc["txt"] = (
|
187 |
+
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
188 |
+
.reshape(N)
|
189 |
+
.tolist()
|
190 |
+
)
|
191 |
+
elif key == "original_size_as_tuple":
|
192 |
+
batch["original_size_as_tuple"] = (
|
193 |
+
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
194 |
+
.to(device)
|
195 |
+
.repeat(*N, 1)
|
196 |
+
)
|
197 |
+
elif key == "crop_coords_top_left":
|
198 |
+
batch["crop_coords_top_left"] = (
|
199 |
+
torch.tensor(
|
200 |
+
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
201 |
+
)
|
202 |
+
.to(device)
|
203 |
+
.repeat(*N, 1)
|
204 |
+
)
|
205 |
+
elif key == "aesthetic_score":
|
206 |
+
batch["aesthetic_score"] = (
|
207 |
+
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
208 |
+
)
|
209 |
+
batch_uc["aesthetic_score"] = (
|
210 |
+
torch.tensor([value_dict["negative_aesthetic_score"]])
|
211 |
+
.to(device)
|
212 |
+
.repeat(*N, 1)
|
213 |
+
)
|
214 |
+
|
215 |
+
elif key == "target_size_as_tuple":
|
216 |
+
batch["target_size_as_tuple"] = (
|
217 |
+
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
218 |
+
.to(device)
|
219 |
+
.repeat(*N, 1)
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
batch[key] = value_dict[key]
|
223 |
+
|
224 |
+
for key in batch.keys():
|
225 |
+
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
226 |
+
batch_uc[key] = torch.clone(batch[key])
|
227 |
+
return batch, batch_uc
|
228 |
+
|
229 |
+
|
230 |
+
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
231 |
+
w, h = image.size
|
232 |
+
print(f"loaded input image of size ({w}, {h})")
|
233 |
+
width, height = map(
|
234 |
+
lambda x: x - x % 64, (w, h)
|
235 |
+
) # resize to integer multiple of 64
|
236 |
+
image = image.resize((width, height))
|
237 |
+
image_array = np.array(image.convert("RGB"))
|
238 |
+
image_array = image_array[None].transpose(0, 3, 1, 2)
|
239 |
+
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
|
240 |
+
return image_tensor.to(device)
|
241 |
+
|
242 |
+
|
243 |
+
def do_img2img(
|
244 |
+
img,
|
245 |
+
model,
|
246 |
+
sampler,
|
247 |
+
value_dict,
|
248 |
+
num_samples,
|
249 |
+
force_uc_zero_embeddings=[],
|
250 |
+
additional_kwargs={},
|
251 |
+
offset_noise_level: float = 0.0,
|
252 |
+
return_latents=False,
|
253 |
+
skip_encode=False,
|
254 |
+
filter=None,
|
255 |
+
device="cuda",
|
256 |
+
):
|
257 |
+
with torch.no_grad():
|
258 |
+
with autocast(device) as precision_scope:
|
259 |
+
with model.ema_scope():
|
260 |
+
batch, batch_uc = get_batch(
|
261 |
+
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
262 |
+
value_dict,
|
263 |
+
[num_samples],
|
264 |
+
)
|
265 |
+
c, uc = model.conditioner.get_unconditional_conditioning(
|
266 |
+
batch,
|
267 |
+
batch_uc=batch_uc,
|
268 |
+
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
269 |
+
)
|
270 |
+
|
271 |
+
for k in c:
|
272 |
+
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
273 |
+
|
274 |
+
for k in additional_kwargs:
|
275 |
+
c[k] = uc[k] = additional_kwargs[k]
|
276 |
+
if skip_encode:
|
277 |
+
z = img
|
278 |
+
else:
|
279 |
+
z = model.encode_first_stage(img)
|
280 |
+
noise = torch.randn_like(z)
|
281 |
+
sigmas = sampler.discretization(sampler.num_steps)
|
282 |
+
sigma = sigmas[0].to(z.device)
|
283 |
+
|
284 |
+
if offset_noise_level > 0.0:
|
285 |
+
noise = noise + offset_noise_level * append_dims(
|
286 |
+
torch.randn(z.shape[0], device=z.device), z.ndim
|
287 |
+
)
|
288 |
+
noised_z = z + noise * append_dims(sigma, z.ndim)
|
289 |
+
noised_z = noised_z / torch.sqrt(
|
290 |
+
1.0 + sigmas[0] ** 2.0
|
291 |
+
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
292 |
+
|
293 |
+
def denoiser(x, sigma, c):
|
294 |
+
return model.denoiser(model.model, x, sigma, c)
|
295 |
+
|
296 |
+
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
297 |
+
samples_x = model.decode_first_stage(samples_z)
|
298 |
+
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
299 |
+
|
300 |
+
if filter is not None:
|
301 |
+
samples = filter(samples)
|
302 |
+
|
303 |
+
if return_latents:
|
304 |
+
return samples, samples_z
|
305 |
+
return samples
|
models/svd/sgm/lr_scheduler.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
warm_up_steps,
|
12 |
+
lr_min,
|
13 |
+
lr_max,
|
14 |
+
lr_start,
|
15 |
+
max_decay_steps,
|
16 |
+
verbosity_interval=0,
|
17 |
+
):
|
18 |
+
self.lr_warm_up_steps = warm_up_steps
|
19 |
+
self.lr_start = lr_start
|
20 |
+
self.lr_min = lr_min
|
21 |
+
self.lr_max = lr_max
|
22 |
+
self.lr_max_decay_steps = max_decay_steps
|
23 |
+
self.last_lr = 0.0
|
24 |
+
self.verbosity_interval = verbosity_interval
|
25 |
+
|
26 |
+
def schedule(self, n, **kwargs):
|
27 |
+
if self.verbosity_interval > 0:
|
28 |
+
if n % self.verbosity_interval == 0:
|
29 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
30 |
+
if n < self.lr_warm_up_steps:
|
31 |
+
lr = (
|
32 |
+
self.lr_max - self.lr_start
|
33 |
+
) / self.lr_warm_up_steps * n + self.lr_start
|
34 |
+
self.last_lr = lr
|
35 |
+
return lr
|
36 |
+
else:
|
37 |
+
t = (n - self.lr_warm_up_steps) / (
|
38 |
+
self.lr_max_decay_steps - self.lr_warm_up_steps
|
39 |
+
)
|
40 |
+
t = min(t, 1.0)
|
41 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
42 |
+
1 + np.cos(t * np.pi)
|
43 |
+
)
|
44 |
+
self.last_lr = lr
|
45 |
+
return lr
|
46 |
+
|
47 |
+
def __call__(self, n, **kwargs):
|
48 |
+
return self.schedule(n, **kwargs)
|
49 |
+
|
50 |
+
|
51 |
+
class LambdaWarmUpCosineScheduler2:
|
52 |
+
"""
|
53 |
+
supports repeated iterations, configurable via lists
|
54 |
+
note: use with a base_lr of 1.0.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
59 |
+
):
|
60 |
+
assert (
|
61 |
+
len(warm_up_steps)
|
62 |
+
== len(f_min)
|
63 |
+
== len(f_max)
|
64 |
+
== len(f_start)
|
65 |
+
== len(cycle_lengths)
|
66 |
+
)
|
67 |
+
self.lr_warm_up_steps = warm_up_steps
|
68 |
+
self.f_start = f_start
|
69 |
+
self.f_min = f_min
|
70 |
+
self.f_max = f_max
|
71 |
+
self.cycle_lengths = cycle_lengths
|
72 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
73 |
+
self.last_f = 0.0
|
74 |
+
self.verbosity_interval = verbosity_interval
|
75 |
+
|
76 |
+
def find_in_interval(self, n):
|
77 |
+
interval = 0
|
78 |
+
for cl in self.cum_cycles[1:]:
|
79 |
+
if n <= cl:
|
80 |
+
return interval
|
81 |
+
interval += 1
|
82 |
+
|
83 |
+
def schedule(self, n, **kwargs):
|
84 |
+
cycle = self.find_in_interval(n)
|
85 |
+
n = n - self.cum_cycles[cycle]
|
86 |
+
if self.verbosity_interval > 0:
|
87 |
+
if n % self.verbosity_interval == 0:
|
88 |
+
print(
|
89 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
90 |
+
f"current cycle {cycle}"
|
91 |
+
)
|
92 |
+
if n < self.lr_warm_up_steps[cycle]:
|
93 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
94 |
+
cycle
|
95 |
+
] * n + self.f_start[cycle]
|
96 |
+
self.last_f = f
|
97 |
+
return f
|
98 |
+
else:
|
99 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (
|
100 |
+
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
101 |
+
)
|
102 |
+
t = min(t, 1.0)
|
103 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
104 |
+
1 + np.cos(t * np.pi)
|
105 |
+
)
|
106 |
+
self.last_f = f
|
107 |
+
return f
|
108 |
+
|
109 |
+
def __call__(self, n, **kwargs):
|
110 |
+
return self.schedule(n, **kwargs)
|
111 |
+
|
112 |
+
|
113 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
114 |
+
def schedule(self, n, **kwargs):
|
115 |
+
cycle = self.find_in_interval(n)
|
116 |
+
n = n - self.cum_cycles[cycle]
|
117 |
+
if self.verbosity_interval > 0:
|
118 |
+
if n % self.verbosity_interval == 0:
|
119 |
+
print(
|
120 |
+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
121 |
+
f"current cycle {cycle}"
|
122 |
+
)
|
123 |
+
|
124 |
+
if n < self.lr_warm_up_steps[cycle]:
|
125 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
126 |
+
cycle
|
127 |
+
] * n + self.f_start[cycle]
|
128 |
+
self.last_f = f
|
129 |
+
return f
|
130 |
+
else:
|
131 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
132 |
+
self.cycle_lengths[cycle] - n
|
133 |
+
) / (self.cycle_lengths[cycle])
|
134 |
+
self.last_f = f
|
135 |
+
return f
|
models/svd/sgm/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from models.svd.sgm.models.autoencoder import AutoencodingEngine
|
2 |
+
from models.svd.sgm.models.diffusion import DiffusionEngine
|
models/svd/sgm/models/autoencoder.py
ADDED
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import re
|
4 |
+
from abc import abstractmethod
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange
|
12 |
+
from packaging import version
|
13 |
+
|
14 |
+
from models.svd.sgm.modules.autoencoding.regularizers import AbstractRegularizer
|
15 |
+
from models.svd.sgm.modules.ema import LitEma
|
16 |
+
from models.svd.sgm.util import (default, get_nested_attribute, get_obj_from_str,
|
17 |
+
instantiate_from_config)
|
18 |
+
|
19 |
+
logpy = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class AbstractAutoencoder(pl.LightningModule):
|
23 |
+
"""
|
24 |
+
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
25 |
+
unCLIP models, etc. Hence, it is fairly general, and specific features
|
26 |
+
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
ema_decay: Union[None, float] = None,
|
32 |
+
monitor: Union[None, str] = None,
|
33 |
+
input_key: str = "jpg",
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.input_key = input_key
|
38 |
+
self.use_ema = ema_decay is not None
|
39 |
+
if monitor is not None:
|
40 |
+
self.monitor = monitor
|
41 |
+
|
42 |
+
if self.use_ema:
|
43 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
44 |
+
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
45 |
+
|
46 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
47 |
+
self.automatic_optimization = False
|
48 |
+
|
49 |
+
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
50 |
+
if ckpt is None:
|
51 |
+
return
|
52 |
+
if isinstance(ckpt, str):
|
53 |
+
ckpt = {
|
54 |
+
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
55 |
+
"params": {"ckpt_path": ckpt},
|
56 |
+
}
|
57 |
+
engine = instantiate_from_config(ckpt)
|
58 |
+
engine(self)
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def get_input(self, batch) -> Any:
|
62 |
+
raise NotImplementedError()
|
63 |
+
|
64 |
+
def on_train_batch_end(self, *args, **kwargs):
|
65 |
+
# for EMA computation
|
66 |
+
if self.use_ema:
|
67 |
+
self.model_ema(self)
|
68 |
+
|
69 |
+
@contextmanager
|
70 |
+
def ema_scope(self, context=None):
|
71 |
+
if self.use_ema:
|
72 |
+
self.model_ema.store(self.parameters())
|
73 |
+
self.model_ema.copy_to(self)
|
74 |
+
if context is not None:
|
75 |
+
logpy.info(f"{context}: Switched to EMA weights")
|
76 |
+
try:
|
77 |
+
yield None
|
78 |
+
finally:
|
79 |
+
if self.use_ema:
|
80 |
+
self.model_ema.restore(self.parameters())
|
81 |
+
if context is not None:
|
82 |
+
logpy.info(f"{context}: Restored training weights")
|
83 |
+
|
84 |
+
@abstractmethod
|
85 |
+
def encode(self, *args, **kwargs) -> torch.Tensor:
|
86 |
+
raise NotImplementedError("encode()-method of abstract base class called")
|
87 |
+
|
88 |
+
@abstractmethod
|
89 |
+
def decode(self, *args, **kwargs) -> torch.Tensor:
|
90 |
+
raise NotImplementedError("decode()-method of abstract base class called")
|
91 |
+
|
92 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
93 |
+
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
94 |
+
return get_obj_from_str(cfg["target"])(
|
95 |
+
params, lr=lr, **cfg.get("params", dict())
|
96 |
+
)
|
97 |
+
|
98 |
+
def configure_optimizers(self) -> Any:
|
99 |
+
raise NotImplementedError()
|
100 |
+
|
101 |
+
|
102 |
+
class AutoencodingEngine(AbstractAutoencoder):
|
103 |
+
"""
|
104 |
+
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
105 |
+
(we also restore them explicitly as special cases for legacy reasons).
|
106 |
+
Regularizations such as KL or VQ are moved to the regularizer class.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
*args,
|
112 |
+
encoder_config: Dict,
|
113 |
+
decoder_config: Dict,
|
114 |
+
loss_config: Dict,
|
115 |
+
regularizer_config: Dict,
|
116 |
+
optimizer_config: Union[Dict, None] = None,
|
117 |
+
lr_g_factor: float = 1.0,
|
118 |
+
trainable_ae_params: Optional[List[List[str]]] = None,
|
119 |
+
ae_optimizer_args: Optional[List[dict]] = None,
|
120 |
+
trainable_disc_params: Optional[List[List[str]]] = None,
|
121 |
+
disc_optimizer_args: Optional[List[dict]] = None,
|
122 |
+
disc_start_iter: int = 0,
|
123 |
+
diff_boost_factor: float = 3.0,
|
124 |
+
ckpt_engine: Union[None, str, dict] = None,
|
125 |
+
ckpt_path: Optional[str] = None,
|
126 |
+
additional_decode_keys: Optional[List[str]] = None,
|
127 |
+
**kwargs,
|
128 |
+
):
|
129 |
+
super().__init__(*args, **kwargs)
|
130 |
+
self.automatic_optimization = False # pytorch lightning
|
131 |
+
|
132 |
+
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
133 |
+
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
134 |
+
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
135 |
+
self.regularization: AbstractRegularizer = instantiate_from_config(
|
136 |
+
regularizer_config
|
137 |
+
)
|
138 |
+
self.optimizer_config = default(
|
139 |
+
optimizer_config, {"target": "torch.optim.Adam"}
|
140 |
+
)
|
141 |
+
self.diff_boost_factor = diff_boost_factor
|
142 |
+
self.disc_start_iter = disc_start_iter
|
143 |
+
self.lr_g_factor = lr_g_factor
|
144 |
+
self.trainable_ae_params = trainable_ae_params
|
145 |
+
if self.trainable_ae_params is not None:
|
146 |
+
self.ae_optimizer_args = default(
|
147 |
+
ae_optimizer_args,
|
148 |
+
[{} for _ in range(len(self.trainable_ae_params))],
|
149 |
+
)
|
150 |
+
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
151 |
+
else:
|
152 |
+
self.ae_optimizer_args = [{}] # makes type consitent
|
153 |
+
|
154 |
+
self.trainable_disc_params = trainable_disc_params
|
155 |
+
if self.trainable_disc_params is not None:
|
156 |
+
self.disc_optimizer_args = default(
|
157 |
+
disc_optimizer_args,
|
158 |
+
[{} for _ in range(len(self.trainable_disc_params))],
|
159 |
+
)
|
160 |
+
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
161 |
+
else:
|
162 |
+
self.disc_optimizer_args = [{}] # makes type consitent
|
163 |
+
|
164 |
+
if ckpt_path is not None:
|
165 |
+
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
166 |
+
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
167 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
168 |
+
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
169 |
+
|
170 |
+
def get_input(self, batch: Dict) -> torch.Tensor:
|
171 |
+
# assuming unified data format, dataloader returns a dict.
|
172 |
+
# image tensors should be scaled to -1 ... 1 and in channels-first
|
173 |
+
# format (e.g., bchw instead if bhwc)
|
174 |
+
return batch[self.input_key]
|
175 |
+
|
176 |
+
def get_autoencoder_params(self) -> list:
|
177 |
+
params = []
|
178 |
+
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
179 |
+
params += list(self.loss.get_trainable_autoencoder_parameters())
|
180 |
+
if hasattr(self.regularization, "get_trainable_parameters"):
|
181 |
+
params += list(self.regularization.get_trainable_parameters())
|
182 |
+
params = params + list(self.encoder.parameters())
|
183 |
+
params = params + list(self.decoder.parameters())
|
184 |
+
return params
|
185 |
+
|
186 |
+
def get_discriminator_params(self) -> list:
|
187 |
+
if hasattr(self.loss, "get_trainable_parameters"):
|
188 |
+
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
189 |
+
else:
|
190 |
+
params = []
|
191 |
+
return params
|
192 |
+
|
193 |
+
def get_last_layer(self):
|
194 |
+
return self.decoder.get_last_layer()
|
195 |
+
|
196 |
+
def encode(
|
197 |
+
self,
|
198 |
+
x: torch.Tensor,
|
199 |
+
return_reg_log: bool = False,
|
200 |
+
unregularized: bool = False,
|
201 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
202 |
+
z = self.encoder(x)
|
203 |
+
if unregularized:
|
204 |
+
return z, dict()
|
205 |
+
z, reg_log = self.regularization(z)
|
206 |
+
if return_reg_log:
|
207 |
+
return z, reg_log
|
208 |
+
return z
|
209 |
+
|
210 |
+
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
211 |
+
x = self.decoder(z, **kwargs)
|
212 |
+
return x
|
213 |
+
|
214 |
+
def forward(
|
215 |
+
self, x: torch.Tensor, **additional_decode_kwargs
|
216 |
+
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
217 |
+
z, reg_log = self.encode(x, return_reg_log=True)
|
218 |
+
dec = self.decode(z, **additional_decode_kwargs)
|
219 |
+
return z, dec, reg_log
|
220 |
+
|
221 |
+
def inner_training_step(
|
222 |
+
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
223 |
+
) -> torch.Tensor:
|
224 |
+
x = self.get_input(batch)
|
225 |
+
additional_decode_kwargs = {
|
226 |
+
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
227 |
+
}
|
228 |
+
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
229 |
+
if hasattr(self.loss, "forward_keys"):
|
230 |
+
extra_info = {
|
231 |
+
"z": z,
|
232 |
+
"optimizer_idx": optimizer_idx,
|
233 |
+
"global_step": self.global_step,
|
234 |
+
"last_layer": self.get_last_layer(),
|
235 |
+
"split": "train",
|
236 |
+
"regularization_log": regularization_log,
|
237 |
+
"autoencoder": self,
|
238 |
+
}
|
239 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
240 |
+
else:
|
241 |
+
extra_info = dict()
|
242 |
+
|
243 |
+
if optimizer_idx == 0:
|
244 |
+
# autoencode
|
245 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
246 |
+
if isinstance(out_loss, tuple):
|
247 |
+
aeloss, log_dict_ae = out_loss
|
248 |
+
else:
|
249 |
+
# simple loss function
|
250 |
+
aeloss = out_loss
|
251 |
+
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
252 |
+
|
253 |
+
self.log_dict(
|
254 |
+
log_dict_ae,
|
255 |
+
prog_bar=False,
|
256 |
+
logger=True,
|
257 |
+
on_step=True,
|
258 |
+
on_epoch=True,
|
259 |
+
sync_dist=False,
|
260 |
+
)
|
261 |
+
self.log(
|
262 |
+
"loss",
|
263 |
+
aeloss.mean().detach(),
|
264 |
+
prog_bar=True,
|
265 |
+
logger=False,
|
266 |
+
on_epoch=False,
|
267 |
+
on_step=True,
|
268 |
+
)
|
269 |
+
return aeloss
|
270 |
+
elif optimizer_idx == 1:
|
271 |
+
# discriminator
|
272 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
273 |
+
# -> discriminator always needs to return a tuple
|
274 |
+
self.log_dict(
|
275 |
+
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
276 |
+
)
|
277 |
+
return discloss
|
278 |
+
else:
|
279 |
+
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
280 |
+
|
281 |
+
def training_step(self, batch: dict, batch_idx: int):
|
282 |
+
opts = self.optimizers()
|
283 |
+
if not isinstance(opts, list):
|
284 |
+
# Non-adversarial case
|
285 |
+
opts = [opts]
|
286 |
+
optimizer_idx = batch_idx % len(opts)
|
287 |
+
if self.global_step < self.disc_start_iter:
|
288 |
+
optimizer_idx = 0
|
289 |
+
opt = opts[optimizer_idx]
|
290 |
+
opt.zero_grad()
|
291 |
+
with opt.toggle_model():
|
292 |
+
loss = self.inner_training_step(
|
293 |
+
batch, batch_idx, optimizer_idx=optimizer_idx
|
294 |
+
)
|
295 |
+
self.manual_backward(loss)
|
296 |
+
opt.step()
|
297 |
+
|
298 |
+
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
299 |
+
log_dict = self._validation_step(batch, batch_idx)
|
300 |
+
with self.ema_scope():
|
301 |
+
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
302 |
+
log_dict.update(log_dict_ema)
|
303 |
+
return log_dict
|
304 |
+
|
305 |
+
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
306 |
+
x = self.get_input(batch)
|
307 |
+
|
308 |
+
z, xrec, regularization_log = self(x)
|
309 |
+
if hasattr(self.loss, "forward_keys"):
|
310 |
+
extra_info = {
|
311 |
+
"z": z,
|
312 |
+
"optimizer_idx": 0,
|
313 |
+
"global_step": self.global_step,
|
314 |
+
"last_layer": self.get_last_layer(),
|
315 |
+
"split": "val" + postfix,
|
316 |
+
"regularization_log": regularization_log,
|
317 |
+
"autoencoder": self,
|
318 |
+
}
|
319 |
+
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
320 |
+
else:
|
321 |
+
extra_info = dict()
|
322 |
+
out_loss = self.loss(x, xrec, **extra_info)
|
323 |
+
if isinstance(out_loss, tuple):
|
324 |
+
aeloss, log_dict_ae = out_loss
|
325 |
+
else:
|
326 |
+
# simple loss function
|
327 |
+
aeloss = out_loss
|
328 |
+
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
329 |
+
full_log_dict = log_dict_ae
|
330 |
+
|
331 |
+
if "optimizer_idx" in extra_info:
|
332 |
+
extra_info["optimizer_idx"] = 1
|
333 |
+
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
334 |
+
full_log_dict.update(log_dict_disc)
|
335 |
+
self.log(
|
336 |
+
f"val{postfix}/loss/rec",
|
337 |
+
log_dict_ae[f"val{postfix}/loss/rec"],
|
338 |
+
sync_dist=True,
|
339 |
+
)
|
340 |
+
self.log_dict(full_log_dict, sync_dist=True)
|
341 |
+
return full_log_dict
|
342 |
+
|
343 |
+
def get_param_groups(
|
344 |
+
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
345 |
+
) -> Tuple[List[Dict[str, Any]], int]:
|
346 |
+
groups = []
|
347 |
+
num_params = 0
|
348 |
+
for names, args in zip(parameter_names, optimizer_args):
|
349 |
+
params = []
|
350 |
+
for pattern_ in names:
|
351 |
+
pattern_params = []
|
352 |
+
pattern = re.compile(pattern_)
|
353 |
+
for p_name, param in self.named_parameters():
|
354 |
+
if re.match(pattern, p_name):
|
355 |
+
pattern_params.append(param)
|
356 |
+
num_params += param.numel()
|
357 |
+
if len(pattern_params) == 0:
|
358 |
+
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
359 |
+
params.extend(pattern_params)
|
360 |
+
groups.append({"params": params, **args})
|
361 |
+
return groups, num_params
|
362 |
+
|
363 |
+
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
364 |
+
if self.trainable_ae_params is None:
|
365 |
+
ae_params = self.get_autoencoder_params()
|
366 |
+
else:
|
367 |
+
ae_params, num_ae_params = self.get_param_groups(
|
368 |
+
self.trainable_ae_params, self.ae_optimizer_args
|
369 |
+
)
|
370 |
+
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
371 |
+
if self.trainable_disc_params is None:
|
372 |
+
disc_params = self.get_discriminator_params()
|
373 |
+
else:
|
374 |
+
disc_params, num_disc_params = self.get_param_groups(
|
375 |
+
self.trainable_disc_params, self.disc_optimizer_args
|
376 |
+
)
|
377 |
+
logpy.info(
|
378 |
+
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
379 |
+
)
|
380 |
+
opt_ae = self.instantiate_optimizer_from_config(
|
381 |
+
ae_params,
|
382 |
+
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
383 |
+
self.optimizer_config,
|
384 |
+
)
|
385 |
+
opts = [opt_ae]
|
386 |
+
if len(disc_params) > 0:
|
387 |
+
opt_disc = self.instantiate_optimizer_from_config(
|
388 |
+
disc_params, self.learning_rate, self.optimizer_config
|
389 |
+
)
|
390 |
+
opts.append(opt_disc)
|
391 |
+
|
392 |
+
return opts
|
393 |
+
|
394 |
+
@torch.no_grad()
|
395 |
+
def log_images(
|
396 |
+
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
397 |
+
) -> dict:
|
398 |
+
log = dict()
|
399 |
+
additional_decode_kwargs = {}
|
400 |
+
x = self.get_input(batch)
|
401 |
+
additional_decode_kwargs.update(
|
402 |
+
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
403 |
+
)
|
404 |
+
|
405 |
+
_, xrec, _ = self(x, **additional_decode_kwargs)
|
406 |
+
log["inputs"] = x
|
407 |
+
log["reconstructions"] = xrec
|
408 |
+
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
409 |
+
diff.clamp_(0, 1.0)
|
410 |
+
log["diff"] = 2.0 * diff - 1.0
|
411 |
+
# diff_boost shows location of small errors, by boosting their
|
412 |
+
# brightness.
|
413 |
+
log["diff_boost"] = (
|
414 |
+
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
415 |
+
)
|
416 |
+
if hasattr(self.loss, "log_images"):
|
417 |
+
log.update(self.loss.log_images(x, xrec))
|
418 |
+
with self.ema_scope():
|
419 |
+
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
420 |
+
log["reconstructions_ema"] = xrec_ema
|
421 |
+
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
422 |
+
diff_ema.clamp_(0, 1.0)
|
423 |
+
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
424 |
+
log["diff_boost_ema"] = (
|
425 |
+
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
426 |
+
)
|
427 |
+
if additional_log_kwargs:
|
428 |
+
additional_decode_kwargs.update(additional_log_kwargs)
|
429 |
+
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
430 |
+
log_str = "reconstructions-" + "-".join(
|
431 |
+
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
432 |
+
)
|
433 |
+
log[log_str] = xrec_add
|
434 |
+
return log
|
435 |
+
|
436 |
+
|
437 |
+
class AutoencodingEngineLegacy(AutoencodingEngine):
|
438 |
+
def __init__(self, embed_dim: int, **kwargs):
|
439 |
+
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
440 |
+
ddconfig = kwargs.pop("ddconfig")
|
441 |
+
ckpt_path = kwargs.pop("ckpt_path", None)
|
442 |
+
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
443 |
+
super().__init__(
|
444 |
+
encoder_config={
|
445 |
+
"target": "models.svd.sgm.modules.diffusionmodules.model.Encoder",
|
446 |
+
"params": ddconfig,
|
447 |
+
},
|
448 |
+
decoder_config={
|
449 |
+
"target": "models.svd.sgm.modules.diffusionmodules.model.Decoder",
|
450 |
+
"params": ddconfig,
|
451 |
+
},
|
452 |
+
**kwargs,
|
453 |
+
)
|
454 |
+
self.quant_conv = torch.nn.Conv2d(
|
455 |
+
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
456 |
+
(1 + ddconfig["double_z"]) * embed_dim,
|
457 |
+
1,
|
458 |
+
)
|
459 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
460 |
+
self.embed_dim = embed_dim
|
461 |
+
|
462 |
+
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
463 |
+
|
464 |
+
def get_autoencoder_params(self) -> list:
|
465 |
+
params = super().get_autoencoder_params()
|
466 |
+
return params
|
467 |
+
|
468 |
+
def encode(
|
469 |
+
self, x: torch.Tensor, return_reg_log: bool = False
|
470 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
471 |
+
if self.max_batch_size is None:
|
472 |
+
z = self.encoder(x)
|
473 |
+
z = self.quant_conv(z)
|
474 |
+
else:
|
475 |
+
N = x.shape[0]
|
476 |
+
bs = self.max_batch_size
|
477 |
+
n_batches = int(math.ceil(N / bs))
|
478 |
+
z = list()
|
479 |
+
for i_batch in range(n_batches):
|
480 |
+
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
481 |
+
z_batch = self.quant_conv(z_batch)
|
482 |
+
z.append(z_batch)
|
483 |
+
z = torch.cat(z, 0)
|
484 |
+
|
485 |
+
z, reg_log = self.regularization(z)
|
486 |
+
if return_reg_log:
|
487 |
+
return z, reg_log
|
488 |
+
return z
|
489 |
+
|
490 |
+
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
491 |
+
if self.max_batch_size is None:
|
492 |
+
dec = self.post_quant_conv(z)
|
493 |
+
dec = self.decoder(dec, **decoder_kwargs)
|
494 |
+
else:
|
495 |
+
N = z.shape[0]
|
496 |
+
bs = self.max_batch_size
|
497 |
+
n_batches = int(math.ceil(N / bs))
|
498 |
+
dec = list()
|
499 |
+
for i_batch in range(n_batches):
|
500 |
+
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
501 |
+
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
502 |
+
dec.append(dec_batch)
|
503 |
+
dec = torch.cat(dec, 0)
|
504 |
+
|
505 |
+
return dec
|
506 |
+
|
507 |
+
|
508 |
+
class AutoencoderKL(AutoencodingEngineLegacy):
|
509 |
+
def __init__(self, **kwargs):
|
510 |
+
if "lossconfig" in kwargs:
|
511 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
512 |
+
super().__init__(
|
513 |
+
regularizer_config={
|
514 |
+
"target": (
|
515 |
+
"sgm.modules.autoencoding.regularizers"
|
516 |
+
".DiagonalGaussianRegularizer"
|
517 |
+
)
|
518 |
+
},
|
519 |
+
**kwargs,
|
520 |
+
)
|
521 |
+
|
522 |
+
|
523 |
+
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
|
524 |
+
def __init__(
|
525 |
+
self,
|
526 |
+
embed_dim: int,
|
527 |
+
n_embed: int,
|
528 |
+
sane_index_shape: bool = False,
|
529 |
+
**kwargs,
|
530 |
+
):
|
531 |
+
if "lossconfig" in kwargs:
|
532 |
+
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
|
533 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
534 |
+
super().__init__(
|
535 |
+
regularizer_config={
|
536 |
+
"target": (
|
537 |
+
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
|
538 |
+
),
|
539 |
+
"params": {
|
540 |
+
"n_e": n_embed,
|
541 |
+
"e_dim": embed_dim,
|
542 |
+
"sane_index_shape": sane_index_shape,
|
543 |
+
},
|
544 |
+
},
|
545 |
+
**kwargs,
|
546 |
+
)
|
547 |
+
|
548 |
+
|
549 |
+
class IdentityFirstStage(AbstractAutoencoder):
|
550 |
+
def __init__(self, *args, **kwargs):
|
551 |
+
super().__init__(*args, **kwargs)
|
552 |
+
|
553 |
+
def get_input(self, x: Any) -> Any:
|
554 |
+
return x
|
555 |
+
|
556 |
+
def encode(self, x: Any, *args, **kwargs) -> Any:
|
557 |
+
return x
|
558 |
+
|
559 |
+
def decode(self, x: Any, *args, **kwargs) -> Any:
|
560 |
+
return x
|
561 |
+
|
562 |
+
|
563 |
+
class AEIntegerWrapper(nn.Module):
|
564 |
+
def __init__(
|
565 |
+
self,
|
566 |
+
model: nn.Module,
|
567 |
+
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
|
568 |
+
regularization_key: str = "regularization",
|
569 |
+
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
570 |
+
):
|
571 |
+
super().__init__()
|
572 |
+
self.model = model
|
573 |
+
assert hasattr(model, "encode") and hasattr(
|
574 |
+
model, "decode"
|
575 |
+
), "Need AE interface"
|
576 |
+
self.regularization = get_nested_attribute(model, regularization_key)
|
577 |
+
self.shape = shape
|
578 |
+
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
|
579 |
+
|
580 |
+
def encode(self, x) -> torch.Tensor:
|
581 |
+
assert (
|
582 |
+
not self.training
|
583 |
+
), f"{self.__class__.__name__} only supports inference currently"
|
584 |
+
_, log = self.model.encode(x, **self.encoder_kwargs)
|
585 |
+
assert isinstance(log, dict)
|
586 |
+
inds = log["min_encoding_indices"]
|
587 |
+
return rearrange(inds, "b ... -> b (...)")
|
588 |
+
|
589 |
+
def decode(
|
590 |
+
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
|
591 |
+
) -> torch.Tensor:
|
592 |
+
# expect inds shape (b, s) with s = h*w
|
593 |
+
shape = default(shape, self.shape) # Optional[(h, w)]
|
594 |
+
if shape is not None:
|
595 |
+
assert len(shape) == 2, f"Unhandeled shape {shape}"
|
596 |
+
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
|
597 |
+
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
|
598 |
+
h = rearrange(h, "b h w c -> b c h w")
|
599 |
+
return self.model.decode(h)
|
600 |
+
|
601 |
+
|
602 |
+
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
|
603 |
+
def __init__(self, **kwargs):
|
604 |
+
if "lossconfig" in kwargs:
|
605 |
+
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
606 |
+
super().__init__(
|
607 |
+
regularizer_config={
|
608 |
+
"target": (
|
609 |
+
"models.svd.sgm.modules.autoencoding.regularizers"
|
610 |
+
".DiagonalGaussianRegularizer"
|
611 |
+
),
|
612 |
+
"params": {"sample": False},
|
613 |
+
},
|
614 |
+
**kwargs,
|
615 |
+
)
|
models/svd/sgm/models/diffusion.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from contextlib import contextmanager
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
from omegaconf import ListConfig, OmegaConf
|
8 |
+
from safetensors.torch import load_file as load_safetensors
|
9 |
+
from torch.optim.lr_scheduler import LambdaLR
|
10 |
+
|
11 |
+
from models.svd.sgm.modules import UNCONDITIONAL_CONFIG
|
12 |
+
from models.svd.sgm.modules.autoencoding.temporal_ae import VideoDecoder
|
13 |
+
from models.svd.sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
14 |
+
from models.svd.sgm.modules.ema import LitEma
|
15 |
+
from models.svd.sgm.util import (default, disabled_train, get_obj_from_str,
|
16 |
+
instantiate_from_config, log_txt_as_img)
|
17 |
+
|
18 |
+
|
19 |
+
class DiffusionEngine(pl.LightningModule):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
network_config,
|
23 |
+
denoiser_config,
|
24 |
+
first_stage_config,
|
25 |
+
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
26 |
+
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
27 |
+
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
28 |
+
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
29 |
+
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
30 |
+
network_wrapper: Union[None, str] = None,
|
31 |
+
ckpt_path: Union[None, str] = None,
|
32 |
+
use_ema: bool = False,
|
33 |
+
ema_decay_rate: float = 0.9999,
|
34 |
+
scale_factor: float = 1.0,
|
35 |
+
disable_first_stage_autocast=False,
|
36 |
+
input_key: str = "jpg",
|
37 |
+
log_keys: Union[List, None] = None,
|
38 |
+
no_cond_log: bool = False,
|
39 |
+
compile_model: bool = False,
|
40 |
+
en_and_decode_n_samples_a_time: Optional[int] = None,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
self.log_keys = log_keys
|
44 |
+
self.input_key = input_key
|
45 |
+
self.optimizer_config = default(
|
46 |
+
optimizer_config, {"target": "torch.optim.AdamW"}
|
47 |
+
)
|
48 |
+
model = instantiate_from_config(network_config)
|
49 |
+
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
50 |
+
model, compile_model=compile_model
|
51 |
+
)
|
52 |
+
|
53 |
+
self.denoiser = instantiate_from_config(denoiser_config)
|
54 |
+
self.sampler = (
|
55 |
+
instantiate_from_config(sampler_config)
|
56 |
+
if sampler_config is not None
|
57 |
+
else None
|
58 |
+
)
|
59 |
+
self.conditioner = instantiate_from_config(
|
60 |
+
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
61 |
+
)
|
62 |
+
self.scheduler_config = scheduler_config
|
63 |
+
self._init_first_stage(first_stage_config)
|
64 |
+
|
65 |
+
self.loss_fn = (
|
66 |
+
instantiate_from_config(loss_fn_config)
|
67 |
+
if loss_fn_config is not None
|
68 |
+
else None
|
69 |
+
)
|
70 |
+
|
71 |
+
self.use_ema = use_ema
|
72 |
+
if self.use_ema:
|
73 |
+
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
74 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
75 |
+
|
76 |
+
self.scale_factor = scale_factor
|
77 |
+
self.disable_first_stage_autocast = disable_first_stage_autocast
|
78 |
+
self.no_cond_log = no_cond_log
|
79 |
+
|
80 |
+
if ckpt_path is not None:
|
81 |
+
self.init_from_ckpt(ckpt_path)
|
82 |
+
|
83 |
+
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
84 |
+
|
85 |
+
def init_from_ckpt(
|
86 |
+
self,
|
87 |
+
path: str,
|
88 |
+
) -> None:
|
89 |
+
if path.endswith("ckpt"):
|
90 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
91 |
+
elif path.endswith("safetensors"):
|
92 |
+
sd = load_safetensors(path)
|
93 |
+
else:
|
94 |
+
raise NotImplementedError
|
95 |
+
|
96 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
97 |
+
print(
|
98 |
+
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
99 |
+
)
|
100 |
+
if len(missing) > 0:
|
101 |
+
print(f"Missing Keys: {missing}")
|
102 |
+
if len(unexpected) > 0:
|
103 |
+
print(f"Unexpected Keys: {unexpected}")
|
104 |
+
|
105 |
+
def _init_first_stage(self, config):
|
106 |
+
model = instantiate_from_config(config).eval()
|
107 |
+
model.train = disabled_train
|
108 |
+
for param in model.parameters():
|
109 |
+
param.requires_grad = False
|
110 |
+
self.first_stage_model = model
|
111 |
+
|
112 |
+
def get_input(self, batch):
|
113 |
+
# assuming unified data format, dataloader returns a dict.
|
114 |
+
# image tensors should be scaled to -1 ... 1 and in bchw format
|
115 |
+
return batch[self.input_key]
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def decode_first_stage(self, z):
|
119 |
+
z = 1.0 / self.scale_factor * z
|
120 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
121 |
+
|
122 |
+
n_rounds = math.ceil(z.shape[0] / n_samples)
|
123 |
+
all_out = []
|
124 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
125 |
+
for n in range(n_rounds):
|
126 |
+
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
127 |
+
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
128 |
+
else:
|
129 |
+
kwargs = {}
|
130 |
+
out = self.first_stage_model.decode(
|
131 |
+
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
132 |
+
)
|
133 |
+
all_out.append(out)
|
134 |
+
out = torch.cat(all_out, dim=0)
|
135 |
+
return out
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def encode_first_stage(self, x):
|
139 |
+
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
140 |
+
n_rounds = math.ceil(x.shape[0] / n_samples)
|
141 |
+
all_out = []
|
142 |
+
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
143 |
+
for n in range(n_rounds):
|
144 |
+
out = self.first_stage_model.encode(
|
145 |
+
x[n * n_samples : (n + 1) * n_samples]
|
146 |
+
)
|
147 |
+
all_out.append(out)
|
148 |
+
z = torch.cat(all_out, dim=0)
|
149 |
+
z = self.scale_factor * z
|
150 |
+
return z
|
151 |
+
|
152 |
+
def forward(self, x, batch):
|
153 |
+
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
|
154 |
+
loss_mean = loss.mean()
|
155 |
+
loss_dict = {"loss": loss_mean}
|
156 |
+
return loss_mean, loss_dict
|
157 |
+
|
158 |
+
def shared_step(self, batch: Dict) -> Any:
|
159 |
+
x = self.get_input(batch)
|
160 |
+
x = self.encode_first_stage(x)
|
161 |
+
batch["global_step"] = self.global_step
|
162 |
+
loss, loss_dict = self(x, batch)
|
163 |
+
return loss, loss_dict
|
164 |
+
|
165 |
+
def training_step(self, batch, batch_idx):
|
166 |
+
loss, loss_dict = self.shared_step(batch)
|
167 |
+
|
168 |
+
self.log_dict(
|
169 |
+
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
170 |
+
)
|
171 |
+
|
172 |
+
self.log(
|
173 |
+
"global_step",
|
174 |
+
self.global_step,
|
175 |
+
prog_bar=True,
|
176 |
+
logger=True,
|
177 |
+
on_step=True,
|
178 |
+
on_epoch=False,
|
179 |
+
)
|
180 |
+
|
181 |
+
if self.scheduler_config is not None:
|
182 |
+
lr = self.optimizers().param_groups[0]["lr"]
|
183 |
+
self.log(
|
184 |
+
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
185 |
+
)
|
186 |
+
|
187 |
+
return loss
|
188 |
+
|
189 |
+
def on_train_start(self, *args, **kwargs):
|
190 |
+
if self.sampler is None or self.loss_fn is None:
|
191 |
+
raise ValueError("Sampler and loss function need to be set for training.")
|
192 |
+
|
193 |
+
def on_train_batch_end(self, *args, **kwargs):
|
194 |
+
if self.use_ema:
|
195 |
+
self.model_ema(self.model)
|
196 |
+
|
197 |
+
@contextmanager
|
198 |
+
def ema_scope(self, context=None):
|
199 |
+
if self.use_ema:
|
200 |
+
self.model_ema.store(self.model.parameters())
|
201 |
+
self.model_ema.copy_to(self.model)
|
202 |
+
if context is not None:
|
203 |
+
print(f"{context}: Switched to EMA weights")
|
204 |
+
try:
|
205 |
+
yield None
|
206 |
+
finally:
|
207 |
+
if self.use_ema:
|
208 |
+
self.model_ema.restore(self.model.parameters())
|
209 |
+
if context is not None:
|
210 |
+
print(f"{context}: Restored training weights")
|
211 |
+
|
212 |
+
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
213 |
+
return get_obj_from_str(cfg["target"])(
|
214 |
+
params, lr=lr, **cfg.get("params", dict())
|
215 |
+
)
|
216 |
+
|
217 |
+
def configure_optimizers(self):
|
218 |
+
lr = self.learning_rate
|
219 |
+
params = list(self.model.parameters())
|
220 |
+
for embedder in self.conditioner.embedders:
|
221 |
+
if embedder.is_trainable:
|
222 |
+
params = params + list(embedder.parameters())
|
223 |
+
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
224 |
+
if self.scheduler_config is not None:
|
225 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
226 |
+
print("Setting up LambdaLR scheduler...")
|
227 |
+
scheduler = [
|
228 |
+
{
|
229 |
+
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
230 |
+
"interval": "step",
|
231 |
+
"frequency": 1,
|
232 |
+
}
|
233 |
+
]
|
234 |
+
return [opt], scheduler
|
235 |
+
return opt
|
236 |
+
|
237 |
+
@torch.no_grad()
|
238 |
+
def sample(
|
239 |
+
self,
|
240 |
+
cond: Dict,
|
241 |
+
uc: Union[Dict, None] = None,
|
242 |
+
batch_size: int = 16,
|
243 |
+
shape: Union[None, Tuple, List] = None,
|
244 |
+
**kwargs,
|
245 |
+
):
|
246 |
+
randn = torch.randn(batch_size, *shape).to(self.device)
|
247 |
+
|
248 |
+
denoiser = lambda input, sigma, c: self.denoiser(
|
249 |
+
self.model, input, sigma, c, **kwargs
|
250 |
+
)
|
251 |
+
samples = self.sampler(denoiser, randn, cond, uc=uc)
|
252 |
+
return samples
|
253 |
+
|
254 |
+
@torch.no_grad()
|
255 |
+
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
256 |
+
"""
|
257 |
+
Defines heuristics to log different conditionings.
|
258 |
+
These can be lists of strings (text-to-image), tensors, ints, ...
|
259 |
+
"""
|
260 |
+
image_h, image_w = batch[self.input_key].shape[2:]
|
261 |
+
log = dict()
|
262 |
+
|
263 |
+
for embedder in self.conditioner.embedders:
|
264 |
+
if (
|
265 |
+
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
266 |
+
) and not self.no_cond_log:
|
267 |
+
x = batch[embedder.input_key][:n]
|
268 |
+
if isinstance(x, torch.Tensor):
|
269 |
+
if x.dim() == 1:
|
270 |
+
# class-conditional, convert integer to string
|
271 |
+
x = [str(x[i].item()) for i in range(x.shape[0])]
|
272 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
273 |
+
elif x.dim() == 2:
|
274 |
+
# size and crop cond and the like
|
275 |
+
x = [
|
276 |
+
"x".join([str(xx) for xx in x[i].tolist()])
|
277 |
+
for i in range(x.shape[0])
|
278 |
+
]
|
279 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
280 |
+
else:
|
281 |
+
raise NotImplementedError()
|
282 |
+
elif isinstance(x, (List, ListConfig)):
|
283 |
+
if isinstance(x[0], str):
|
284 |
+
# strings
|
285 |
+
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
286 |
+
else:
|
287 |
+
raise NotImplementedError()
|
288 |
+
else:
|
289 |
+
raise NotImplementedError()
|
290 |
+
log[embedder.input_key] = xc
|
291 |
+
return log
|
292 |
+
|
293 |
+
@torch.no_grad()
|
294 |
+
def log_images(
|
295 |
+
self,
|
296 |
+
batch: Dict,
|
297 |
+
N: int = 8,
|
298 |
+
sample: bool = True,
|
299 |
+
ucg_keys: List[str] = None,
|
300 |
+
**kwargs,
|
301 |
+
) -> Dict:
|
302 |
+
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
303 |
+
if ucg_keys:
|
304 |
+
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
305 |
+
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
306 |
+
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
ucg_keys = conditioner_input_keys
|
310 |
+
log = dict()
|
311 |
+
|
312 |
+
x = self.get_input(batch)
|
313 |
+
|
314 |
+
c, uc = self.conditioner.get_unconditional_conditioning(
|
315 |
+
batch,
|
316 |
+
force_uc_zero_embeddings=ucg_keys
|
317 |
+
if len(self.conditioner.embedders) > 0
|
318 |
+
else [],
|
319 |
+
)
|
320 |
+
|
321 |
+
sampling_kwargs = {}
|
322 |
+
|
323 |
+
N = min(x.shape[0], N)
|
324 |
+
x = x.to(self.device)[:N]
|
325 |
+
log["inputs"] = x
|
326 |
+
z = self.encode_first_stage(x)
|
327 |
+
log["reconstructions"] = self.decode_first_stage(z)
|
328 |
+
log.update(self.log_conditionings(batch, N))
|
329 |
+
|
330 |
+
for k in c:
|
331 |
+
if isinstance(c[k], torch.Tensor):
|
332 |
+
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
333 |
+
|
334 |
+
if sample:
|
335 |
+
with self.ema_scope("Plotting"):
|
336 |
+
samples = self.sample(
|
337 |
+
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
338 |
+
)
|
339 |
+
samples = self.decode_first_stage(samples)
|
340 |
+
log["samples"] = samples
|
341 |
+
return log
|
models/svd/sgm/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.svd.sgm.modules.encoders.modules import GeneralConditioner
|
2 |
+
|
3 |
+
UNCONDITIONAL_CONFIG = {
|
4 |
+
"target": "sgm.modules.GeneralConditioner",
|
5 |
+
"params": {"emb_models": []},
|
6 |
+
}
|
models/svd/sgm/modules/attention.py
ADDED
@@ -0,0 +1,809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from inspect import isfunction
|
4 |
+
from typing import Any, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
from packaging import version
|
10 |
+
from torch import nn
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
logpy = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
16 |
+
SDP_IS_AVAILABLE = True
|
17 |
+
from torch.backends.cuda import SDPBackend, sdp_kernel
|
18 |
+
|
19 |
+
BACKEND_MAP = {
|
20 |
+
SDPBackend.MATH: {
|
21 |
+
"enable_math": True,
|
22 |
+
"enable_flash": False,
|
23 |
+
"enable_mem_efficient": False,
|
24 |
+
},
|
25 |
+
SDPBackend.FLASH_ATTENTION: {
|
26 |
+
"enable_math": False,
|
27 |
+
"enable_flash": True,
|
28 |
+
"enable_mem_efficient": False,
|
29 |
+
},
|
30 |
+
SDPBackend.EFFICIENT_ATTENTION: {
|
31 |
+
"enable_math": False,
|
32 |
+
"enable_flash": False,
|
33 |
+
"enable_mem_efficient": True,
|
34 |
+
},
|
35 |
+
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
36 |
+
}
|
37 |
+
else:
|
38 |
+
from contextlib import nullcontext
|
39 |
+
|
40 |
+
SDP_IS_AVAILABLE = False
|
41 |
+
sdp_kernel = nullcontext
|
42 |
+
BACKEND_MAP = {}
|
43 |
+
logpy.warn(
|
44 |
+
f"No SDP backend available, likely because you are running in pytorch "
|
45 |
+
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
46 |
+
f"You might want to consider upgrading."
|
47 |
+
)
|
48 |
+
|
49 |
+
try:
|
50 |
+
import xformers
|
51 |
+
import xformers.ops
|
52 |
+
|
53 |
+
XFORMERS_IS_AVAILABLE = True
|
54 |
+
except:
|
55 |
+
XFORMERS_IS_AVAILABLE = False
|
56 |
+
logpy.warn("no module 'xformers'. Processing without...")
|
57 |
+
|
58 |
+
# from .diffusionmodules.util import mixed_checkpoint as checkpoint
|
59 |
+
|
60 |
+
|
61 |
+
def exists(val):
|
62 |
+
return val is not None
|
63 |
+
|
64 |
+
|
65 |
+
def uniq(arr):
|
66 |
+
return {el: True for el in arr}.keys()
|
67 |
+
|
68 |
+
|
69 |
+
def default(val, d):
|
70 |
+
if exists(val):
|
71 |
+
return val
|
72 |
+
return d() if isfunction(d) else d
|
73 |
+
|
74 |
+
|
75 |
+
def max_neg_value(t):
|
76 |
+
return -torch.finfo(t.dtype).max
|
77 |
+
|
78 |
+
|
79 |
+
def init_(tensor):
|
80 |
+
dim = tensor.shape[-1]
|
81 |
+
std = 1 / math.sqrt(dim)
|
82 |
+
tensor.uniform_(-std, std)
|
83 |
+
return tensor
|
84 |
+
|
85 |
+
|
86 |
+
# feedforward
|
87 |
+
class GEGLU(nn.Module):
|
88 |
+
def __init__(self, dim_in, dim_out):
|
89 |
+
super().__init__()
|
90 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
94 |
+
return x * F.gelu(gate)
|
95 |
+
|
96 |
+
|
97 |
+
class FeedForward(nn.Module):
|
98 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
99 |
+
super().__init__()
|
100 |
+
inner_dim = int(dim * mult)
|
101 |
+
dim_out = default(dim_out, dim)
|
102 |
+
project_in = (
|
103 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
104 |
+
if not glu
|
105 |
+
else GEGLU(dim, inner_dim)
|
106 |
+
)
|
107 |
+
|
108 |
+
self.net = nn.Sequential(
|
109 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
return self.net(x)
|
114 |
+
|
115 |
+
|
116 |
+
def zero_module(module):
|
117 |
+
"""
|
118 |
+
Zero out the parameters of a module and return it.
|
119 |
+
"""
|
120 |
+
for p in module.parameters():
|
121 |
+
p.detach().zero_()
|
122 |
+
return module
|
123 |
+
|
124 |
+
|
125 |
+
def Normalize(in_channels):
|
126 |
+
return torch.nn.GroupNorm(
|
127 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
class LinearAttention(nn.Module):
|
132 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
133 |
+
super().__init__()
|
134 |
+
self.heads = heads
|
135 |
+
hidden_dim = dim_head * heads
|
136 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
137 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
b, c, h, w = x.shape
|
141 |
+
qkv = self.to_qkv(x)
|
142 |
+
q, k, v = rearrange(
|
143 |
+
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
144 |
+
)
|
145 |
+
k = k.softmax(dim=-1)
|
146 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
147 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
148 |
+
out = rearrange(
|
149 |
+
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
150 |
+
)
|
151 |
+
return self.to_out(out)
|
152 |
+
|
153 |
+
|
154 |
+
class SelfAttention(nn.Module):
|
155 |
+
ATTENTION_MODES = ("xformers", "torch", "math")
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
dim: int,
|
160 |
+
num_heads: int = 8,
|
161 |
+
qkv_bias: bool = False,
|
162 |
+
qk_scale: Optional[float] = None,
|
163 |
+
attn_drop: float = 0.0,
|
164 |
+
proj_drop: float = 0.0,
|
165 |
+
attn_mode: str = "xformers",
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
self.num_heads = num_heads
|
169 |
+
head_dim = dim // num_heads
|
170 |
+
self.scale = qk_scale or head_dim**-0.5
|
171 |
+
|
172 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
173 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
174 |
+
self.proj = nn.Linear(dim, dim)
|
175 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
176 |
+
assert attn_mode in self.ATTENTION_MODES
|
177 |
+
self.attn_mode = attn_mode
|
178 |
+
|
179 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
180 |
+
B, L, C = x.shape
|
181 |
+
|
182 |
+
qkv = self.qkv(x)
|
183 |
+
if self.attn_mode == "torch":
|
184 |
+
qkv = rearrange(
|
185 |
+
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
186 |
+
).float()
|
187 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
188 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
189 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
190 |
+
elif self.attn_mode == "xformers":
|
191 |
+
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
192 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
193 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
194 |
+
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
|
195 |
+
elif self.attn_mode == "math":
|
196 |
+
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
197 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
198 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
199 |
+
attn = attn.softmax(dim=-1)
|
200 |
+
attn = self.attn_drop(attn)
|
201 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
202 |
+
else:
|
203 |
+
raise NotImplemented
|
204 |
+
|
205 |
+
x = self.proj(x)
|
206 |
+
x = self.proj_drop(x)
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class SpatialSelfAttention(nn.Module):
|
211 |
+
def __init__(self, in_channels):
|
212 |
+
super().__init__()
|
213 |
+
self.in_channels = in_channels
|
214 |
+
|
215 |
+
self.norm = Normalize(in_channels)
|
216 |
+
self.q = torch.nn.Conv2d(
|
217 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
218 |
+
)
|
219 |
+
self.k = torch.nn.Conv2d(
|
220 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
221 |
+
)
|
222 |
+
self.v = torch.nn.Conv2d(
|
223 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
224 |
+
)
|
225 |
+
self.proj_out = torch.nn.Conv2d(
|
226 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
227 |
+
)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
h_ = x
|
231 |
+
h_ = self.norm(h_)
|
232 |
+
q = self.q(h_)
|
233 |
+
k = self.k(h_)
|
234 |
+
v = self.v(h_)
|
235 |
+
|
236 |
+
# compute attention
|
237 |
+
b, c, h, w = q.shape
|
238 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
239 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
240 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
241 |
+
|
242 |
+
w_ = w_ * (int(c) ** (-0.5))
|
243 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
244 |
+
|
245 |
+
# attend to values
|
246 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
247 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
248 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
249 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
250 |
+
h_ = self.proj_out(h_)
|
251 |
+
|
252 |
+
return x + h_
|
253 |
+
|
254 |
+
|
255 |
+
class CrossAttention(nn.Module):
|
256 |
+
def __init__(
|
257 |
+
self,
|
258 |
+
query_dim,
|
259 |
+
context_dim=None,
|
260 |
+
heads=8,
|
261 |
+
dim_head=64,
|
262 |
+
dropout=0.0,
|
263 |
+
backend=None,
|
264 |
+
):
|
265 |
+
super().__init__()
|
266 |
+
inner_dim = dim_head * heads
|
267 |
+
context_dim = default(context_dim, query_dim)
|
268 |
+
|
269 |
+
self.scale = dim_head**-0.5
|
270 |
+
self.heads = heads
|
271 |
+
|
272 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
273 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
274 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
275 |
+
|
276 |
+
self.to_out = nn.Sequential(
|
277 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
278 |
+
)
|
279 |
+
self.backend = backend
|
280 |
+
|
281 |
+
def forward(
|
282 |
+
self,
|
283 |
+
x,
|
284 |
+
context=None,
|
285 |
+
mask=None,
|
286 |
+
additional_tokens=None,
|
287 |
+
n_times_crossframe_attn_in_self=0,
|
288 |
+
):
|
289 |
+
h = self.heads
|
290 |
+
|
291 |
+
if additional_tokens is not None:
|
292 |
+
# get the number of masked tokens at the beginning of the output sequence
|
293 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
294 |
+
# add additional token
|
295 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
296 |
+
|
297 |
+
q = self.to_q(x)
|
298 |
+
context = default(context, x)
|
299 |
+
k = self.to_k(context)
|
300 |
+
v = self.to_v(context)
|
301 |
+
|
302 |
+
if n_times_crossframe_attn_in_self:
|
303 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
304 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
305 |
+
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
306 |
+
k = repeat(
|
307 |
+
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
308 |
+
)
|
309 |
+
v = repeat(
|
310 |
+
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
311 |
+
)
|
312 |
+
|
313 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
314 |
+
|
315 |
+
## old
|
316 |
+
"""
|
317 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
318 |
+
del q, k
|
319 |
+
|
320 |
+
if exists(mask):
|
321 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
322 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
323 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
324 |
+
sim.masked_fill_(~mask, max_neg_value)
|
325 |
+
|
326 |
+
# attention, what we cannot get enough of
|
327 |
+
sim = sim.softmax(dim=-1)
|
328 |
+
|
329 |
+
out = einsum('b i j, b j d -> b i d', sim, v)
|
330 |
+
"""
|
331 |
+
## new
|
332 |
+
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
333 |
+
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
334 |
+
out = F.scaled_dot_product_attention(
|
335 |
+
q, k, v, attn_mask=mask
|
336 |
+
) # scale is dim_head ** -0.5 per default
|
337 |
+
|
338 |
+
del q, k, v
|
339 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
340 |
+
|
341 |
+
if additional_tokens is not None:
|
342 |
+
# remove additional token
|
343 |
+
out = out[:, n_tokens_to_mask:]
|
344 |
+
return self.to_out(out)
|
345 |
+
|
346 |
+
|
347 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
348 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
349 |
+
def __init__(
|
350 |
+
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
351 |
+
):
|
352 |
+
super().__init__()
|
353 |
+
logpy.debug(
|
354 |
+
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
355 |
+
f"context_dim is {context_dim} and using {heads} heads with a "
|
356 |
+
f"dimension of {dim_head}."
|
357 |
+
)
|
358 |
+
inner_dim = dim_head * heads
|
359 |
+
context_dim = default(context_dim, query_dim)
|
360 |
+
|
361 |
+
self.heads = heads
|
362 |
+
self.dim_head = dim_head
|
363 |
+
|
364 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
365 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
366 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
367 |
+
|
368 |
+
self.to_out = nn.Sequential(
|
369 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
370 |
+
)
|
371 |
+
self.attention_op: Optional[Any] = None
|
372 |
+
|
373 |
+
def forward(
|
374 |
+
self,
|
375 |
+
x,
|
376 |
+
context=None,
|
377 |
+
mask=None,
|
378 |
+
additional_tokens=None,
|
379 |
+
n_times_crossframe_attn_in_self=0,
|
380 |
+
):
|
381 |
+
if additional_tokens is not None:
|
382 |
+
# get the number of masked tokens at the beginning of the output sequence
|
383 |
+
n_tokens_to_mask = additional_tokens.shape[1]
|
384 |
+
# add additional token
|
385 |
+
x = torch.cat([additional_tokens, x], dim=1)
|
386 |
+
q = self.to_q(x)
|
387 |
+
context = default(context, x)
|
388 |
+
k = self.to_k(context)
|
389 |
+
v = self.to_v(context)
|
390 |
+
|
391 |
+
if n_times_crossframe_attn_in_self:
|
392 |
+
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
393 |
+
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
394 |
+
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
|
395 |
+
k = repeat(
|
396 |
+
k[::n_times_crossframe_attn_in_self],
|
397 |
+
"b ... -> (b n) ...",
|
398 |
+
n=n_times_crossframe_attn_in_self,
|
399 |
+
)
|
400 |
+
v = repeat(
|
401 |
+
v[::n_times_crossframe_attn_in_self],
|
402 |
+
"b ... -> (b n) ...",
|
403 |
+
n=n_times_crossframe_attn_in_self,
|
404 |
+
)
|
405 |
+
|
406 |
+
b, _, _ = q.shape
|
407 |
+
q, k, v = map(
|
408 |
+
lambda t: t.unsqueeze(3)
|
409 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
410 |
+
.permute(0, 2, 1, 3)
|
411 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
412 |
+
.contiguous(),
|
413 |
+
(q, k, v),
|
414 |
+
)
|
415 |
+
|
416 |
+
# actually compute the attention, what we cannot get enough of
|
417 |
+
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
|
418 |
+
# NOTE: workaround for
|
419 |
+
# https://github.com/facebookresearch/xformers/issues/845
|
420 |
+
max_bs = 32768
|
421 |
+
N = q.shape[0]
|
422 |
+
n_batches = math.ceil(N / max_bs)
|
423 |
+
out = list()
|
424 |
+
for i_batch in range(n_batches):
|
425 |
+
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
|
426 |
+
out.append(
|
427 |
+
xformers.ops.memory_efficient_attention(
|
428 |
+
q[batch],
|
429 |
+
k[batch],
|
430 |
+
v[batch],
|
431 |
+
attn_bias=None,
|
432 |
+
op=self.attention_op,
|
433 |
+
)
|
434 |
+
)
|
435 |
+
out = torch.cat(out, 0)
|
436 |
+
else:
|
437 |
+
out = xformers.ops.memory_efficient_attention(
|
438 |
+
q, k, v, attn_bias=None, op=self.attention_op
|
439 |
+
)
|
440 |
+
|
441 |
+
# TODO: Use this directly in the attention operation, as a bias
|
442 |
+
if exists(mask):
|
443 |
+
raise NotImplementedError
|
444 |
+
out = (
|
445 |
+
out.unsqueeze(0)
|
446 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
447 |
+
.permute(0, 2, 1, 3)
|
448 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
449 |
+
)
|
450 |
+
if additional_tokens is not None:
|
451 |
+
# remove additional token
|
452 |
+
out = out[:, n_tokens_to_mask:]
|
453 |
+
return self.to_out(out)
|
454 |
+
|
455 |
+
|
456 |
+
|
457 |
+
class BasicTransformerBlock(nn.Module):
|
458 |
+
ATTENTION_MODES = {
|
459 |
+
"softmax": CrossAttention, # vanilla attention
|
460 |
+
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
|
461 |
+
}
|
462 |
+
|
463 |
+
def __init__(
|
464 |
+
self,
|
465 |
+
dim,
|
466 |
+
n_heads,
|
467 |
+
d_head,
|
468 |
+
dropout=0.0,
|
469 |
+
context_dim=None,
|
470 |
+
gated_ff=True,
|
471 |
+
checkpoint=True,
|
472 |
+
disable_self_attn=False,
|
473 |
+
attn_mode="softmax",
|
474 |
+
sdp_backend=None,
|
475 |
+
):
|
476 |
+
super().__init__()
|
477 |
+
assert attn_mode in self.ATTENTION_MODES
|
478 |
+
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
479 |
+
logpy.warn(
|
480 |
+
f"Attention mode '{attn_mode}' is not available. Falling "
|
481 |
+
f"back to native attention. This is not a problem in "
|
482 |
+
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
|
483 |
+
f"version {torch.__version__}."
|
484 |
+
)
|
485 |
+
attn_mode = "softmax"
|
486 |
+
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
487 |
+
logpy.warn(
|
488 |
+
"We do not support vanilla attention anymore, as it is too "
|
489 |
+
"expensive. Sorry."
|
490 |
+
)
|
491 |
+
if not XFORMERS_IS_AVAILABLE:
|
492 |
+
assert (
|
493 |
+
False
|
494 |
+
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
495 |
+
else:
|
496 |
+
logpy.info("Falling back to xformers efficient attention.")
|
497 |
+
attn_mode = "softmax-xformers"
|
498 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
499 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
500 |
+
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
501 |
+
else:
|
502 |
+
assert sdp_backend is None
|
503 |
+
self.disable_self_attn = disable_self_attn
|
504 |
+
self.attn1 = attn_cls(
|
505 |
+
query_dim=dim,
|
506 |
+
heads=n_heads,
|
507 |
+
dim_head=d_head,
|
508 |
+
dropout=dropout,
|
509 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
510 |
+
backend=sdp_backend,
|
511 |
+
) # is a self-attention if not self.disable_self_attn
|
512 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
513 |
+
self.attn2 = attn_cls(
|
514 |
+
query_dim=dim,
|
515 |
+
context_dim=context_dim,
|
516 |
+
heads=n_heads,
|
517 |
+
dim_head=d_head,
|
518 |
+
dropout=dropout,
|
519 |
+
backend=sdp_backend,
|
520 |
+
) # is self-attn if context is none
|
521 |
+
self.norm1 = nn.LayerNorm(dim)
|
522 |
+
self.norm2 = nn.LayerNorm(dim)
|
523 |
+
self.norm3 = nn.LayerNorm(dim)
|
524 |
+
self.checkpoint = checkpoint
|
525 |
+
if self.checkpoint:
|
526 |
+
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
|
527 |
+
|
528 |
+
|
529 |
+
def forward(
|
530 |
+
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
531 |
+
):
|
532 |
+
kwargs = {"x": x}
|
533 |
+
|
534 |
+
if context is not None:
|
535 |
+
kwargs.update({"context": context})
|
536 |
+
|
537 |
+
if additional_tokens is not None:
|
538 |
+
kwargs.update({"additional_tokens": additional_tokens})
|
539 |
+
|
540 |
+
if n_times_crossframe_attn_in_self:
|
541 |
+
kwargs.update(
|
542 |
+
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
|
543 |
+
)
|
544 |
+
|
545 |
+
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
546 |
+
if self.checkpoint:
|
547 |
+
# inputs = {"x": x, "context": context}
|
548 |
+
return checkpoint(self._forward, x, context)
|
549 |
+
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
550 |
+
else:
|
551 |
+
return self._forward(**kwargs)
|
552 |
+
|
553 |
+
def _forward(
|
554 |
+
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
555 |
+
):
|
556 |
+
x = (
|
557 |
+
self.attn1(
|
558 |
+
self.norm1(x),
|
559 |
+
context=context if self.disable_self_attn else None,
|
560 |
+
additional_tokens=additional_tokens,
|
561 |
+
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
562 |
+
if not self.disable_self_attn
|
563 |
+
else 0,
|
564 |
+
)
|
565 |
+
+ x
|
566 |
+
)
|
567 |
+
x = (
|
568 |
+
self.attn2(
|
569 |
+
self.norm2(x), context=context, additional_tokens=additional_tokens
|
570 |
+
)
|
571 |
+
+ x
|
572 |
+
)
|
573 |
+
x = self.ff(self.norm3(x)) + x
|
574 |
+
return x
|
575 |
+
|
576 |
+
|
577 |
+
class BasicTransformerBlockWithAPM(BasicTransformerBlock):
|
578 |
+
|
579 |
+
def __init__(self, dim, n_heads, d_head, dropout=0, context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False, attn_mode="softmax", sdp_backend=None,use_apm=False):
|
580 |
+
super().__init__(dim, n_heads, d_head, dropout, context_dim, gated_ff, checkpoint, disable_self_attn, attn_mode, sdp_backend)
|
581 |
+
# APM Addition
|
582 |
+
assert disable_self_attn == False
|
583 |
+
self.use_apm = use_apm
|
584 |
+
if use_apm:
|
585 |
+
tokens_apm_clip = 16+1
|
586 |
+
self.apm_conv = torch.nn.Conv1d(
|
587 |
+
tokens_apm_clip, 1, kernel_size=3, padding="same")
|
588 |
+
channel_dim_context = 1024
|
589 |
+
self.apm_ln = nn.LayerNorm(channel_dim_context)
|
590 |
+
self.apm_alpha = nn.Parameter(torch.tensor(0.))
|
591 |
+
|
592 |
+
|
593 |
+
def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
594 |
+
):
|
595 |
+
if context is not None and context.shape[1]>1 and self.use_apm:
|
596 |
+
print("using APM CONTEXT !!!!")
|
597 |
+
context_svd = context[:,:1]
|
598 |
+
context_mixed = self.apm_conv(context)
|
599 |
+
context_mixed = self.apm_ln(context_mixed)
|
600 |
+
context = context_svd + context_mixed * F.silu(self.apm_alpha)
|
601 |
+
return super().forward(x=x,context=context,additional_tokens=additional_tokens,n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self)
|
602 |
+
|
603 |
+
|
604 |
+
class BasicTransformerSingleLayerBlock(nn.Module):
|
605 |
+
ATTENTION_MODES = {
|
606 |
+
"softmax": CrossAttention, # vanilla attention
|
607 |
+
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
|
608 |
+
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
609 |
+
}
|
610 |
+
|
611 |
+
def __init__(
|
612 |
+
self,
|
613 |
+
dim,
|
614 |
+
n_heads,
|
615 |
+
d_head,
|
616 |
+
dropout=0.0,
|
617 |
+
context_dim=None,
|
618 |
+
gated_ff=True,
|
619 |
+
checkpoint=True,
|
620 |
+
attn_mode="softmax",
|
621 |
+
):
|
622 |
+
super().__init__()
|
623 |
+
assert attn_mode in self.ATTENTION_MODES
|
624 |
+
attn_cls = self.ATTENTION_MODES[attn_mode]
|
625 |
+
self.attn1 = attn_cls(
|
626 |
+
query_dim=dim,
|
627 |
+
heads=n_heads,
|
628 |
+
dim_head=d_head,
|
629 |
+
dropout=dropout,
|
630 |
+
context_dim=context_dim,
|
631 |
+
)
|
632 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
633 |
+
self.norm1 = nn.LayerNorm(dim)
|
634 |
+
self.norm2 = nn.LayerNorm(dim)
|
635 |
+
self.checkpoint = checkpoint
|
636 |
+
|
637 |
+
def forward(self, x, context=None):
|
638 |
+
# inputs = {"x": x, "context": context}
|
639 |
+
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
640 |
+
return checkpoint(self._forward, x, context)
|
641 |
+
|
642 |
+
def _forward(self, x, context=None):
|
643 |
+
x = self.attn1(self.norm1(x), context=context) + x
|
644 |
+
x = self.ff(self.norm2(x)) + x
|
645 |
+
return x
|
646 |
+
|
647 |
+
|
648 |
+
class SpatialTransformer(nn.Module):
|
649 |
+
"""
|
650 |
+
Transformer block for image-like data.
|
651 |
+
First, project the input (aka embedding)
|
652 |
+
and reshape to b, t, d.
|
653 |
+
Then apply standard transformer action.
|
654 |
+
Finally, reshape to image
|
655 |
+
NEW: use_linear for more efficiency instead of the 1x1 convs
|
656 |
+
"""
|
657 |
+
|
658 |
+
def __init__(
|
659 |
+
self,
|
660 |
+
in_channels,
|
661 |
+
n_heads,
|
662 |
+
d_head,
|
663 |
+
depth=1,
|
664 |
+
dropout=0.0,
|
665 |
+
context_dim=None,
|
666 |
+
disable_self_attn=False,
|
667 |
+
use_linear=False,
|
668 |
+
attn_type="softmax",
|
669 |
+
use_checkpoint=True,
|
670 |
+
# sdp_backend=SDPBackend.FLASH_ATTENTION
|
671 |
+
sdp_backend=None,
|
672 |
+
use_apm:bool =False,
|
673 |
+
):
|
674 |
+
super().__init__()
|
675 |
+
logpy.debug(
|
676 |
+
f"constructing {self.__class__.__name__} of depth {depth} w/ "
|
677 |
+
f"{in_channels} channels and {n_heads} heads."
|
678 |
+
)
|
679 |
+
|
680 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
681 |
+
context_dim = [context_dim]
|
682 |
+
if exists(context_dim) and isinstance(context_dim, list):
|
683 |
+
if depth != len(context_dim):
|
684 |
+
logpy.warn(
|
685 |
+
f"{self.__class__.__name__}: Found context dims "
|
686 |
+
f"{context_dim} of depth {len(context_dim)}, which does not "
|
687 |
+
f"match the specified 'depth' of {depth}. Setting context_dim "
|
688 |
+
f"to {depth * [context_dim[0]]} now."
|
689 |
+
)
|
690 |
+
# depth does not match context dims.
|
691 |
+
assert all(
|
692 |
+
map(lambda x: x == context_dim[0], context_dim)
|
693 |
+
), "need homogenous context_dim to match depth automatically"
|
694 |
+
context_dim = depth * [context_dim[0]]
|
695 |
+
elif context_dim is None:
|
696 |
+
context_dim = [None] * depth
|
697 |
+
self.in_channels = in_channels
|
698 |
+
inner_dim = n_heads * d_head
|
699 |
+
self.norm = Normalize(in_channels)
|
700 |
+
if not use_linear:
|
701 |
+
self.proj_in = nn.Conv2d(
|
702 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
703 |
+
)
|
704 |
+
else:
|
705 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
706 |
+
|
707 |
+
if use_apm:
|
708 |
+
print("APM TRANSFORMER BLOCK")
|
709 |
+
self.transformer_blocks = nn.ModuleList(
|
710 |
+
[
|
711 |
+
BasicTransformerBlockWithAPM(
|
712 |
+
inner_dim,
|
713 |
+
n_heads,
|
714 |
+
d_head,
|
715 |
+
dropout=dropout,
|
716 |
+
context_dim=context_dim[d],
|
717 |
+
disable_self_attn=disable_self_attn,
|
718 |
+
attn_mode=attn_type,
|
719 |
+
checkpoint=use_checkpoint,
|
720 |
+
sdp_backend=sdp_backend,
|
721 |
+
use_apm=use_apm,
|
722 |
+
)
|
723 |
+
for d in range(depth)
|
724 |
+
]
|
725 |
+
)
|
726 |
+
else:
|
727 |
+
self.transformer_blocks = nn.ModuleList(
|
728 |
+
[
|
729 |
+
BasicTransformerBlock(
|
730 |
+
inner_dim,
|
731 |
+
n_heads,
|
732 |
+
d_head,
|
733 |
+
dropout=dropout,
|
734 |
+
context_dim=context_dim[d],
|
735 |
+
disable_self_attn=disable_self_attn,
|
736 |
+
attn_mode=attn_type,
|
737 |
+
checkpoint=use_checkpoint,
|
738 |
+
sdp_backend=sdp_backend,
|
739 |
+
)
|
740 |
+
for d in range(depth)
|
741 |
+
]
|
742 |
+
)
|
743 |
+
if not use_linear:
|
744 |
+
self.proj_out = zero_module(
|
745 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
746 |
+
)
|
747 |
+
else:
|
748 |
+
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
749 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
750 |
+
self.use_linear = use_linear
|
751 |
+
|
752 |
+
def forward(self, x, context=None):
|
753 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
754 |
+
if not isinstance(context, list):
|
755 |
+
context = [context]
|
756 |
+
b, c, h, w = x.shape
|
757 |
+
x_in = x
|
758 |
+
x = self.norm(x)
|
759 |
+
if not self.use_linear:
|
760 |
+
x = self.proj_in(x)
|
761 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
762 |
+
if self.use_linear:
|
763 |
+
x = self.proj_in(x)
|
764 |
+
for i, block in enumerate(self.transformer_blocks):
|
765 |
+
if i > 0 and len(context) == 1:
|
766 |
+
i = 0 # use same context for each block
|
767 |
+
x = block(x, context=context[i])
|
768 |
+
if self.use_linear:
|
769 |
+
x = self.proj_out(x)
|
770 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
771 |
+
if not self.use_linear:
|
772 |
+
x = self.proj_out(x)
|
773 |
+
return x + x_in
|
774 |
+
|
775 |
+
|
776 |
+
class SimpleTransformer(nn.Module):
|
777 |
+
def __init__(
|
778 |
+
self,
|
779 |
+
dim: int,
|
780 |
+
depth: int,
|
781 |
+
heads: int,
|
782 |
+
dim_head: int,
|
783 |
+
context_dim: Optional[int] = None,
|
784 |
+
dropout: float = 0.0,
|
785 |
+
checkpoint: bool = True,
|
786 |
+
):
|
787 |
+
super().__init__()
|
788 |
+
self.layers = nn.ModuleList([])
|
789 |
+
for _ in range(depth):
|
790 |
+
self.layers.append(
|
791 |
+
BasicTransformerBlock(
|
792 |
+
dim,
|
793 |
+
heads,
|
794 |
+
dim_head,
|
795 |
+
dropout=dropout,
|
796 |
+
context_dim=context_dim,
|
797 |
+
attn_mode="softmax-xformers",
|
798 |
+
checkpoint=checkpoint,
|
799 |
+
)
|
800 |
+
)
|
801 |
+
|
802 |
+
def forward(
|
803 |
+
self,
|
804 |
+
x: torch.Tensor,
|
805 |
+
context: Optional[torch.Tensor] = None,
|
806 |
+
) -> torch.Tensor:
|
807 |
+
for layer in self.layers:
|
808 |
+
x = layer(x, context)
|
809 |
+
return x
|
models/svd/sgm/modules/autoencoding/__init__.py
ADDED
File without changes
|