working.
Browse files- xora/examples/image_to_video.py +87 -0
- xora/models/autoencoders/causal_video_autoencoder.py +3 -1
- xora/models/autoencoders/vae_encode.py +11 -41
- xora/models/autoencoders/video_autoencoder.py +912 -0
- xora/models/transformers/embeddings.py +125 -0
- xora/models/transformers/transformer3d.py +77 -4
- xora/pipelines/pipeline_video_pixart_alpha.py +181 -13
- xora/schedulers/rf.py +13 -4
- xora/utils/conditioning_method.py +7 -0
- xora/utils/dist_util.py +11 -0
xora/examples/image_to_video.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
3 |
+
from xora.models.transformers.transformer3d import Transformer3DModel
|
4 |
+
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
5 |
+
from xora.schedulers.rf import RectifiedFlowScheduler
|
6 |
+
from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
|
7 |
+
from pathlib import Path
|
8 |
+
from transformers import T5EncoderModel
|
9 |
+
|
10 |
+
|
11 |
+
model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
|
12 |
+
vae_local_path = Path("/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000")
|
13 |
+
dtype = torch.float32
|
14 |
+
vae = CausalVideoAutoencoder.from_pretrained(
|
15 |
+
pretrained_model_name_or_path=vae_local_path,
|
16 |
+
revision=False,
|
17 |
+
torch_dtype=torch.bfloat16,
|
18 |
+
load_in_8bit=False,
|
19 |
+
).cuda()
|
20 |
+
transformer_config_path = Path("/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json")
|
21 |
+
transformer_config = Transformer3DModel.load_config(transformer_config_path)
|
22 |
+
transformer = Transformer3DModel.from_config(transformer_config)
|
23 |
+
transformer_local_path = Path("/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-first-frame-cond-4k-seq/ckpt/01822000/model.pt")
|
24 |
+
transformer_ckpt_state_dict = torch.load(transformer_local_path)
|
25 |
+
transformer.load_state_dict(transformer_ckpt_state_dict, True)
|
26 |
+
transformer = transformer.cuda()
|
27 |
+
unet = transformer
|
28 |
+
scheduler_config_path = Path("/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json")
|
29 |
+
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
30 |
+
scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
|
31 |
+
patchifier = SymmetricPatchifier(patch_size=1)
|
32 |
+
# text_encoder = T5EncoderModel.from_pretrained("t5-v1_1-xxl")
|
33 |
+
|
34 |
+
submodel_dict = {
|
35 |
+
"unet": unet,
|
36 |
+
"transformer": transformer,
|
37 |
+
"patchifier": patchifier,
|
38 |
+
"text_encoder": None,
|
39 |
+
"scheduler": scheduler,
|
40 |
+
"vae": vae,
|
41 |
+
|
42 |
+
}
|
43 |
+
|
44 |
+
pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
|
45 |
+
safety_checker=None,
|
46 |
+
revision=None,
|
47 |
+
torch_dtype=dtype,
|
48 |
+
**submodel_dict,
|
49 |
+
)
|
50 |
+
|
51 |
+
num_inference_steps=20
|
52 |
+
num_images_per_prompt=2
|
53 |
+
guidance_scale=3
|
54 |
+
height=512
|
55 |
+
width=768
|
56 |
+
num_frames=57
|
57 |
+
frame_rate=25
|
58 |
+
# sample = {
|
59 |
+
# "prompt": "A cat", # (B, L, E)
|
60 |
+
# 'prompt_attention_mask': None, # (B , L)
|
61 |
+
# 'negative_prompt': "Ugly deformed",
|
62 |
+
# 'negative_prompt_attention_mask': None # (B , L)
|
63 |
+
# }
|
64 |
+
|
65 |
+
sample = torch.load("/opt/sample.pt")
|
66 |
+
for _, item in sample.items():
|
67 |
+
if item is not None:
|
68 |
+
item = item.cuda()
|
69 |
+
media_items = torch.load("/opt/sample_media.pt")
|
70 |
+
|
71 |
+
images = pipeline(
|
72 |
+
num_inference_steps=num_inference_steps,
|
73 |
+
num_images_per_prompt=num_images_per_prompt,
|
74 |
+
guidance_scale=guidance_scale,
|
75 |
+
generator=None,
|
76 |
+
output_type="pt",
|
77 |
+
callback_on_step_end=None,
|
78 |
+
height=height,
|
79 |
+
width=width,
|
80 |
+
num_frames=num_frames,
|
81 |
+
frame_rate=frame_rate,
|
82 |
+
**sample,
|
83 |
+
is_video=True,
|
84 |
+
vae_per_channel_normalize=True,
|
85 |
+
).images
|
86 |
+
|
87 |
+
print()
|
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
@@ -8,11 +8,13 @@ import torch
|
|
8 |
import numpy as np
|
9 |
from einops import rearrange
|
10 |
from torch import nn
|
|
|
11 |
|
12 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
13 |
from xora.models.autoencoders.pixel_norm import PixelNorm
|
14 |
from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
15 |
|
|
|
16 |
|
17 |
class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
18 |
@classmethod
|
@@ -138,7 +140,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
138 |
key = key.replace(k, v)
|
139 |
|
140 |
if "norm" in key and key not in model_keys:
|
141 |
-
|
142 |
continue
|
143 |
|
144 |
converted_state_dict[key] = value
|
|
|
8 |
import numpy as np
|
9 |
from einops import rearrange
|
10 |
from torch import nn
|
11 |
+
from diffusers.utils import logging
|
12 |
|
13 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
14 |
from xora.models.autoencoders.pixel_norm import PixelNorm
|
15 |
from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
16 |
|
17 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
18 |
|
19 |
class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
20 |
@classmethod
|
|
|
140 |
key = key.replace(k, v)
|
141 |
|
142 |
if "norm" in key and key not in model_keys:
|
143 |
+
logger.info(f"Removing key {key} from state_dict as it is not present in the model")
|
144 |
continue
|
145 |
|
146 |
converted_state_dict[key] = value
|
xora/models/autoencoders/vae_encode.py
CHANGED
@@ -1,44 +1,12 @@
|
|
1 |
import torch
|
2 |
-
from torch import nn
|
3 |
from diffusers import AutoencoderKL
|
4 |
from einops import rearrange
|
5 |
from torch import Tensor
|
6 |
-
from torch.nn import functional
|
7 |
|
8 |
|
9 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
10 |
-
|
11 |
-
|
12 |
-
def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
|
13 |
-
super().__init__()
|
14 |
-
stride: int = 2
|
15 |
-
self.padding = padding
|
16 |
-
self.in_channels = in_channels
|
17 |
-
self.dims = dims
|
18 |
-
self.conv = make_conv_nd(
|
19 |
-
dims=dims,
|
20 |
-
in_channels=in_channels,
|
21 |
-
out_channels=out_channels,
|
22 |
-
kernel_size=kernel_size,
|
23 |
-
stride=stride,
|
24 |
-
padding=padding,
|
25 |
-
)
|
26 |
-
|
27 |
-
def forward(self, x, downsample_in_time=True):
|
28 |
-
conv = self.conv
|
29 |
-
if self.padding == 0:
|
30 |
-
if self.dims == 2:
|
31 |
-
padding = (0, 1, 0, 1)
|
32 |
-
else:
|
33 |
-
padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
|
34 |
-
|
35 |
-
x = functional.pad(x, padding, mode="constant", value=0)
|
36 |
-
|
37 |
-
if self.dims == (2, 1) and not downsample_in_time:
|
38 |
-
return conv(x, skip_time_conv=True)
|
39 |
-
|
40 |
-
return conv(x)
|
41 |
-
|
42 |
|
43 |
|
44 |
def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
|
@@ -78,7 +46,7 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
|
|
78 |
if channels != 3:
|
79 |
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
80 |
|
81 |
-
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
|
82 |
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
83 |
if split_size > 1:
|
84 |
if len(media_items) % split_size != 0:
|
@@ -86,14 +54,16 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
|
|
86 |
encode_bs = len(media_items) // split_size
|
87 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
88 |
latents = []
|
|
|
89 |
for image_batch in media_items.split(encode_bs):
|
90 |
latents.append(vae.encode(image_batch).latent_dist.sample())
|
|
|
91 |
latents = torch.cat(latents, dim=0)
|
92 |
else:
|
93 |
latents = vae.encode(media_items).latent_dist.sample()
|
94 |
|
95 |
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
96 |
-
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
|
97 |
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
98 |
return latents
|
99 |
|
@@ -104,7 +74,7 @@ def vae_decode(
|
|
104 |
is_video_shaped = latents.dim() == 5
|
105 |
batch_size = latents.shape[0]
|
106 |
|
107 |
-
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
|
108 |
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
109 |
if split_size > 1:
|
110 |
if len(latents) % split_size != 0:
|
@@ -118,13 +88,13 @@ def vae_decode(
|
|
118 |
else:
|
119 |
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
|
120 |
|
121 |
-
if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
|
122 |
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
123 |
return images
|
124 |
|
125 |
|
126 |
def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor:
|
127 |
-
if isinstance(vae, (CausalVideoAutoencoder)):
|
128 |
*_, fl, hl, wl = latents.shape
|
129 |
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
130 |
latents = latents.to(vae.dtype)
|
@@ -148,7 +118,7 @@ def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
|
|
148 |
else:
|
149 |
down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)])
|
150 |
spatial = vae.config.patch_size * 2**down_blocks
|
151 |
-
temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae) else 1
|
152 |
|
153 |
return (temporal, spatial, spatial)
|
154 |
|
@@ -168,4 +138,4 @@ def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_no
|
|
168 |
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
169 |
if vae_per_channel_normalize
|
170 |
else latents / vae.config.scaling_factor
|
171 |
-
)
|
|
|
1 |
import torch
|
|
|
2 |
from diffusers import AutoencoderKL
|
3 |
from einops import rearrange
|
4 |
from torch import Tensor
|
|
|
5 |
|
6 |
|
7 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
8 |
+
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
|
9 |
+
import xora.utils.dist_util
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
|
|
|
46 |
if channels != 3:
|
47 |
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
48 |
|
49 |
+
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
50 |
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
51 |
if split_size > 1:
|
52 |
if len(media_items) % split_size != 0:
|
|
|
54 |
encode_bs = len(media_items) // split_size
|
55 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
56 |
latents = []
|
57 |
+
dist_util.execute_graph()
|
58 |
for image_batch in media_items.split(encode_bs):
|
59 |
latents.append(vae.encode(image_batch).latent_dist.sample())
|
60 |
+
dist_util.execute_graph()
|
61 |
latents = torch.cat(latents, dim=0)
|
62 |
else:
|
63 |
latents = vae.encode(media_items).latent_dist.sample()
|
64 |
|
65 |
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
66 |
+
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
67 |
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
68 |
return latents
|
69 |
|
|
|
74 |
is_video_shaped = latents.dim() == 5
|
75 |
batch_size = latents.shape[0]
|
76 |
|
77 |
+
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
78 |
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
79 |
if split_size > 1:
|
80 |
if len(latents) % split_size != 0:
|
|
|
88 |
else:
|
89 |
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
|
90 |
|
91 |
+
if is_video_shaped and not isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
92 |
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
93 |
return images
|
94 |
|
95 |
|
96 |
def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor:
|
97 |
+
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
98 |
*_, fl, hl, wl = latents.shape
|
99 |
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
100 |
latents = latents.to(vae.dtype)
|
|
|
118 |
else:
|
119 |
down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)])
|
120 |
spatial = vae.config.patch_size * 2**down_blocks
|
121 |
+
temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae, VideoAutoencoder) else 1
|
122 |
|
123 |
return (temporal, spatial, spatial)
|
124 |
|
|
|
138 |
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
139 |
if vae_per_channel_normalize
|
140 |
else latents / vae.config.scaling_factor
|
141 |
+
)
|
xora/models/autoencoders/video_autoencoder.py
ADDED
@@ -0,0 +1,912 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from functools import partial
|
4 |
+
from types import SimpleNamespace
|
5 |
+
from typing import Any, Mapping, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from einops import rearrange
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import functional
|
11 |
+
|
12 |
+
from diffusers.utils import logging
|
13 |
+
|
14 |
+
from txt2img.models.layers.nn import Identity
|
15 |
+
from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
16 |
+
from xora.models.autoencoders.pixel_norm import PixelNorm
|
17 |
+
from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
18 |
+
|
19 |
+
logger = logging.get_logger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class VideoAutoencoder(AutoencoderKLWrapper):
|
23 |
+
@classmethod
|
24 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *args, **kwargs):
|
25 |
+
config_local_path = pretrained_model_name_or_path / "config.json"
|
26 |
+
config = cls.load_config(config_local_path, **kwargs)
|
27 |
+
video_vae = cls.from_config(config)
|
28 |
+
video_vae.to(kwargs["torch_dtype"])
|
29 |
+
|
30 |
+
model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
|
31 |
+
ckpt_state_dict = torch.load(model_local_path)
|
32 |
+
video_vae.load_state_dict(ckpt_state_dict)
|
33 |
+
|
34 |
+
statistics_local_path = pretrained_model_name_or_path / "per_channel_statistics.json"
|
35 |
+
if statistics_local_path.exists():
|
36 |
+
with open(statistics_local_path, "r") as file:
|
37 |
+
data = json.load(file)
|
38 |
+
transposed_data = list(zip(*data["data"]))
|
39 |
+
data_dict = {col: torch.tensor(vals) for col, vals in zip(data["columns"], transposed_data)}
|
40 |
+
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
41 |
+
video_vae.register_buffer(
|
42 |
+
"mean_of_means", data_dict.get("mean-of-means", torch.zeros_like(data_dict["std-of-means"]))
|
43 |
+
)
|
44 |
+
|
45 |
+
return video_vae
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def from_config(config):
|
49 |
+
assert config["_class_name"] == "VideoAutoencoder", "config must have _class_name=VideoAutoencoder"
|
50 |
+
if isinstance(config["dims"], list):
|
51 |
+
config["dims"] = tuple(config["dims"])
|
52 |
+
|
53 |
+
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
54 |
+
|
55 |
+
double_z = config.get("double_z", True)
|
56 |
+
latent_log_var = config.get("latent_log_var", "per_channel" if double_z else "none")
|
57 |
+
use_quant_conv = config.get("use_quant_conv", True)
|
58 |
+
|
59 |
+
if use_quant_conv and latent_log_var == "uniform":
|
60 |
+
raise ValueError("uniform latent_log_var requires use_quant_conv=False")
|
61 |
+
|
62 |
+
encoder = Encoder(
|
63 |
+
dims=config["dims"],
|
64 |
+
in_channels=config.get("in_channels", 3),
|
65 |
+
out_channels=config["latent_channels"],
|
66 |
+
block_out_channels=config["block_out_channels"],
|
67 |
+
patch_size=config.get("patch_size", 1),
|
68 |
+
latent_log_var=latent_log_var,
|
69 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
70 |
+
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
|
71 |
+
add_channel_padding=config.get("add_channel_padding", False),
|
72 |
+
)
|
73 |
+
|
74 |
+
decoder = Decoder(
|
75 |
+
dims=config["dims"],
|
76 |
+
in_channels=config["latent_channels"],
|
77 |
+
out_channels=config.get("out_channels", 3),
|
78 |
+
block_out_channels=config["block_out_channels"],
|
79 |
+
patch_size=config.get("patch_size", 1),
|
80 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
81 |
+
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
|
82 |
+
add_channel_padding=config.get("add_channel_padding", False),
|
83 |
+
)
|
84 |
+
|
85 |
+
dims = config["dims"]
|
86 |
+
return VideoAutoencoder(
|
87 |
+
encoder=encoder,
|
88 |
+
decoder=decoder,
|
89 |
+
latent_channels=config["latent_channels"],
|
90 |
+
dims=dims,
|
91 |
+
use_quant_conv=use_quant_conv,
|
92 |
+
)
|
93 |
+
|
94 |
+
@property
|
95 |
+
def config(self):
|
96 |
+
return SimpleNamespace(
|
97 |
+
_class_name="VideoAutoencoder",
|
98 |
+
dims=self.dims,
|
99 |
+
in_channels=self.encoder.conv_in.in_channels // (self.encoder.patch_size_t * self.encoder.patch_size**2),
|
100 |
+
out_channels=self.decoder.conv_out.out_channels // (self.decoder.patch_size_t * self.decoder.patch_size**2),
|
101 |
+
latent_channels=self.decoder.conv_in.in_channels,
|
102 |
+
block_out_channels=[
|
103 |
+
self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
|
104 |
+
for i in range(len(self.encoder.down_blocks))
|
105 |
+
],
|
106 |
+
scaling_factor=1.0,
|
107 |
+
norm_layer=self.encoder.norm_layer,
|
108 |
+
patch_size=self.encoder.patch_size,
|
109 |
+
latent_log_var=self.encoder.latent_log_var,
|
110 |
+
use_quant_conv=self.use_quant_conv,
|
111 |
+
patch_size_t=self.encoder.patch_size_t,
|
112 |
+
add_channel_padding=self.encoder.add_channel_padding,
|
113 |
+
)
|
114 |
+
|
115 |
+
@property
|
116 |
+
def is_video_supported(self):
|
117 |
+
"""
|
118 |
+
Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
|
119 |
+
"""
|
120 |
+
return self.dims != 2
|
121 |
+
|
122 |
+
@property
|
123 |
+
def downscale_factor(self):
|
124 |
+
return self.encoder.downsample_factor
|
125 |
+
|
126 |
+
def to_json_string(self) -> str:
|
127 |
+
import json
|
128 |
+
|
129 |
+
return json.dumps(self.config.__dict__)
|
130 |
+
|
131 |
+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
132 |
+
model_keys = set(name for name, _ in self.named_parameters())
|
133 |
+
|
134 |
+
key_mapping = {
|
135 |
+
".resnets.": ".res_blocks.",
|
136 |
+
"downsamplers.0": "downsample",
|
137 |
+
"upsamplers.0": "upsample",
|
138 |
+
}
|
139 |
+
|
140 |
+
converted_state_dict = {}
|
141 |
+
for key, value in state_dict.items():
|
142 |
+
for k, v in key_mapping.items():
|
143 |
+
key = key.replace(k, v)
|
144 |
+
|
145 |
+
if "norm" in key and key not in model_keys:
|
146 |
+
logger.info(f"Removing key {key} from state_dict as it is not present in the model")
|
147 |
+
continue
|
148 |
+
|
149 |
+
converted_state_dict[key] = value
|
150 |
+
|
151 |
+
super().load_state_dict(converted_state_dict, strict=strict)
|
152 |
+
|
153 |
+
def last_layer(self):
|
154 |
+
if hasattr(self.decoder, "conv_out"):
|
155 |
+
if isinstance(self.decoder.conv_out, nn.Sequential):
|
156 |
+
last_layer = self.decoder.conv_out[-1]
|
157 |
+
else:
|
158 |
+
last_layer = self.decoder.conv_out
|
159 |
+
else:
|
160 |
+
last_layer = self.decoder.layers[-1]
|
161 |
+
return last_layer
|
162 |
+
|
163 |
+
|
164 |
+
class Encoder(nn.Module):
|
165 |
+
r"""
|
166 |
+
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
in_channels (`int`, *optional*, defaults to 3):
|
170 |
+
The number of input channels.
|
171 |
+
out_channels (`int`, *optional*, defaults to 3):
|
172 |
+
The number of output channels.
|
173 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
174 |
+
The number of output channels for each block.
|
175 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
176 |
+
The number of layers per block.
|
177 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
178 |
+
The number of groups for normalization.
|
179 |
+
patch_size (`int`, *optional*, defaults to 1):
|
180 |
+
The patch size to use. Should be a power of 2.
|
181 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
182 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
183 |
+
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
184 |
+
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
dims: Union[int, Tuple[int, int]] = 3,
|
190 |
+
in_channels: int = 3,
|
191 |
+
out_channels: int = 3,
|
192 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
193 |
+
layers_per_block: int = 2,
|
194 |
+
norm_num_groups: int = 32,
|
195 |
+
patch_size: Union[int, Tuple[int]] = 1,
|
196 |
+
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
197 |
+
latent_log_var: str = "per_channel",
|
198 |
+
patch_size_t: Optional[int] = None,
|
199 |
+
add_channel_padding: Optional[bool] = False,
|
200 |
+
):
|
201 |
+
super().__init__()
|
202 |
+
self.patch_size = patch_size
|
203 |
+
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
|
204 |
+
self.add_channel_padding = add_channel_padding
|
205 |
+
self.layers_per_block = layers_per_block
|
206 |
+
self.norm_layer = norm_layer
|
207 |
+
self.latent_channels = out_channels
|
208 |
+
self.latent_log_var = latent_log_var
|
209 |
+
if add_channel_padding:
|
210 |
+
in_channels = in_channels * self.patch_size**3
|
211 |
+
else:
|
212 |
+
in_channels = in_channels * self.patch_size_t * self.patch_size**2
|
213 |
+
self.in_channels = in_channels
|
214 |
+
output_channel = block_out_channels[0]
|
215 |
+
|
216 |
+
self.conv_in = make_conv_nd(
|
217 |
+
dims=dims,
|
218 |
+
in_channels=in_channels,
|
219 |
+
out_channels=output_channel,
|
220 |
+
kernel_size=3,
|
221 |
+
stride=1,
|
222 |
+
padding=1,
|
223 |
+
)
|
224 |
+
|
225 |
+
self.down_blocks = nn.ModuleList([])
|
226 |
+
|
227 |
+
for i in range(len(block_out_channels)):
|
228 |
+
input_channel = output_channel
|
229 |
+
output_channel = block_out_channels[i]
|
230 |
+
is_final_block = i == len(block_out_channels) - 1
|
231 |
+
|
232 |
+
down_block = DownEncoderBlock3D(
|
233 |
+
dims=dims,
|
234 |
+
in_channels=input_channel,
|
235 |
+
out_channels=output_channel,
|
236 |
+
num_layers=self.layers_per_block,
|
237 |
+
add_downsample=not is_final_block and 2**i >= patch_size,
|
238 |
+
resnet_eps=1e-6,
|
239 |
+
downsample_padding=0,
|
240 |
+
resnet_groups=norm_num_groups,
|
241 |
+
norm_layer=norm_layer,
|
242 |
+
)
|
243 |
+
self.down_blocks.append(down_block)
|
244 |
+
|
245 |
+
self.mid_block = UNetMidBlock3D(
|
246 |
+
dims=dims,
|
247 |
+
in_channels=block_out_channels[-1],
|
248 |
+
num_layers=self.layers_per_block,
|
249 |
+
resnet_eps=1e-6,
|
250 |
+
resnet_groups=norm_num_groups,
|
251 |
+
norm_layer=norm_layer,
|
252 |
+
)
|
253 |
+
|
254 |
+
# out
|
255 |
+
if norm_layer == "group_norm":
|
256 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
257 |
+
elif norm_layer == "pixel_norm":
|
258 |
+
self.conv_norm_out = PixelNorm()
|
259 |
+
self.conv_act = nn.SiLU()
|
260 |
+
|
261 |
+
conv_out_channels = out_channels
|
262 |
+
if latent_log_var == "per_channel":
|
263 |
+
conv_out_channels *= 2
|
264 |
+
elif latent_log_var == "uniform":
|
265 |
+
conv_out_channels += 1
|
266 |
+
elif latent_log_var != "none":
|
267 |
+
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
268 |
+
self.conv_out = make_conv_nd(dims, block_out_channels[-1], conv_out_channels, 3, padding=1)
|
269 |
+
|
270 |
+
self.gradient_checkpointing = False
|
271 |
+
|
272 |
+
@property
|
273 |
+
def downscale_factor(self):
|
274 |
+
return (
|
275 |
+
2 ** len([block for block in self.down_blocks if isinstance(block.downsample, Downsample3D)])
|
276 |
+
* self.patch_size
|
277 |
+
)
|
278 |
+
|
279 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
280 |
+
r"""The forward method of the `Encoder` class."""
|
281 |
+
|
282 |
+
downsample_in_time = sample.shape[2] != 1
|
283 |
+
|
284 |
+
# patchify
|
285 |
+
patch_size_t = self.patch_size_t if downsample_in_time else 1
|
286 |
+
sample = patchify(
|
287 |
+
sample,
|
288 |
+
patch_size_hw=self.patch_size,
|
289 |
+
patch_size_t=patch_size_t,
|
290 |
+
add_channel_padding=self.add_channel_padding,
|
291 |
+
)
|
292 |
+
|
293 |
+
sample = self.conv_in(sample)
|
294 |
+
|
295 |
+
checkpoint_fn = (
|
296 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
297 |
+
if self.gradient_checkpointing and self.training
|
298 |
+
else lambda x: x
|
299 |
+
)
|
300 |
+
|
301 |
+
for down_block in self.down_blocks:
|
302 |
+
sample = checkpoint_fn(down_block)(sample, downsample_in_time=downsample_in_time)
|
303 |
+
|
304 |
+
sample = checkpoint_fn(self.mid_block)(sample)
|
305 |
+
|
306 |
+
# post-process
|
307 |
+
sample = self.conv_norm_out(sample)
|
308 |
+
sample = self.conv_act(sample)
|
309 |
+
sample = self.conv_out(sample)
|
310 |
+
|
311 |
+
if self.latent_log_var == "uniform":
|
312 |
+
last_channel = sample[:, -1:, ...]
|
313 |
+
num_dims = sample.dim()
|
314 |
+
|
315 |
+
if num_dims == 4:
|
316 |
+
# For shape (B, C, H, W)
|
317 |
+
repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1)
|
318 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
319 |
+
elif num_dims == 5:
|
320 |
+
# For shape (B, C, F, H, W)
|
321 |
+
repeated_last_channel = last_channel.repeat(1, sample.shape[1] - 2, 1, 1, 1)
|
322 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
323 |
+
else:
|
324 |
+
raise ValueError(f"Invalid input shape: {sample.shape}")
|
325 |
+
|
326 |
+
return sample
|
327 |
+
|
328 |
+
|
329 |
+
class Decoder(nn.Module):
|
330 |
+
r"""
|
331 |
+
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
in_channels (`int`, *optional*, defaults to 3):
|
335 |
+
The number of input channels.
|
336 |
+
out_channels (`int`, *optional*, defaults to 3):
|
337 |
+
The number of output channels.
|
338 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
339 |
+
The number of output channels for each block.
|
340 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
341 |
+
The number of layers per block.
|
342 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
343 |
+
The number of groups for normalization.
|
344 |
+
patch_size (`int`, *optional*, defaults to 1):
|
345 |
+
The patch size to use. Should be a power of 2.
|
346 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
347 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
348 |
+
"""
|
349 |
+
|
350 |
+
def __init__(
|
351 |
+
self,
|
352 |
+
dims,
|
353 |
+
in_channels: int = 3,
|
354 |
+
out_channels: int = 3,
|
355 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
356 |
+
layers_per_block: int = 2,
|
357 |
+
norm_num_groups: int = 32,
|
358 |
+
patch_size: int = 1,
|
359 |
+
norm_layer: str = "group_norm",
|
360 |
+
patch_size_t: Optional[int] = None,
|
361 |
+
add_channel_padding: Optional[bool] = False,
|
362 |
+
):
|
363 |
+
super().__init__()
|
364 |
+
self.patch_size = patch_size
|
365 |
+
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
|
366 |
+
self.add_channel_padding = add_channel_padding
|
367 |
+
self.layers_per_block = layers_per_block
|
368 |
+
if add_channel_padding:
|
369 |
+
out_channels = out_channels * self.patch_size**3
|
370 |
+
else:
|
371 |
+
out_channels = out_channels * self.patch_size_t * self.patch_size**2
|
372 |
+
self.out_channels = out_channels
|
373 |
+
|
374 |
+
self.conv_in = make_conv_nd(
|
375 |
+
dims,
|
376 |
+
in_channels,
|
377 |
+
block_out_channels[-1],
|
378 |
+
kernel_size=3,
|
379 |
+
stride=1,
|
380 |
+
padding=1,
|
381 |
+
)
|
382 |
+
|
383 |
+
self.mid_block = None
|
384 |
+
self.up_blocks = nn.ModuleList([])
|
385 |
+
|
386 |
+
self.mid_block = UNetMidBlock3D(
|
387 |
+
dims=dims,
|
388 |
+
in_channels=block_out_channels[-1],
|
389 |
+
num_layers=self.layers_per_block,
|
390 |
+
resnet_eps=1e-6,
|
391 |
+
resnet_groups=norm_num_groups,
|
392 |
+
norm_layer=norm_layer,
|
393 |
+
)
|
394 |
+
|
395 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
396 |
+
output_channel = reversed_block_out_channels[0]
|
397 |
+
for i in range(len(reversed_block_out_channels)):
|
398 |
+
prev_output_channel = output_channel
|
399 |
+
output_channel = reversed_block_out_channels[i]
|
400 |
+
|
401 |
+
is_final_block = i == len(block_out_channels) - 1
|
402 |
+
|
403 |
+
up_block = UpDecoderBlock3D(
|
404 |
+
dims=dims,
|
405 |
+
num_layers=self.layers_per_block + 1,
|
406 |
+
in_channels=prev_output_channel,
|
407 |
+
out_channels=output_channel,
|
408 |
+
add_upsample=not is_final_block and 2 ** (len(block_out_channels) - i - 1) > patch_size,
|
409 |
+
resnet_eps=1e-6,
|
410 |
+
resnet_groups=norm_num_groups,
|
411 |
+
norm_layer=norm_layer,
|
412 |
+
)
|
413 |
+
self.up_blocks.append(up_block)
|
414 |
+
|
415 |
+
if norm_layer == "group_norm":
|
416 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
417 |
+
elif norm_layer == "pixel_norm":
|
418 |
+
self.conv_norm_out = PixelNorm()
|
419 |
+
|
420 |
+
self.conv_act = nn.SiLU()
|
421 |
+
self.conv_out = make_conv_nd(dims, block_out_channels[0], out_channels, 3, padding=1)
|
422 |
+
|
423 |
+
self.gradient_checkpointing = False
|
424 |
+
|
425 |
+
def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
426 |
+
r"""The forward method of the `Decoder` class."""
|
427 |
+
assert target_shape is not None, "target_shape must be provided"
|
428 |
+
upsample_in_time = sample.shape[2] < target_shape[2]
|
429 |
+
|
430 |
+
sample = self.conv_in(sample)
|
431 |
+
|
432 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
433 |
+
|
434 |
+
checkpoint_fn = (
|
435 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
436 |
+
if self.gradient_checkpointing and self.training
|
437 |
+
else lambda x: x
|
438 |
+
)
|
439 |
+
|
440 |
+
sample = checkpoint_fn(self.mid_block)(sample)
|
441 |
+
sample = sample.to(upscale_dtype)
|
442 |
+
|
443 |
+
for up_block in self.up_blocks:
|
444 |
+
sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
|
445 |
+
|
446 |
+
# post-process
|
447 |
+
sample = self.conv_norm_out(sample)
|
448 |
+
sample = self.conv_act(sample)
|
449 |
+
sample = self.conv_out(sample)
|
450 |
+
|
451 |
+
# un-patchify
|
452 |
+
patch_size_t = self.patch_size_t if upsample_in_time else 1
|
453 |
+
sample = unpatchify(
|
454 |
+
sample,
|
455 |
+
patch_size_hw=self.patch_size,
|
456 |
+
patch_size_t=patch_size_t,
|
457 |
+
add_channel_padding=self.add_channel_padding,
|
458 |
+
)
|
459 |
+
|
460 |
+
return sample
|
461 |
+
|
462 |
+
|
463 |
+
class DownEncoderBlock3D(nn.Module):
|
464 |
+
def __init__(
|
465 |
+
self,
|
466 |
+
dims: Union[int, Tuple[int, int]],
|
467 |
+
in_channels: int,
|
468 |
+
out_channels: int,
|
469 |
+
dropout: float = 0.0,
|
470 |
+
num_layers: int = 1,
|
471 |
+
resnet_eps: float = 1e-6,
|
472 |
+
resnet_groups: int = 32,
|
473 |
+
add_downsample: bool = True,
|
474 |
+
downsample_padding: int = 1,
|
475 |
+
norm_layer: str = "group_norm",
|
476 |
+
):
|
477 |
+
super().__init__()
|
478 |
+
res_blocks = []
|
479 |
+
|
480 |
+
for i in range(num_layers):
|
481 |
+
in_channels = in_channels if i == 0 else out_channels
|
482 |
+
res_blocks.append(
|
483 |
+
ResnetBlock3D(
|
484 |
+
dims=dims,
|
485 |
+
in_channels=in_channels,
|
486 |
+
out_channels=out_channels,
|
487 |
+
eps=resnet_eps,
|
488 |
+
groups=resnet_groups,
|
489 |
+
dropout=dropout,
|
490 |
+
norm_layer=norm_layer,
|
491 |
+
)
|
492 |
+
)
|
493 |
+
|
494 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
495 |
+
|
496 |
+
if add_downsample:
|
497 |
+
self.downsample = Downsample3D(dims, out_channels, out_channels=out_channels, padding=downsample_padding)
|
498 |
+
else:
|
499 |
+
self.downsample = Identity()
|
500 |
+
|
501 |
+
def forward(self, hidden_states: torch.FloatTensor, downsample_in_time) -> torch.FloatTensor:
|
502 |
+
for resnet in self.res_blocks:
|
503 |
+
hidden_states = resnet(hidden_states)
|
504 |
+
|
505 |
+
hidden_states = self.downsample(hidden_states, downsample_in_time=downsample_in_time)
|
506 |
+
|
507 |
+
return hidden_states
|
508 |
+
|
509 |
+
|
510 |
+
class UNetMidBlock3D(nn.Module):
|
511 |
+
"""
|
512 |
+
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
513 |
+
|
514 |
+
Args:
|
515 |
+
in_channels (`int`): The number of input channels.
|
516 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
517 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
518 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
519 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
520 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
521 |
+
|
522 |
+
Returns:
|
523 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
524 |
+
in_channels, height, width)`.
|
525 |
+
|
526 |
+
"""
|
527 |
+
|
528 |
+
def __init__(
|
529 |
+
self,
|
530 |
+
dims: Union[int, Tuple[int, int]],
|
531 |
+
in_channels: int,
|
532 |
+
dropout: float = 0.0,
|
533 |
+
num_layers: int = 1,
|
534 |
+
resnet_eps: float = 1e-6,
|
535 |
+
resnet_groups: int = 32,
|
536 |
+
norm_layer: str = "group_norm",
|
537 |
+
):
|
538 |
+
super().__init__()
|
539 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
540 |
+
|
541 |
+
self.res_blocks = nn.ModuleList(
|
542 |
+
[
|
543 |
+
ResnetBlock3D(
|
544 |
+
dims=dims,
|
545 |
+
in_channels=in_channels,
|
546 |
+
out_channels=in_channels,
|
547 |
+
eps=resnet_eps,
|
548 |
+
groups=resnet_groups,
|
549 |
+
dropout=dropout,
|
550 |
+
norm_layer=norm_layer,
|
551 |
+
)
|
552 |
+
for _ in range(num_layers)
|
553 |
+
]
|
554 |
+
)
|
555 |
+
|
556 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
557 |
+
for resnet in self.res_blocks:
|
558 |
+
hidden_states = resnet(hidden_states)
|
559 |
+
|
560 |
+
return hidden_states
|
561 |
+
|
562 |
+
|
563 |
+
class UpDecoderBlock3D(nn.Module):
|
564 |
+
def __init__(
|
565 |
+
self,
|
566 |
+
dims: Union[int, Tuple[int, int]],
|
567 |
+
in_channels: int,
|
568 |
+
out_channels: int,
|
569 |
+
resolution_idx: Optional[int] = None,
|
570 |
+
dropout: float = 0.0,
|
571 |
+
num_layers: int = 1,
|
572 |
+
resnet_eps: float = 1e-6,
|
573 |
+
resnet_groups: int = 32,
|
574 |
+
add_upsample: bool = True,
|
575 |
+
norm_layer: str = "group_norm",
|
576 |
+
):
|
577 |
+
super().__init__()
|
578 |
+
res_blocks = []
|
579 |
+
|
580 |
+
for i in range(num_layers):
|
581 |
+
input_channels = in_channels if i == 0 else out_channels
|
582 |
+
|
583 |
+
res_blocks.append(
|
584 |
+
ResnetBlock3D(
|
585 |
+
dims=dims,
|
586 |
+
in_channels=input_channels,
|
587 |
+
out_channels=out_channels,
|
588 |
+
eps=resnet_eps,
|
589 |
+
groups=resnet_groups,
|
590 |
+
dropout=dropout,
|
591 |
+
norm_layer=norm_layer,
|
592 |
+
)
|
593 |
+
)
|
594 |
+
|
595 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
596 |
+
|
597 |
+
if add_upsample:
|
598 |
+
self.upsample = Upsample3D(dims=dims, channels=out_channels, out_channels=out_channels)
|
599 |
+
else:
|
600 |
+
self.upsample = Identity()
|
601 |
+
|
602 |
+
self.resolution_idx = resolution_idx
|
603 |
+
|
604 |
+
def forward(self, hidden_states: torch.FloatTensor, upsample_in_time=True) -> torch.FloatTensor:
|
605 |
+
for resnet in self.res_blocks:
|
606 |
+
hidden_states = resnet(hidden_states)
|
607 |
+
|
608 |
+
hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)
|
609 |
+
|
610 |
+
return hidden_states
|
611 |
+
|
612 |
+
|
613 |
+
class ResnetBlock3D(nn.Module):
|
614 |
+
r"""
|
615 |
+
A Resnet block.
|
616 |
+
|
617 |
+
Parameters:
|
618 |
+
in_channels (`int`): The number of channels in the input.
|
619 |
+
out_channels (`int`, *optional*, default to be `None`):
|
620 |
+
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
621 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
622 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
623 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
624 |
+
"""
|
625 |
+
|
626 |
+
def __init__(
|
627 |
+
self,
|
628 |
+
dims: Union[int, Tuple[int, int]],
|
629 |
+
in_channels: int,
|
630 |
+
out_channels: Optional[int] = None,
|
631 |
+
conv_shortcut: bool = False,
|
632 |
+
dropout: float = 0.0,
|
633 |
+
groups: int = 32,
|
634 |
+
eps: float = 1e-6,
|
635 |
+
norm_layer: str = "group_norm",
|
636 |
+
):
|
637 |
+
super().__init__()
|
638 |
+
self.in_channels = in_channels
|
639 |
+
out_channels = in_channels if out_channels is None else out_channels
|
640 |
+
self.out_channels = out_channels
|
641 |
+
self.use_conv_shortcut = conv_shortcut
|
642 |
+
|
643 |
+
if norm_layer == "group_norm":
|
644 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
645 |
+
elif norm_layer == "pixel_norm":
|
646 |
+
self.norm1 = PixelNorm()
|
647 |
+
|
648 |
+
self.non_linearity = nn.SiLU()
|
649 |
+
|
650 |
+
self.conv1 = make_conv_nd(dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
651 |
+
|
652 |
+
if norm_layer == "group_norm":
|
653 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
654 |
+
elif norm_layer == "pixel_norm":
|
655 |
+
self.norm2 = PixelNorm()
|
656 |
+
|
657 |
+
self.dropout = torch.nn.Dropout(dropout)
|
658 |
+
|
659 |
+
self.conv2 = make_conv_nd(dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
660 |
+
|
661 |
+
self.conv_shortcut = (
|
662 |
+
make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels)
|
663 |
+
if in_channels != out_channels
|
664 |
+
else nn.Identity()
|
665 |
+
)
|
666 |
+
|
667 |
+
def forward(
|
668 |
+
self,
|
669 |
+
input_tensor: torch.FloatTensor,
|
670 |
+
) -> torch.FloatTensor:
|
671 |
+
hidden_states = input_tensor
|
672 |
+
|
673 |
+
hidden_states = self.norm1(hidden_states)
|
674 |
+
|
675 |
+
hidden_states = self.non_linearity(hidden_states)
|
676 |
+
|
677 |
+
hidden_states = self.conv1(hidden_states)
|
678 |
+
|
679 |
+
hidden_states = self.norm2(hidden_states)
|
680 |
+
|
681 |
+
hidden_states = self.non_linearity(hidden_states)
|
682 |
+
|
683 |
+
hidden_states = self.dropout(hidden_states)
|
684 |
+
|
685 |
+
hidden_states = self.conv2(hidden_states)
|
686 |
+
|
687 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
688 |
+
|
689 |
+
output_tensor = input_tensor + hidden_states
|
690 |
+
|
691 |
+
return output_tensor
|
692 |
+
|
693 |
+
|
694 |
+
class Downsample3D(nn.Module):
|
695 |
+
def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
|
696 |
+
super().__init__()
|
697 |
+
stride: int = 2
|
698 |
+
self.padding = padding
|
699 |
+
self.in_channels = in_channels
|
700 |
+
self.dims = dims
|
701 |
+
self.conv = make_conv_nd(
|
702 |
+
dims=dims,
|
703 |
+
in_channels=in_channels,
|
704 |
+
out_channels=out_channels,
|
705 |
+
kernel_size=kernel_size,
|
706 |
+
stride=stride,
|
707 |
+
padding=padding,
|
708 |
+
)
|
709 |
+
|
710 |
+
def forward(self, x, downsample_in_time=True):
|
711 |
+
conv = self.conv
|
712 |
+
if self.padding == 0:
|
713 |
+
if self.dims == 2:
|
714 |
+
padding = (0, 1, 0, 1)
|
715 |
+
else:
|
716 |
+
padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
|
717 |
+
|
718 |
+
x = functional.pad(x, padding, mode="constant", value=0)
|
719 |
+
|
720 |
+
if self.dims == (2, 1) and not downsample_in_time:
|
721 |
+
return conv(x, skip_time_conv=True)
|
722 |
+
|
723 |
+
return conv(x)
|
724 |
+
|
725 |
+
|
726 |
+
class Upsample3D(nn.Module):
|
727 |
+
"""
|
728 |
+
An upsampling layer for 3D tensors of shape (B, C, D, H, W).
|
729 |
+
|
730 |
+
:param channels: channels in the inputs and outputs.
|
731 |
+
"""
|
732 |
+
|
733 |
+
def __init__(self, dims, channels, out_channels=None):
|
734 |
+
super().__init__()
|
735 |
+
self.dims = dims
|
736 |
+
self.channels = channels
|
737 |
+
self.out_channels = out_channels or channels
|
738 |
+
self.conv = make_conv_nd(dims, channels, out_channels, kernel_size=3, padding=1, bias=True)
|
739 |
+
|
740 |
+
def forward(self, x, upsample_in_time):
|
741 |
+
if self.dims == 2:
|
742 |
+
x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest")
|
743 |
+
else:
|
744 |
+
time_scale_factor = 2 if upsample_in_time else 1
|
745 |
+
# print("before:", x.shape)
|
746 |
+
b, c, d, h, w = x.shape
|
747 |
+
x = rearrange(x, "b c d h w -> (b d) c h w")
|
748 |
+
# height and width interpolate
|
749 |
+
x = functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest")
|
750 |
+
_, _, h, w = x.shape
|
751 |
+
|
752 |
+
if not upsample_in_time and self.dims == (2, 1):
|
753 |
+
x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
|
754 |
+
return self.conv(x, skip_time_conv=True)
|
755 |
+
|
756 |
+
# Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
|
757 |
+
x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)
|
758 |
+
|
759 |
+
# (b h w) c 1 d
|
760 |
+
new_d = x.shape[-1] * time_scale_factor
|
761 |
+
x = functional.interpolate(x, (1, new_d), mode="nearest")
|
762 |
+
# (b h w) c 1 new_d
|
763 |
+
x = rearrange(x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d)
|
764 |
+
# b c d h w
|
765 |
+
|
766 |
+
# x = functional.interpolate(
|
767 |
+
# x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
768 |
+
# )
|
769 |
+
# print("after:", x.shape)
|
770 |
+
|
771 |
+
return self.conv(x)
|
772 |
+
|
773 |
+
|
774 |
+
def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
775 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
776 |
+
return x
|
777 |
+
if x.dim() == 4:
|
778 |
+
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw)
|
779 |
+
elif x.dim() == 5:
|
780 |
+
x = rearrange(x, "b c (f p) (h q) (w r) -> b (c p r q) f h w", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
|
781 |
+
else:
|
782 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
783 |
+
|
784 |
+
if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding):
|
785 |
+
channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
|
786 |
+
padding_zeros = torch.zeros(
|
787 |
+
x.shape[0],
|
788 |
+
channels_to_pad,
|
789 |
+
x.shape[2],
|
790 |
+
x.shape[3],
|
791 |
+
x.shape[4],
|
792 |
+
device=x.device,
|
793 |
+
dtype=x.dtype,
|
794 |
+
)
|
795 |
+
x = torch.cat([padding_zeros, x], dim=1)
|
796 |
+
|
797 |
+
return x
|
798 |
+
|
799 |
+
|
800 |
+
def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
801 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
802 |
+
return x
|
803 |
+
|
804 |
+
if (x.dim() == 5) and (patch_size_hw > patch_size_t) and (patch_size_t > 1 or add_channel_padding):
|
805 |
+
channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
|
806 |
+
x = x[:, :channels_to_keep, :, :, :]
|
807 |
+
|
808 |
+
if x.dim() == 4:
|
809 |
+
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw)
|
810 |
+
elif x.dim() == 5:
|
811 |
+
x = rearrange(x, "b (c p r q) f h w -> b c (f p) (h q) (w r)", p=patch_size_t, q=patch_size_hw, r=patch_size_hw)
|
812 |
+
|
813 |
+
return x
|
814 |
+
|
815 |
+
|
816 |
+
def create_video_autoencoder_config(
|
817 |
+
latent_channels: int = 4,
|
818 |
+
):
|
819 |
+
config = {
|
820 |
+
"_class_name": "VideoAutoencoder",
|
821 |
+
"dims": (2, 1), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
822 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
823 |
+
"out_channels": 3, # Number of output color channels
|
824 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
825 |
+
"block_out_channels": [128, 256, 512, 512], # Number of output channels of each encoder / decoder inner block
|
826 |
+
"patch_size": 1,
|
827 |
+
}
|
828 |
+
|
829 |
+
return config
|
830 |
+
|
831 |
+
|
832 |
+
def create_video_autoencoder_pathify4x4x4_config(
|
833 |
+
latent_channels: int = 4,
|
834 |
+
):
|
835 |
+
config = {
|
836 |
+
"_class_name": "VideoAutoencoder",
|
837 |
+
"dims": (2, 1), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
838 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
839 |
+
"out_channels": 3, # Number of output color channels
|
840 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
841 |
+
"block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block
|
842 |
+
"patch_size": 4,
|
843 |
+
"latent_log_var": "uniform",
|
844 |
+
}
|
845 |
+
|
846 |
+
return config
|
847 |
+
|
848 |
+
|
849 |
+
def create_video_autoencoder_pathify4x4_config(
|
850 |
+
latent_channels: int = 4,
|
851 |
+
):
|
852 |
+
config = {
|
853 |
+
"_class_name": "VideoAutoencoder",
|
854 |
+
"dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
855 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
856 |
+
"out_channels": 3, # Number of output color channels
|
857 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
858 |
+
"block_out_channels": [512] * 4, # Number of output channels of each encoder / decoder inner block
|
859 |
+
"patch_size": 4,
|
860 |
+
"norm_layer": "pixel_norm",
|
861 |
+
}
|
862 |
+
|
863 |
+
return config
|
864 |
+
|
865 |
+
|
866 |
+
def test_vae_patchify_unpatchify():
|
867 |
+
import torch
|
868 |
+
|
869 |
+
x = torch.randn(2, 3, 8, 64, 64)
|
870 |
+
x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
|
871 |
+
x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
|
872 |
+
assert torch.allclose(x, x_unpatched)
|
873 |
+
|
874 |
+
|
875 |
+
def demo_video_autoencoder_forward_backward():
|
876 |
+
# Configuration for the VideoAutoencoder
|
877 |
+
config = create_video_autoencoder_pathify4x4x4_config()
|
878 |
+
|
879 |
+
# Instantiate the VideoAutoencoder with the specified configuration
|
880 |
+
video_autoencoder = VideoAutoencoder.from_config(config)
|
881 |
+
|
882 |
+
print(video_autoencoder)
|
883 |
+
|
884 |
+
# Print the total number of parameters in the video autoencoder
|
885 |
+
total_params = sum(p.numel() for p in video_autoencoder.parameters())
|
886 |
+
print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
|
887 |
+
|
888 |
+
# Create a mock input tensor simulating a batch of videos
|
889 |
+
# Shape: (batch_size, channels, depth, height, width)
|
890 |
+
# E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
|
891 |
+
input_videos = torch.randn(2, 3, 8, 64, 64)
|
892 |
+
|
893 |
+
# Forward pass: encode and decode the input videos
|
894 |
+
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
|
895 |
+
print(f"input shape={input_videos.shape}")
|
896 |
+
print(f"latent shape={latent.shape}")
|
897 |
+
reconstructed_videos = video_autoencoder.decode(latent, target_shape=input_videos.shape).sample
|
898 |
+
|
899 |
+
print(f"reconstructed shape={reconstructed_videos.shape}")
|
900 |
+
|
901 |
+
# Calculate the loss (e.g., mean squared error)
|
902 |
+
loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
|
903 |
+
|
904 |
+
# Perform backward pass
|
905 |
+
loss.backward()
|
906 |
+
|
907 |
+
print(f"Demo completed with loss: {loss.item()}")
|
908 |
+
|
909 |
+
|
910 |
+
# Ensure to call the demo function to execute the forward and backward pass
|
911 |
+
if __name__ == "__main__":
|
912 |
+
demo_video_autoencoder_forward_backward()
|
xora/models/transformers/embeddings.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from einops import rearrange
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
def get_timestep_embedding(
|
11 |
+
timesteps: torch.Tensor,
|
12 |
+
embedding_dim: int,
|
13 |
+
flip_sin_to_cos: bool = False,
|
14 |
+
downscale_freq_shift: float = 1,
|
15 |
+
scale: float = 1,
|
16 |
+
max_period: int = 10000,
|
17 |
+
):
|
18 |
+
"""
|
19 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
20 |
+
|
21 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
22 |
+
These may be fractional.
|
23 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
24 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
25 |
+
"""
|
26 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
27 |
+
|
28 |
+
half_dim = embedding_dim // 2
|
29 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
30 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
31 |
+
|
32 |
+
emb = torch.exp(exponent)
|
33 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
34 |
+
|
35 |
+
# scale embeddings
|
36 |
+
emb = scale * emb
|
37 |
+
|
38 |
+
# concat sine and cosine embeddings
|
39 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
40 |
+
|
41 |
+
# flip sine and cosine embeddings
|
42 |
+
if flip_sin_to_cos:
|
43 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
44 |
+
|
45 |
+
# zero pad
|
46 |
+
if embedding_dim % 2 == 1:
|
47 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
48 |
+
return emb
|
49 |
+
|
50 |
+
|
51 |
+
def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
|
52 |
+
"""
|
53 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
54 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
55 |
+
"""
|
56 |
+
grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
|
57 |
+
grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
|
58 |
+
grid = grid.reshape([3, 1, w, h, f])
|
59 |
+
pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
|
60 |
+
pos_embed = pos_embed.transpose(1, 0, 2, 3)
|
61 |
+
return rearrange(pos_embed, "h w f c -> (f h w) c")
|
62 |
+
|
63 |
+
|
64 |
+
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
|
65 |
+
if embed_dim % 3 != 0:
|
66 |
+
raise ValueError("embed_dim must be divisible by 3")
|
67 |
+
|
68 |
+
# use half of dimensions to encode grid_h
|
69 |
+
emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
|
70 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
|
71 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
|
72 |
+
|
73 |
+
emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
|
74 |
+
return emb
|
75 |
+
|
76 |
+
|
77 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
78 |
+
"""
|
79 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
80 |
+
"""
|
81 |
+
if embed_dim % 2 != 0:
|
82 |
+
raise ValueError("embed_dim must be divisible by 2")
|
83 |
+
|
84 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
85 |
+
omega /= embed_dim / 2.0
|
86 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
87 |
+
|
88 |
+
pos_shape = pos.shape
|
89 |
+
|
90 |
+
pos = pos.reshape(-1)
|
91 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
92 |
+
out = out.reshape([*pos_shape, -1])[0]
|
93 |
+
|
94 |
+
emb_sin = np.sin(out) # (M, D/2)
|
95 |
+
emb_cos = np.cos(out) # (M, D/2)
|
96 |
+
|
97 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
|
98 |
+
return emb
|
99 |
+
|
100 |
+
|
101 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
102 |
+
"""Apply positional information to a sequence of embeddings.
|
103 |
+
|
104 |
+
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
|
105 |
+
them
|
106 |
+
|
107 |
+
Args:
|
108 |
+
embed_dim: (int): Dimension of the positional embedding.
|
109 |
+
max_seq_length: Maximum sequence length to apply positional embeddings
|
110 |
+
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
114 |
+
super().__init__()
|
115 |
+
position = torch.arange(max_seq_length).unsqueeze(1)
|
116 |
+
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
|
117 |
+
pe = torch.zeros(1, max_seq_length, embed_dim)
|
118 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
119 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
120 |
+
self.register_buffer("pe", pe)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
_, seq_length, _ = x.shape
|
124 |
+
x = x + self.pe[:, :seq_length]
|
125 |
+
return x
|
xora/models/transformers/transformer3d.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
|
2 |
import math
|
3 |
from dataclasses import dataclass
|
4 |
-
from typing import Any, Dict, List, Optional
|
5 |
|
6 |
import torch
|
7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
@@ -9,10 +9,13 @@ from diffusers.models.embeddings import PixArtAlphaTextProjection
|
|
9 |
from diffusers.models.modeling_utils import ModelMixin
|
10 |
from diffusers.models.normalization import AdaLayerNormSingle
|
11 |
from diffusers.utils import BaseOutput, is_torch_version
|
|
|
12 |
from torch import nn
|
13 |
|
14 |
from xora.models.transformers.attention import BasicTransformerBlock
|
|
|
15 |
|
|
|
16 |
|
17 |
@dataclass
|
18 |
class Transformer3DModelOutput(BaseOutput):
|
@@ -143,6 +146,61 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
143 |
|
144 |
self.gradient_checkpointing = False
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
def _set_gradient_checkpointing(self, module, value=False):
|
147 |
if hasattr(module, "gradient_checkpointing"):
|
148 |
module.gradient_checkpointing = value
|
@@ -287,10 +345,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
287 |
if self.timestep_scale_multiplier:
|
288 |
timestep = self.timestep_scale_multiplier * timestep
|
289 |
|
290 |
-
if self.positional_embedding_type == "
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
292 |
-
else:
|
293 |
-
raise NotImplementedError("Only rope pos embed supported.")
|
294 |
|
295 |
batch_size = hidden_states.shape[0]
|
296 |
timestep, embedded_timestep = self.adaln_single(
|
@@ -358,3 +420,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
358 |
|
359 |
return Transformer3DModelOutput(sample=hidden_states)
|
360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
|
2 |
import math
|
3 |
from dataclasses import dataclass
|
4 |
+
from typing import Any, Dict, List, Optional, Literal
|
5 |
|
6 |
import torch
|
7 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
|
9 |
from diffusers.models.modeling_utils import ModelMixin
|
10 |
from diffusers.models.normalization import AdaLayerNormSingle
|
11 |
from diffusers.utils import BaseOutput, is_torch_version
|
12 |
+
from diffusers.utils import logging
|
13 |
from torch import nn
|
14 |
|
15 |
from xora.models.transformers.attention import BasicTransformerBlock
|
16 |
+
from xora.models.transformers.embeddings import get_3d_sincos_pos_embed
|
17 |
|
18 |
+
logger = logging.get_logger(__name__)
|
19 |
|
20 |
@dataclass
|
21 |
class Transformer3DModelOutput(BaseOutput):
|
|
|
146 |
|
147 |
self.gradient_checkpointing = False
|
148 |
|
149 |
+
def set_use_tpu_flash_attention(self):
|
150 |
+
r"""
|
151 |
+
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
152 |
+
attention kernel.
|
153 |
+
"""
|
154 |
+
logger.info(" ENABLE TPU FLASH ATTENTION -> TRUE")
|
155 |
+
# if using TPU -> configure components to use TPU flash attention
|
156 |
+
if dist_util.acceleration_type() == dist_util.AccelerationType.TPU:
|
157 |
+
self.use_tpu_flash_attention = True
|
158 |
+
# push config down to the attention modules
|
159 |
+
for block in self.transformer_blocks:
|
160 |
+
block.set_use_tpu_flash_attention()
|
161 |
+
|
162 |
+
def initialize(self, embedding_std: float, mode: Literal["xora", "pixart"]):
|
163 |
+
def _basic_init(module):
|
164 |
+
if isinstance(module, nn.Linear):
|
165 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
166 |
+
if module.bias is not None:
|
167 |
+
nn.init.constant_(module.bias, 0)
|
168 |
+
|
169 |
+
self.apply(_basic_init)
|
170 |
+
|
171 |
+
# Initialize timestep embedding MLP:
|
172 |
+
nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_1.weight, std=embedding_std)
|
173 |
+
nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_2.weight, std=embedding_std)
|
174 |
+
nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
|
175 |
+
|
176 |
+
if hasattr(self.adaln_single.emb, "resolution_embedder"):
|
177 |
+
nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_1.weight, std=embedding_std)
|
178 |
+
nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_2.weight, std=embedding_std)
|
179 |
+
if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
|
180 |
+
nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight, std=embedding_std)
|
181 |
+
nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight, std=embedding_std)
|
182 |
+
|
183 |
+
# Initialize caption embedding MLP:
|
184 |
+
nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
|
185 |
+
nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
|
186 |
+
|
187 |
+
# Zero-out adaLN modulation layers in PixArt blocks:
|
188 |
+
for block in self.transformer_blocks:
|
189 |
+
if mode == "xora":
|
190 |
+
nn.init.constant_(block.attn1.to_out[0].weight, 0)
|
191 |
+
nn.init.constant_(block.attn1.to_out[0].bias, 0)
|
192 |
+
|
193 |
+
nn.init.constant_(block.attn2.to_out[0].weight, 0)
|
194 |
+
nn.init.constant_(block.attn2.to_out[0].bias, 0)
|
195 |
+
|
196 |
+
if mode == "xora":
|
197 |
+
nn.init.constant_(block.ff.net[2].weight, 0)
|
198 |
+
nn.init.constant_(block.ff.net[2].bias, 0)
|
199 |
+
|
200 |
+
# Zero-out output layers:
|
201 |
+
nn.init.constant_(self.proj_out.weight, 0)
|
202 |
+
nn.init.constant_(self.proj_out.bias, 0)
|
203 |
+
|
204 |
def _set_gradient_checkpointing(self, module, value=False):
|
205 |
if hasattr(module, "gradient_checkpointing"):
|
206 |
module.gradient_checkpointing = value
|
|
|
345 |
if self.timestep_scale_multiplier:
|
346 |
timestep = self.timestep_scale_multiplier * timestep
|
347 |
|
348 |
+
if self.positional_embedding_type == "absolute":
|
349 |
+
pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(hidden_states.device)
|
350 |
+
if self.project_to_2d_pos:
|
351 |
+
pos_embed = self.to_2d_proj(pos_embed_3d)
|
352 |
+
hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
|
353 |
+
freqs_cis = None
|
354 |
+
elif self.positional_embedding_type == "rope":
|
355 |
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
|
|
|
|
356 |
|
357 |
batch_size = hidden_states.shape[0]
|
358 |
timestep, embedded_timestep = self.adaln_single(
|
|
|
420 |
|
421 |
return Transformer3DModelOutput(sample=hidden_states)
|
422 |
|
423 |
+
def get_absolute_pos_embed(self, grid):
|
424 |
+
grid_np = grid[0].cpu().numpy()
|
425 |
+
embed_dim_3d = math.ceil((self.inner_dim / 2) * 3) if self.project_to_2d_pos else self.inner_dim
|
426 |
+
pos_embed = get_3d_sincos_pos_embed( # (f h w)
|
427 |
+
embed_dim_3d,
|
428 |
+
grid_np,
|
429 |
+
h=int(max(grid_np[1]) + 1),
|
430 |
+
w=int(max(grid_np[2]) + 1),
|
431 |
+
f=int(max(grid_np[0] + 1)),
|
432 |
+
)
|
433 |
+
return torch.from_numpy(pos_embed).float().unsqueeze(0)
|
xora/pipelines/pipeline_video_pixart_alpha.py
CHANGED
@@ -32,16 +32,106 @@ from xora.models.transformers.symmetric_patchifier import Patchifier
|
|
32 |
from xora.models.autoencoders.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
|
33 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
34 |
from xora.schedulers.rf import TimestepShifter
|
|
|
35 |
|
36 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
|
38 |
-
|
39 |
if is_bs4_available():
|
40 |
from bs4 import BeautifulSoup
|
41 |
|
42 |
if is_ftfy_available():
|
43 |
import ftfy
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def retrieve_timesteps(
|
46 |
scheduler,
|
47 |
num_inference_steps: Optional[int] = None,
|
@@ -520,14 +610,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
520 |
|
521 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
522 |
def prepare_latents(
|
523 |
-
self,
|
524 |
-
batch_size,
|
525 |
-
num_latent_channels,
|
526 |
-
num_patches,
|
527 |
-
dtype,
|
528 |
-
device,
|
529 |
-
generator,
|
530 |
-
latents=None,
|
531 |
):
|
532 |
shape = (
|
533 |
batch_size,
|
@@ -543,6 +626,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
543 |
|
544 |
if latents is None:
|
545 |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
|
|
|
|
|
|
546 |
else:
|
547 |
latents = latents.to(device)
|
548 |
|
@@ -582,8 +668,8 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
582 |
|
583 |
return samples
|
584 |
|
585 |
-
|
586 |
@torch.no_grad()
|
|
|
587 |
def __call__(
|
588 |
self,
|
589 |
height: int,
|
@@ -607,6 +693,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
607 |
return_dict: bool = True,
|
608 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
609 |
clean_caption: bool = True,
|
|
|
610 |
**kwargs,
|
611 |
) -> Union[ImagePipelineOutput, Tuple]:
|
612 |
"""
|
@@ -736,8 +823,15 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
736 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
737 |
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
738 |
|
739 |
-
#
|
740 |
self.video_scale_factor = self.video_scale_factor if is_video else 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
741 |
latent_height = height // self.vae_scale_factor
|
742 |
latent_width = width // self.vae_scale_factor
|
743 |
latent_num_frames = num_frames // self.video_scale_factor
|
@@ -752,7 +846,12 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
752 |
dtype=prompt_embeds.dtype,
|
753 |
device=device,
|
754 |
generator=generator,
|
|
|
|
|
755 |
)
|
|
|
|
|
|
|
756 |
|
757 |
# 5. Prepare timesteps
|
758 |
retrieve_timesteps_kwargs = {}
|
@@ -790,7 +889,7 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
790 |
elif len(current_timestep.shape) == 0:
|
791 |
current_timestep = current_timestep[None].to(latent_model_input.device)
|
792 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
793 |
-
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
794 |
scale_grid = (
|
795 |
(1 / latent_frame_rates, self.vae_scale_factor, self.vae_scale_factor)
|
796 |
if self.transformer.use_rope
|
@@ -805,6 +904,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
805 |
device=latents.device,
|
806 |
)
|
807 |
|
|
|
|
|
|
|
808 |
# predict noise model_output
|
809 |
noise_pred = self.transformer(
|
810 |
latent_model_input.to(self.transformer.dtype),
|
@@ -819,13 +921,20 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
819 |
if do_classifier_free_guidance:
|
820 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
821 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
822 |
|
823 |
# learned sigma
|
824 |
if self.transformer.config.out_channels // 2 == self.transformer.config.in_channels:
|
825 |
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
826 |
|
827 |
# compute previous image: x_t -> x_t-1
|
828 |
-
latents = self.scheduler.step(
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
|
830 |
# call the callback, if provided
|
831 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
@@ -857,3 +966,62 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
857 |
return (image,)
|
858 |
|
859 |
return ImagePipelineOutput(images=image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
from xora.models.autoencoders.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
|
33 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
34 |
from xora.schedulers.rf import TimestepShifter
|
35 |
+
from xora.utils.conditioning_method import ConditioningMethod
|
36 |
|
37 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
38 |
|
|
|
39 |
if is_bs4_available():
|
40 |
from bs4 import BeautifulSoup
|
41 |
|
42 |
if is_ftfy_available():
|
43 |
import ftfy
|
44 |
|
45 |
+
EXAMPLE_DOC_STRING = """
|
46 |
+
Examples:
|
47 |
+
```py
|
48 |
+
>>> import torch
|
49 |
+
>>> from diffusers import PixArtAlphaPipeline
|
50 |
+
|
51 |
+
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
|
52 |
+
>>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
|
53 |
+
>>> # Enable memory optimizations.
|
54 |
+
>>> pipe.enable_model_cpu_offload()
|
55 |
+
|
56 |
+
>>> prompt = "A small cactus with a happy face in the Sahara desert."
|
57 |
+
>>> image = pipe(prompt).images[0]
|
58 |
+
```
|
59 |
+
"""
|
60 |
+
|
61 |
+
ASPECT_RATIO_1024_BIN = {
|
62 |
+
"0.25": [512.0, 2048.0],
|
63 |
+
"0.28": [512.0, 1856.0],
|
64 |
+
"0.32": [576.0, 1792.0],
|
65 |
+
"0.33": [576.0, 1728.0],
|
66 |
+
"0.35": [576.0, 1664.0],
|
67 |
+
"0.4": [640.0, 1600.0],
|
68 |
+
"0.42": [640.0, 1536.0],
|
69 |
+
"0.48": [704.0, 1472.0],
|
70 |
+
"0.5": [704.0, 1408.0],
|
71 |
+
"0.52": [704.0, 1344.0],
|
72 |
+
"0.57": [768.0, 1344.0],
|
73 |
+
"0.6": [768.0, 1280.0],
|
74 |
+
"0.68": [832.0, 1216.0],
|
75 |
+
"0.72": [832.0, 1152.0],
|
76 |
+
"0.78": [896.0, 1152.0],
|
77 |
+
"0.82": [896.0, 1088.0],
|
78 |
+
"0.88": [960.0, 1088.0],
|
79 |
+
"0.94": [960.0, 1024.0],
|
80 |
+
"1.0": [1024.0, 1024.0],
|
81 |
+
"1.07": [1024.0, 960.0],
|
82 |
+
"1.13": [1088.0, 960.0],
|
83 |
+
"1.21": [1088.0, 896.0],
|
84 |
+
"1.29": [1152.0, 896.0],
|
85 |
+
"1.38": [1152.0, 832.0],
|
86 |
+
"1.46": [1216.0, 832.0],
|
87 |
+
"1.67": [1280.0, 768.0],
|
88 |
+
"1.75": [1344.0, 768.0],
|
89 |
+
"2.0": [1408.0, 704.0],
|
90 |
+
"2.09": [1472.0, 704.0],
|
91 |
+
"2.4": [1536.0, 640.0],
|
92 |
+
"2.5": [1600.0, 640.0],
|
93 |
+
"3.0": [1728.0, 576.0],
|
94 |
+
"4.0": [2048.0, 512.0],
|
95 |
+
}
|
96 |
+
|
97 |
+
ASPECT_RATIO_512_BIN = {
|
98 |
+
"0.25": [256.0, 1024.0],
|
99 |
+
"0.28": [256.0, 928.0],
|
100 |
+
"0.32": [288.0, 896.0],
|
101 |
+
"0.33": [288.0, 864.0],
|
102 |
+
"0.35": [288.0, 832.0],
|
103 |
+
"0.4": [320.0, 800.0],
|
104 |
+
"0.42": [320.0, 768.0],
|
105 |
+
"0.48": [352.0, 736.0],
|
106 |
+
"0.5": [352.0, 704.0],
|
107 |
+
"0.52": [352.0, 672.0],
|
108 |
+
"0.57": [384.0, 672.0],
|
109 |
+
"0.6": [384.0, 640.0],
|
110 |
+
"0.68": [416.0, 608.0],
|
111 |
+
"0.72": [416.0, 576.0],
|
112 |
+
"0.78": [448.0, 576.0],
|
113 |
+
"0.82": [448.0, 544.0],
|
114 |
+
"0.88": [480.0, 544.0],
|
115 |
+
"0.94": [480.0, 512.0],
|
116 |
+
"1.0": [512.0, 512.0],
|
117 |
+
"1.07": [512.0, 480.0],
|
118 |
+
"1.13": [544.0, 480.0],
|
119 |
+
"1.21": [544.0, 448.0],
|
120 |
+
"1.29": [576.0, 448.0],
|
121 |
+
"1.38": [576.0, 416.0],
|
122 |
+
"1.46": [608.0, 416.0],
|
123 |
+
"1.67": [640.0, 384.0],
|
124 |
+
"1.75": [672.0, 384.0],
|
125 |
+
"2.0": [704.0, 352.0],
|
126 |
+
"2.09": [736.0, 352.0],
|
127 |
+
"2.4": [768.0, 320.0],
|
128 |
+
"2.5": [800.0, 320.0],
|
129 |
+
"3.0": [864.0, 288.0],
|
130 |
+
"4.0": [1024.0, 256.0],
|
131 |
+
}
|
132 |
+
|
133 |
+
|
134 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
135 |
def retrieve_timesteps(
|
136 |
scheduler,
|
137 |
num_inference_steps: Optional[int] = None,
|
|
|
610 |
|
611 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
612 |
def prepare_latents(
|
613 |
+
self, batch_size, num_latent_channels, num_patches, dtype, device, generator, latents=None, latents_mask=None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
):
|
615 |
shape = (
|
616 |
batch_size,
|
|
|
626 |
|
627 |
if latents is None:
|
628 |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
629 |
+
elif latents_mask is not None:
|
630 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
631 |
+
latents = latents * latents_mask[..., None] + noise * (1 - latents_mask[..., None])
|
632 |
else:
|
633 |
latents = latents.to(device)
|
634 |
|
|
|
668 |
|
669 |
return samples
|
670 |
|
|
|
671 |
@torch.no_grad()
|
672 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
673 |
def __call__(
|
674 |
self,
|
675 |
height: int,
|
|
|
693 |
return_dict: bool = True,
|
694 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
695 |
clean_caption: bool = True,
|
696 |
+
media_items: Optional[torch.FloatTensor] = None,
|
697 |
**kwargs,
|
698 |
) -> Union[ImagePipelineOutput, Tuple]:
|
699 |
"""
|
|
|
823 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
824 |
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
825 |
|
826 |
+
# 3b. Encode and prepare conditioning data
|
827 |
self.video_scale_factor = self.video_scale_factor if is_video else 1
|
828 |
+
conditioning_method = kwargs.get("conditioning_method", None)
|
829 |
+
vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
|
830 |
+
init_latents, conditioning_mask = self.prepare_conditioning(
|
831 |
+
media_items, num_frames, height, width, conditioning_method, vae_per_channel_normalize
|
832 |
+
)
|
833 |
+
|
834 |
+
# 4. Prepare latents.
|
835 |
latent_height = height // self.vae_scale_factor
|
836 |
latent_width = width // self.vae_scale_factor
|
837 |
latent_num_frames = num_frames // self.video_scale_factor
|
|
|
846 |
dtype=prompt_embeds.dtype,
|
847 |
device=device,
|
848 |
generator=generator,
|
849 |
+
latents=init_latents,
|
850 |
+
latents_mask=conditioning_mask,
|
851 |
)
|
852 |
+
if conditioning_mask is not None and is_video:
|
853 |
+
assert num_images_per_prompt == 1
|
854 |
+
conditioning_mask = torch.cat([conditioning_mask] * 2) if do_classifier_free_guidance else conditioning_mask
|
855 |
|
856 |
# 5. Prepare timesteps
|
857 |
retrieve_timesteps_kwargs = {}
|
|
|
889 |
elif len(current_timestep.shape) == 0:
|
890 |
current_timestep = current_timestep[None].to(latent_model_input.device)
|
891 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
892 |
+
current_timestep = current_timestep.expand(latent_model_input.shape[0]).unsqueeze(-1)
|
893 |
scale_grid = (
|
894 |
(1 / latent_frame_rates, self.vae_scale_factor, self.vae_scale_factor)
|
895 |
if self.transformer.use_rope
|
|
|
904 |
device=latents.device,
|
905 |
)
|
906 |
|
907 |
+
if conditioning_mask is not None:
|
908 |
+
current_timestep = current_timestep * (1 - conditioning_mask)
|
909 |
+
|
910 |
# predict noise model_output
|
911 |
noise_pred = self.transformer(
|
912 |
latent_model_input.to(self.transformer.dtype),
|
|
|
921 |
if do_classifier_free_guidance:
|
922 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
923 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
924 |
+
current_timestep, _ = current_timestep.chunk(2)
|
925 |
|
926 |
# learned sigma
|
927 |
if self.transformer.config.out_channels // 2 == self.transformer.config.in_channels:
|
928 |
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
929 |
|
930 |
# compute previous image: x_t -> x_t-1
|
931 |
+
latents = self.scheduler.step(
|
932 |
+
noise_pred,
|
933 |
+
t if current_timestep is None else current_timestep,
|
934 |
+
latents,
|
935 |
+
**extra_step_kwargs,
|
936 |
+
return_dict=False,
|
937 |
+
)[0]
|
938 |
|
939 |
# call the callback, if provided
|
940 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
|
966 |
return (image,)
|
967 |
|
968 |
return ImagePipelineOutput(images=image)
|
969 |
+
|
970 |
+
def prepare_conditioning(
|
971 |
+
self,
|
972 |
+
media_items: torch.Tensor,
|
973 |
+
num_frames: int,
|
974 |
+
height: int,
|
975 |
+
width: int,
|
976 |
+
method: ConditioningMethod = ConditioningMethod.UNCONDITIONAL,
|
977 |
+
vae_per_channel_normalize: bool = False,
|
978 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
979 |
+
"""
|
980 |
+
Prepare the conditioning data for the video generation. If an input media item is provided, encode it
|
981 |
+
and set the conditioning_mask to indicate which tokens to condition on. Input media item should have
|
982 |
+
the same height and width as the generated video.
|
983 |
+
|
984 |
+
Args:
|
985 |
+
media_items (torch.Tensor): media items to condition on (images or videos)
|
986 |
+
num_frames (int): number of frames to generate
|
987 |
+
height (int): height of the generated video
|
988 |
+
width (int): width of the generated video
|
989 |
+
method (ConditioningMethod, optional): conditioning method to use. Defaults to ConditioningMethod.UNCONDITIONAL.
|
990 |
+
vae_per_channel_normalize (bool, optional): whether to normalize the input to the VAE per channel. Defaults to False.
|
991 |
+
|
992 |
+
Returns:
|
993 |
+
Tuple[torch.Tensor, torch.Tensor]: the conditioning latents and the conditioning mask
|
994 |
+
"""
|
995 |
+
if media_items is None or method == ConditioningMethod.UNCONDITIONAL:
|
996 |
+
return None, None
|
997 |
+
|
998 |
+
assert media_items.ndim == 5
|
999 |
+
assert height == media_items.shape[-2] and width == media_items.shape[-1]
|
1000 |
+
|
1001 |
+
# Encode the input video and repeat to the required number of frame-tokens
|
1002 |
+
init_latents = vae_encode(
|
1003 |
+
media_items.to(dtype=self.vae.dtype, device=self.vae.device),
|
1004 |
+
self.vae,
|
1005 |
+
vae_per_channel_normalize=vae_per_channel_normalize,
|
1006 |
+
).float()
|
1007 |
+
|
1008 |
+
init_len, target_len = init_latents.shape[2], num_frames // self.video_scale_factor
|
1009 |
+
if isinstance(self.vae, CausalVideoAutoencoder):
|
1010 |
+
target_len += 1
|
1011 |
+
init_latents = init_latents[:, :, :target_len]
|
1012 |
+
if target_len > init_len:
|
1013 |
+
repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division
|
1014 |
+
init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[:, :, :target_len]
|
1015 |
+
|
1016 |
+
# Prepare the conditioning mask (1.0 = condition on this token)
|
1017 |
+
b, n, f, h, w = init_latents.shape
|
1018 |
+
conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
|
1019 |
+
if method in [ConditioningMethod.FIRST_FRAME, ConditioningMethod.FIRST_AND_LAST_FRAME]:
|
1020 |
+
conditioning_mask[:, :, 0] = 1.0
|
1021 |
+
if method in [ConditioningMethod.LAST_FRAME, ConditioningMethod.FIRST_AND_LAST_FRAME]:
|
1022 |
+
conditioning_mask[:, :, -1] = 1.0
|
1023 |
+
|
1024 |
+
# Patchify the init latents and the mask
|
1025 |
+
conditioning_mask = self.patchifier.patchify(conditioning_mask).squeeze(-1)
|
1026 |
+
init_latents = self.patchifier.patchify(latents=init_latents)
|
1027 |
+
return init_latents, conditioning_mask
|
xora/schedulers/rf.py
CHANGED
@@ -9,7 +9,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
|
9 |
from diffusers.utils import BaseOutput
|
10 |
from torch import Tensor
|
11 |
|
12 |
-
from
|
13 |
|
14 |
|
15 |
def simple_diffusion_resolution_dependent_timestep_shift(
|
@@ -199,8 +199,17 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
199 |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
200 |
)
|
201 |
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
prev_sample = sample - dt * model_output
|
206 |
|
@@ -219,4 +228,4 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
219 |
sigmas = append_dims(sigmas, original_samples.ndim)
|
220 |
alphas = 1 - sigmas
|
221 |
noisy_samples = alphas * original_samples + sigmas * noise
|
222 |
-
return noisy_samples
|
|
|
9 |
from diffusers.utils import BaseOutput
|
10 |
from torch import Tensor
|
11 |
|
12 |
+
from txt2img.common.torch_utils import append_dims
|
13 |
|
14 |
|
15 |
def simple_diffusion_resolution_dependent_timestep_shift(
|
|
|
199 |
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
200 |
)
|
201 |
|
202 |
+
if timestep.ndim == 0:
|
203 |
+
# Global timestep
|
204 |
+
current_index = (self.timesteps - timestep).abs().argmin()
|
205 |
+
dt = self.delta_timesteps.gather(0, current_index.unsqueeze(0))
|
206 |
+
else:
|
207 |
+
# Timestep per token
|
208 |
+
assert timestep.ndim == 2
|
209 |
+
current_index = (self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0)
|
210 |
+
dt = self.delta_timesteps[current_index]
|
211 |
+
# Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
|
212 |
+
dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
|
213 |
|
214 |
prev_sample = sample - dt * model_output
|
215 |
|
|
|
228 |
sigmas = append_dims(sigmas, original_samples.ndim)
|
229 |
alphas = 1 - sigmas
|
230 |
noisy_samples = alphas * original_samples + sigmas * noise
|
231 |
+
return noisy_samples
|
xora/utils/conditioning_method.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
class ConditioningMethod(Enum):
|
4 |
+
UNCONDITIONAL = "unconditional"
|
5 |
+
FIRST_FRAME = "first_frame"
|
6 |
+
LAST_FRAME = "last_frame"
|
7 |
+
FIRST_AND_LAST_FRAME = "first_and_last_frame"
|
xora/utils/dist_util.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
class AccelerationType(Enum):
|
4 |
+
CPU = "cpu"
|
5 |
+
GPU = "gpu"
|
6 |
+
TPU = "tpu"
|
7 |
+
MPS = "mps"
|
8 |
+
|
9 |
+
def execute_graph() -> None:
|
10 |
+
if _acceleration_type == AccelerationType.TPU:
|
11 |
+
xm.mark_step()
|