Lint: added ruff.
Browse files- .github/workflows/pylint.yml +27 -0
- .gitignore +1 -1
- .pre-commit-config.yaml +16 -0
- scripts/to_safetensors.py +64 -30
- setup.py +11 -7
- xora/__init__.py +0 -1
- xora/examples/image_to_video.py +94 -44
- xora/examples/text_to_video.py +29 -13
- xora/models/autoencoders/causal_conv3d.py +9 -3
- xora/models/autoencoders/causal_video_autoencoder.py +133 -33
- xora/models/autoencoders/conv_nd_factory.py +6 -2
- xora/models/autoencoders/dual_conv3d.py +36 -6
- xora/models/autoencoders/vae.py +74 -24
- xora/models/autoencoders/vae_encode.py +62 -17
- xora/models/autoencoders/video_autoencoder.py +170 -46
- xora/models/transformers/attention.py +174 -53
- xora/models/transformers/embeddings.py +6 -2
- xora/models/transformers/symmetric_patchifier.py +19 -4
- xora/models/transformers/transformer3d.py +86 -23
- xora/pipelines/pipeline_video_pixart_alpha.py +205 -63
- xora/schedulers/rf.py +43 -13
- xora/utils/conditioning_method.py +2 -1
- xora/utils/torch_utils.py +5 -1
.github/workflows/pylint.yml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Ruff
|
2 |
+
|
3 |
+
on: [push]
|
4 |
+
|
5 |
+
jobs:
|
6 |
+
build:
|
7 |
+
runs-on: ubuntu-latest
|
8 |
+
strategy:
|
9 |
+
matrix:
|
10 |
+
python-version: ["3.10"]
|
11 |
+
steps:
|
12 |
+
- name: Checkout repository and submodules
|
13 |
+
uses: actions/checkout@v3
|
14 |
+
- name: Set up Python ${{ matrix.python-version }}
|
15 |
+
uses: actions/setup-python@v3
|
16 |
+
with:
|
17 |
+
python-version: ${{ matrix.python-version }}
|
18 |
+
- name: Install dependencies
|
19 |
+
run: |
|
20 |
+
python -m pip install --upgrade pip
|
21 |
+
pip install ruff==0.2.2 black==24.2.0
|
22 |
+
- name: Analyzing the code with ruff
|
23 |
+
run: |
|
24 |
+
ruff $(git ls-files '*.py')
|
25 |
+
- name: Verify that no Black changes are required
|
26 |
+
run: |
|
27 |
+
black --check $(git ls-files '*.py')
|
.gitignore
CHANGED
@@ -159,4 +159,4 @@ cython_debug/
|
|
159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
-
|
|
|
159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
.idea/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
3 |
+
# Ruff version.
|
4 |
+
rev: v0.2.2
|
5 |
+
hooks:
|
6 |
+
# Run the linter.
|
7 |
+
- id: ruff
|
8 |
+
args: [--fix] # Automatically fix issues if possible.
|
9 |
+
types: [python] # Ensure it only runs on .py files.
|
10 |
+
|
11 |
+
- repo: https://github.com/psf/black
|
12 |
+
rev: 24.2.0 # Specify the version of Black you want
|
13 |
+
hooks:
|
14 |
+
- id: black
|
15 |
+
name: Black code formatter
|
16 |
+
language_version: python3 # Use the Python version you're targeting (e.g., 3.10)
|
scripts/to_safetensors.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import argparse
|
2 |
from pathlib import Path
|
3 |
-
from typing import
|
4 |
import safetensors.torch
|
5 |
import torch
|
6 |
import json
|
@@ -8,12 +8,14 @@ import shutil
|
|
8 |
|
9 |
|
10 |
def load_text_encoder(index_path: Path) -> Dict:
|
11 |
-
with open(index_path,
|
12 |
index: Dict = json.load(f)
|
13 |
|
14 |
loaded_tensors = {}
|
15 |
for part_file in set(index.get("weight_map", {}).values()):
|
16 |
-
tensors = safetensors.torch.load_file(
|
|
|
|
|
17 |
for tensor_name in tensors:
|
18 |
loaded_tensors[tensor_name] = tensors[tensor_name]
|
19 |
|
@@ -30,23 +32,30 @@ def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
|
|
30 |
state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
|
31 |
stats_path = vae_path / "per_channel_statistics.json"
|
32 |
if stats_path.exists():
|
33 |
-
with open(stats_path,
|
34 |
data = json.load(f)
|
35 |
transposed_data = list(zip(*data["data"]))
|
36 |
data_dict = {
|
37 |
-
f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(
|
|
|
|
|
38 |
for col, vals in zip(data["columns"], transposed_data)
|
39 |
}
|
40 |
else:
|
41 |
data_dict = {}
|
42 |
|
43 |
-
result = {
|
|
|
|
|
44 |
result.update(data_dict)
|
45 |
return result
|
46 |
|
47 |
|
48 |
def convert_encoder(encoder: Dict) -> Dict:
|
49 |
-
return {
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
def save_config(config_src: str, config_dst: str):
|
@@ -60,50 +69,75 @@ def load_vae_config(vae_path: Path) -> str:
|
|
60 |
return str(config_path)
|
61 |
|
62 |
|
63 |
-
def main(
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
# Load VAE from directory and config
|
68 |
-
vae = convert_vae(Path(vae_path), add_prefix=(mode ==
|
69 |
vae_config_path = load_vae_config(Path(vae_path))
|
70 |
|
71 |
-
if mode ==
|
72 |
result = {**unet, **vae}
|
73 |
safetensors.torch.save_file(result, out_path)
|
74 |
-
elif mode ==
|
75 |
# Create directories for unet, vae, and scheduler
|
76 |
-
unet_dir = Path(out_path) /
|
77 |
-
vae_dir = Path(out_path) /
|
78 |
-
scheduler_dir = Path(out_path) /
|
79 |
|
80 |
unet_dir.mkdir(parents=True, exist_ok=True)
|
81 |
vae_dir.mkdir(parents=True, exist_ok=True)
|
82 |
scheduler_dir.mkdir(parents=True, exist_ok=True)
|
83 |
|
84 |
# Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
|
85 |
-
safetensors.torch.save_file(
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
|
88 |
# Save config files for unet, vae, and scheduler
|
89 |
if unet_config_path:
|
90 |
-
save_config(unet_config_path, unet_dir /
|
91 |
if vae_config_path:
|
92 |
-
save_config(vae_config_path, vae_dir /
|
93 |
if scheduler_config_path:
|
94 |
-
save_config(scheduler_config_path, scheduler_dir /
|
95 |
|
96 |
|
97 |
-
if __name__ ==
|
98 |
parser = argparse.ArgumentParser()
|
99 |
-
parser.add_argument(
|
100 |
-
parser.add_argument(
|
101 |
-
parser.add_argument(
|
102 |
-
parser.add_argument(
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
args = parser.parse_args()
|
109 |
main(**args.__dict__)
|
|
|
1 |
import argparse
|
2 |
from pathlib import Path
|
3 |
+
from typing import Dict
|
4 |
import safetensors.torch
|
5 |
import torch
|
6 |
import json
|
|
|
8 |
|
9 |
|
10 |
def load_text_encoder(index_path: Path) -> Dict:
|
11 |
+
with open(index_path, "r") as f:
|
12 |
index: Dict = json.load(f)
|
13 |
|
14 |
loaded_tensors = {}
|
15 |
for part_file in set(index.get("weight_map", {}).values()):
|
16 |
+
tensors = safetensors.torch.load_file(
|
17 |
+
index_path.parent / part_file, device="cpu"
|
18 |
+
)
|
19 |
for tensor_name in tensors:
|
20 |
loaded_tensors[tensor_name] = tensors[tensor_name]
|
21 |
|
|
|
32 |
state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
|
33 |
stats_path = vae_path / "per_channel_statistics.json"
|
34 |
if stats_path.exists():
|
35 |
+
with open(stats_path, "r") as f:
|
36 |
data = json.load(f)
|
37 |
transposed_data = list(zip(*data["data"]))
|
38 |
data_dict = {
|
39 |
+
f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(
|
40 |
+
vals
|
41 |
+
)
|
42 |
for col, vals in zip(data["columns"], transposed_data)
|
43 |
}
|
44 |
else:
|
45 |
data_dict = {}
|
46 |
|
47 |
+
result = {
|
48 |
+
("vae." if add_prefix else "") + key: value for key, value in state_dict.items()
|
49 |
+
}
|
50 |
result.update(data_dict)
|
51 |
return result
|
52 |
|
53 |
|
54 |
def convert_encoder(encoder: Dict) -> Dict:
|
55 |
+
return {
|
56 |
+
"text_encoders.t5xxl.transformer." + key: value
|
57 |
+
for key, value in encoder.items()
|
58 |
+
}
|
59 |
|
60 |
|
61 |
def save_config(config_src: str, config_dst: str):
|
|
|
69 |
return str(config_path)
|
70 |
|
71 |
|
72 |
+
def main(
|
73 |
+
unet_path: str,
|
74 |
+
vae_path: str,
|
75 |
+
out_path: str,
|
76 |
+
mode: str,
|
77 |
+
unet_config_path: str = None,
|
78 |
+
scheduler_config_path: str = None,
|
79 |
+
) -> None:
|
80 |
+
unet = convert_unet(
|
81 |
+
torch.load(unet_path, weights_only=True), add_prefix=(mode == "single")
|
82 |
+
)
|
83 |
|
84 |
# Load VAE from directory and config
|
85 |
+
vae = convert_vae(Path(vae_path), add_prefix=(mode == "single"))
|
86 |
vae_config_path = load_vae_config(Path(vae_path))
|
87 |
|
88 |
+
if mode == "single":
|
89 |
result = {**unet, **vae}
|
90 |
safetensors.torch.save_file(result, out_path)
|
91 |
+
elif mode == "separate":
|
92 |
# Create directories for unet, vae, and scheduler
|
93 |
+
unet_dir = Path(out_path) / "unet"
|
94 |
+
vae_dir = Path(out_path) / "vae"
|
95 |
+
scheduler_dir = Path(out_path) / "scheduler"
|
96 |
|
97 |
unet_dir.mkdir(parents=True, exist_ok=True)
|
98 |
vae_dir.mkdir(parents=True, exist_ok=True)
|
99 |
scheduler_dir.mkdir(parents=True, exist_ok=True)
|
100 |
|
101 |
# Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
|
102 |
+
safetensors.torch.save_file(
|
103 |
+
unet, unet_dir / "diffusion_pytorch_model.safetensors"
|
104 |
+
)
|
105 |
+
safetensors.torch.save_file(
|
106 |
+
vae, vae_dir / "diffusion_pytorch_model.safetensors"
|
107 |
+
)
|
108 |
|
109 |
# Save config files for unet, vae, and scheduler
|
110 |
if unet_config_path:
|
111 |
+
save_config(unet_config_path, unet_dir / "config.json")
|
112 |
if vae_config_path:
|
113 |
+
save_config(vae_config_path, vae_dir / "config.json")
|
114 |
if scheduler_config_path:
|
115 |
+
save_config(scheduler_config_path, scheduler_dir / "scheduler_config.json")
|
116 |
|
117 |
|
118 |
+
if __name__ == "__main__":
|
119 |
parser = argparse.ArgumentParser()
|
120 |
+
parser.add_argument("--unet_path", "-u", type=str, default="unet/ema-002.pt")
|
121 |
+
parser.add_argument("--vae_path", "-v", type=str, default="vae/")
|
122 |
+
parser.add_argument("--out_path", "-o", type=str, default="xora.safetensors")
|
123 |
+
parser.add_argument(
|
124 |
+
"--mode",
|
125 |
+
"-m",
|
126 |
+
type=str,
|
127 |
+
choices=["single", "separate"],
|
128 |
+
default="single",
|
129 |
+
help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--unet_config_path",
|
133 |
+
type=str,
|
134 |
+
help="Path to the UNet config file (for separate mode)",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--scheduler_config_path",
|
138 |
+
type=str,
|
139 |
+
help="Path to the Scheduler config file (for separate mode)",
|
140 |
+
)
|
141 |
|
142 |
args = parser.parse_args()
|
143 |
main(**args.__dict__)
|
setup.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
from setuptools import setup, find_packages
|
|
|
|
|
2 |
def parse_requirements(filename):
|
3 |
"""Load requirements from a pip requirements file."""
|
4 |
-
with open(filename,
|
5 |
return file.read().splitlines()
|
6 |
|
7 |
|
@@ -13,11 +15,13 @@ setup(
|
|
13 |
author_email="sapir@lightricks.com", # Your email
|
14 |
url="https://github.com/LightricksResearch/xora-core", # URL for the project (GitHub, etc.)
|
15 |
packages=find_packages(), # Automatically find all packages inside `xora`
|
16 |
-
install_requires=parse_requirements(
|
|
|
|
|
17 |
classifiers=[
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
],
|
22 |
-
python_requires=
|
23 |
-
)
|
|
|
1 |
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
|
4 |
def parse_requirements(filename):
|
5 |
"""Load requirements from a pip requirements file."""
|
6 |
+
with open(filename, "r") as file:
|
7 |
return file.read().splitlines()
|
8 |
|
9 |
|
|
|
15 |
author_email="sapir@lightricks.com", # Your email
|
16 |
url="https://github.com/LightricksResearch/xora-core", # URL for the project (GitHub, etc.)
|
17 |
packages=find_packages(), # Automatically find all packages inside `xora`
|
18 |
+
install_requires=parse_requirements(
|
19 |
+
"requirements.txt"
|
20 |
+
), # Install dependencies from requirements.txt
|
21 |
classifiers=[
|
22 |
+
"Programming Language :: Python :: 3",
|
23 |
+
"License :: OSI Approved :: MIT License",
|
24 |
+
"Operating System :: OS Independent",
|
25 |
],
|
26 |
+
python_requires=">=3.10", # Specify Python version compatibility
|
27 |
+
)
|
xora/__init__.py
CHANGED
@@ -1 +0,0 @@
|
|
1 |
-
from .pipelines import *
|
|
|
|
xora/examples/image_to_video.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import time
|
2 |
import torch
|
3 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
4 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
@@ -15,19 +14,20 @@ import os
|
|
15 |
import numpy as np
|
16 |
import cv2
|
17 |
from PIL import Image
|
18 |
-
from tqdm import tqdm
|
19 |
import random
|
20 |
|
|
|
21 |
def load_vae(vae_dir):
|
22 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
23 |
vae_config_path = vae_dir / "config.json"
|
24 |
-
with open(vae_config_path,
|
25 |
vae_config = json.load(f)
|
26 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
27 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
28 |
vae.load_state_dict(vae_state_dict)
|
29 |
return vae.cuda().to(torch.bfloat16)
|
30 |
|
|
|
31 |
def load_unet(unet_dir):
|
32 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
33 |
unet_config_path = unet_dir / "config.json"
|
@@ -37,11 +37,13 @@ def load_unet(unet_dir):
|
|
37 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
38 |
return transformer.cuda()
|
39 |
|
|
|
40 |
def load_scheduler(scheduler_dir):
|
41 |
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
42 |
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
43 |
return RectifiedFlowScheduler.from_config(scheduler_config)
|
44 |
|
|
|
45 |
def center_crop_and_resize(frame, target_height, target_width):
|
46 |
h, w, _ = frame.shape
|
47 |
aspect_ratio_target = target_width / target_height
|
@@ -49,14 +51,15 @@ def center_crop_and_resize(frame, target_height, target_width):
|
|
49 |
if aspect_ratio_frame > aspect_ratio_target:
|
50 |
new_width = int(h * aspect_ratio_target)
|
51 |
x_start = (w - new_width) // 2
|
52 |
-
frame_cropped = frame[:, x_start:x_start + new_width]
|
53 |
else:
|
54 |
new_height = int(w / aspect_ratio_target)
|
55 |
y_start = (h - new_height) // 2
|
56 |
-
frame_cropped = frame[y_start:y_start + new_height, :]
|
57 |
frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
|
58 |
return frame_resized
|
59 |
|
|
|
60 |
def load_video_to_tensor_with_resize(video_path, target_height=512, target_width=768):
|
61 |
cap = cv2.VideoCapture(video_path)
|
62 |
frames = []
|
@@ -72,6 +75,7 @@ def load_video_to_tensor_with_resize(video_path, target_height=512, target_width
|
|
72 |
video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
|
73 |
return video_tensor
|
74 |
|
|
|
75 |
def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
|
76 |
image = Image.open(image_path).convert("RGB")
|
77 |
image_np = np.array(image)
|
@@ -81,51 +85,90 @@ def load_image_to_tensor_with_resize(image_path, target_height=512, target_width
|
|
81 |
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
|
82 |
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
83 |
|
|
|
84 |
def main():
|
85 |
-
parser = argparse.ArgumentParser(
|
|
|
|
|
86 |
|
87 |
# Directories
|
88 |
-
parser.add_argument(
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
95 |
|
96 |
# Pipeline parameters
|
97 |
-
parser.add_argument(
|
98 |
-
|
99 |
-
|
100 |
-
parser.add_argument(
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
# Prompts
|
106 |
-
parser.add_argument(
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
args = parser.parse_args()
|
114 |
|
115 |
# Paths for the separate mode directories
|
116 |
ckpt_dir = Path(args.ckpt_dir)
|
117 |
-
unet_dir = ckpt_dir /
|
118 |
-
vae_dir = ckpt_dir /
|
119 |
-
scheduler_dir = ckpt_dir /
|
120 |
|
121 |
# Load models
|
122 |
vae = load_vae(vae_dir)
|
123 |
unet = load_unet(unet_dir)
|
124 |
scheduler = load_scheduler(scheduler_dir)
|
125 |
patchifier = SymmetricPatchifier(patch_size=1)
|
126 |
-
text_encoder = T5EncoderModel.from_pretrained(
|
127 |
-
"
|
128 |
-
|
|
|
|
|
|
|
129 |
|
130 |
# Use submodels for the pipeline
|
131 |
submodel_dict = {
|
@@ -141,22 +184,25 @@ def main():
|
|
141 |
|
142 |
# Load media (video or image)
|
143 |
if args.video_path:
|
144 |
-
media_items = load_video_to_tensor_with_resize(
|
|
|
|
|
145 |
elif args.image_path:
|
146 |
-
media_items = load_image_to_tensor_with_resize(
|
|
|
|
|
147 |
else:
|
148 |
raise ValueError("Either --video_path or --image_path must be provided.")
|
149 |
|
150 |
# Prepare input for the pipeline
|
151 |
sample = {
|
152 |
"prompt": args.prompt,
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
}
|
158 |
|
159 |
-
start_time = time.time()
|
160 |
random.seed(args.seed)
|
161 |
np.random.seed(args.seed)
|
162 |
torch.manual_seed(args.seed)
|
@@ -177,16 +223,18 @@ def main():
|
|
177 |
**sample,
|
178 |
is_video=True,
|
179 |
vae_per_channel_normalize=True,
|
180 |
-
conditioning_method=ConditioningMethod.FIRST_FRAME
|
181 |
).images
|
|
|
182 |
# Save output video
|
183 |
-
def get_unique_filename(base, ext, dir=
|
184 |
for i in range(index_range):
|
185 |
filename = os.path.join(dir, f"{base}_{i}{ext}")
|
186 |
if not os.path.exists(filename):
|
187 |
return filename
|
188 |
-
raise FileExistsError(
|
189 |
-
|
|
|
190 |
|
191 |
for i in range(images.shape[0]):
|
192 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
@@ -195,7 +243,9 @@ def main():
|
|
195 |
height, width = video_np.shape[1:3]
|
196 |
output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
|
197 |
|
198 |
-
out = cv2.VideoWriter(
|
|
|
|
|
199 |
|
200 |
for frame in video_np[..., ::-1]:
|
201 |
out.write(frame)
|
|
|
|
|
1 |
import torch
|
2 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
3 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
|
|
14 |
import numpy as np
|
15 |
import cv2
|
16 |
from PIL import Image
|
|
|
17 |
import random
|
18 |
|
19 |
+
|
20 |
def load_vae(vae_dir):
|
21 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
22 |
vae_config_path = vae_dir / "config.json"
|
23 |
+
with open(vae_config_path, "r") as f:
|
24 |
vae_config = json.load(f)
|
25 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
26 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
27 |
vae.load_state_dict(vae_state_dict)
|
28 |
return vae.cuda().to(torch.bfloat16)
|
29 |
|
30 |
+
|
31 |
def load_unet(unet_dir):
|
32 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
33 |
unet_config_path = unet_dir / "config.json"
|
|
|
37 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
38 |
return transformer.cuda()
|
39 |
|
40 |
+
|
41 |
def load_scheduler(scheduler_dir):
|
42 |
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
43 |
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
44 |
return RectifiedFlowScheduler.from_config(scheduler_config)
|
45 |
|
46 |
+
|
47 |
def center_crop_and_resize(frame, target_height, target_width):
|
48 |
h, w, _ = frame.shape
|
49 |
aspect_ratio_target = target_width / target_height
|
|
|
51 |
if aspect_ratio_frame > aspect_ratio_target:
|
52 |
new_width = int(h * aspect_ratio_target)
|
53 |
x_start = (w - new_width) // 2
|
54 |
+
frame_cropped = frame[:, x_start : x_start + new_width]
|
55 |
else:
|
56 |
new_height = int(w / aspect_ratio_target)
|
57 |
y_start = (h - new_height) // 2
|
58 |
+
frame_cropped = frame[y_start : y_start + new_height, :]
|
59 |
frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
|
60 |
return frame_resized
|
61 |
|
62 |
+
|
63 |
def load_video_to_tensor_with_resize(video_path, target_height=512, target_width=768):
|
64 |
cap = cv2.VideoCapture(video_path)
|
65 |
frames = []
|
|
|
75 |
video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
|
76 |
return video_tensor
|
77 |
|
78 |
+
|
79 |
def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
|
80 |
image = Image.open(image_path).convert("RGB")
|
81 |
image_np = np.array(image)
|
|
|
85 |
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
|
86 |
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
87 |
|
88 |
+
|
89 |
def main():
|
90 |
+
parser = argparse.ArgumentParser(
|
91 |
+
description="Load models from separate directories and run the pipeline."
|
92 |
+
)
|
93 |
|
94 |
# Directories
|
95 |
+
parser.add_argument(
|
96 |
+
"--ckpt_dir",
|
97 |
+
type=str,
|
98 |
+
required=True,
|
99 |
+
help="Path to the directory containing unet, vae, and scheduler subdirectories",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--video_path", type=str, help="Path to the input video file (first frame used)"
|
103 |
+
)
|
104 |
+
parser.add_argument("--image_path", type=str, help="Path to the input image file")
|
105 |
+
parser.add_argument("--seed", type=int, default="171198")
|
106 |
|
107 |
# Pipeline parameters
|
108 |
+
parser.add_argument(
|
109 |
+
"--num_inference_steps", type=int, default=40, help="Number of inference steps"
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"--num_images_per_prompt",
|
113 |
+
type=int,
|
114 |
+
default=1,
|
115 |
+
help="Number of images per prompt",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--guidance_scale",
|
119 |
+
type=float,
|
120 |
+
default=3,
|
121 |
+
help="Guidance scale for the pipeline",
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--height", type=int, default=512, help="Height of the output video frames"
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--width", type=int, default=768, help="Width of the output video frames"
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--num_frames",
|
131 |
+
type=int,
|
132 |
+
default=121,
|
133 |
+
help="Number of frames to generate in the output video",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--frame_rate", type=int, default=25, help="Frame rate for the output video"
|
137 |
+
)
|
138 |
|
139 |
# Prompts
|
140 |
+
parser.add_argument(
|
141 |
+
"--prompt",
|
142 |
+
type=str,
|
143 |
+
default='A man wearing a black leather jacket and blue jeans is riding a Harley Davidson motorcycle down a paved road. The man has short brown hair and is wearing a black helmet. The motorcycle is a dark red color with a large front fairing. The road is surrounded by green grass and trees. There is a gas station on the left side of the road with a red and white sign that says "Oil" and "Diner".',
|
144 |
+
help="Text prompt to guide generation",
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--negative_prompt",
|
148 |
+
type=str,
|
149 |
+
default="worst quality, inconsistent motion, blurry, jittery, distorted",
|
150 |
+
help="Negative prompt for undesired features",
|
151 |
+
)
|
152 |
|
153 |
args = parser.parse_args()
|
154 |
|
155 |
# Paths for the separate mode directories
|
156 |
ckpt_dir = Path(args.ckpt_dir)
|
157 |
+
unet_dir = ckpt_dir / "unet"
|
158 |
+
vae_dir = ckpt_dir / "vae"
|
159 |
+
scheduler_dir = ckpt_dir / "scheduler"
|
160 |
|
161 |
# Load models
|
162 |
vae = load_vae(vae_dir)
|
163 |
unet = load_unet(unet_dir)
|
164 |
scheduler = load_scheduler(scheduler_dir)
|
165 |
patchifier = SymmetricPatchifier(patch_size=1)
|
166 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
167 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
168 |
+
).to("cuda")
|
169 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
170 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
171 |
+
)
|
172 |
|
173 |
# Use submodels for the pipeline
|
174 |
submodel_dict = {
|
|
|
184 |
|
185 |
# Load media (video or image)
|
186 |
if args.video_path:
|
187 |
+
media_items = load_video_to_tensor_with_resize(
|
188 |
+
args.video_path, args.height, args.width
|
189 |
+
).unsqueeze(0)
|
190 |
elif args.image_path:
|
191 |
+
media_items = load_image_to_tensor_with_resize(
|
192 |
+
args.image_path, args.height, args.width
|
193 |
+
)
|
194 |
else:
|
195 |
raise ValueError("Either --video_path or --image_path must be provided.")
|
196 |
|
197 |
# Prepare input for the pipeline
|
198 |
sample = {
|
199 |
"prompt": args.prompt,
|
200 |
+
"prompt_attention_mask": None,
|
201 |
+
"negative_prompt": args.negative_prompt,
|
202 |
+
"negative_prompt_attention_mask": None,
|
203 |
+
"media_items": media_items,
|
204 |
}
|
205 |
|
|
|
206 |
random.seed(args.seed)
|
207 |
np.random.seed(args.seed)
|
208 |
torch.manual_seed(args.seed)
|
|
|
223 |
**sample,
|
224 |
is_video=True,
|
225 |
vae_per_channel_normalize=True,
|
226 |
+
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
227 |
).images
|
228 |
+
|
229 |
# Save output video
|
230 |
+
def get_unique_filename(base, ext, dir=".", index_range=1000):
|
231 |
for i in range(index_range):
|
232 |
filename = os.path.join(dir, f"{base}_{i}{ext}")
|
233 |
if not os.path.exists(filename):
|
234 |
return filename
|
235 |
+
raise FileExistsError(
|
236 |
+
f"Could not find a unique filename after {index_range} attempts."
|
237 |
+
)
|
238 |
|
239 |
for i in range(images.shape[0]):
|
240 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
|
|
243 |
height, width = video_np.shape[1:3]
|
244 |
output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
|
245 |
|
246 |
+
out = cv2.VideoWriter(
|
247 |
+
output_filename, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
|
248 |
+
)
|
249 |
|
250 |
for frame in video_np[..., ::-1]:
|
251 |
out.write(frame)
|
xora/examples/text_to_video.py
CHANGED
@@ -10,16 +10,18 @@ import safetensors.torch
|
|
10 |
import json
|
11 |
import argparse
|
12 |
|
|
|
13 |
def load_vae(vae_dir):
|
14 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
15 |
vae_config_path = vae_dir / "config.json"
|
16 |
-
with open(vae_config_path,
|
17 |
vae_config = json.load(f)
|
18 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
19 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
20 |
vae.load_state_dict(vae_state_dict)
|
21 |
return vae.cuda().to(torch.bfloat16)
|
22 |
|
|
|
23 |
def load_unet(unet_dir):
|
24 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
25 |
unet_config_path = unet_dir / "config.json"
|
@@ -29,22 +31,31 @@ def load_unet(unet_dir):
|
|
29 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
30 |
return transformer.cuda()
|
31 |
|
|
|
32 |
def load_scheduler(scheduler_dir):
|
33 |
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
34 |
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
35 |
return RectifiedFlowScheduler.from_config(scheduler_config)
|
36 |
|
|
|
37 |
def main():
|
38 |
# Parse command line arguments
|
39 |
-
parser = argparse.ArgumentParser(
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
args = parser.parse_args()
|
42 |
|
43 |
# Paths for the separate mode directories
|
44 |
separate_dir = Path(args.separate_dir)
|
45 |
-
unet_dir = separate_dir /
|
46 |
-
vae_dir = separate_dir /
|
47 |
-
scheduler_dir = separate_dir /
|
48 |
|
49 |
# Load models
|
50 |
vae = load_vae(vae_dir)
|
@@ -54,8 +65,12 @@ def main():
|
|
54 |
# Patchifier (remains the same)
|
55 |
patchifier = SymmetricPatchifier(patch_size=1)
|
56 |
|
57 |
-
text_encoder = T5EncoderModel.from_pretrained(
|
58 |
-
|
|
|
|
|
|
|
|
|
59 |
|
60 |
# Use submodels for the pipeline
|
61 |
submodel_dict = {
|
@@ -79,14 +94,14 @@ def main():
|
|
79 |
frame_rate = 25
|
80 |
sample = {
|
81 |
"prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
}
|
87 |
|
88 |
# Generate images (video frames)
|
89 |
-
|
90 |
num_inference_steps=num_inference_steps,
|
91 |
num_images_per_prompt=num_images_per_prompt,
|
92 |
guidance_scale=guidance_scale,
|
@@ -104,5 +119,6 @@ def main():
|
|
104 |
|
105 |
print("Generated images (video frames).")
|
106 |
|
|
|
107 |
if __name__ == "__main__":
|
108 |
main()
|
|
|
10 |
import json
|
11 |
import argparse
|
12 |
|
13 |
+
|
14 |
def load_vae(vae_dir):
|
15 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
16 |
vae_config_path = vae_dir / "config.json"
|
17 |
+
with open(vae_config_path, "r") as f:
|
18 |
vae_config = json.load(f)
|
19 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
20 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
21 |
vae.load_state_dict(vae_state_dict)
|
22 |
return vae.cuda().to(torch.bfloat16)
|
23 |
|
24 |
+
|
25 |
def load_unet(unet_dir):
|
26 |
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
27 |
unet_config_path = unet_dir / "config.json"
|
|
|
31 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
32 |
return transformer.cuda()
|
33 |
|
34 |
+
|
35 |
def load_scheduler(scheduler_dir):
|
36 |
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
37 |
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
38 |
return RectifiedFlowScheduler.from_config(scheduler_config)
|
39 |
|
40 |
+
|
41 |
def main():
|
42 |
# Parse command line arguments
|
43 |
+
parser = argparse.ArgumentParser(
|
44 |
+
description="Load models from separate directories"
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--separate_dir",
|
48 |
+
type=str,
|
49 |
+
required=True,
|
50 |
+
help="Path to the directory containing unet, vae, and scheduler subdirectories",
|
51 |
+
)
|
52 |
args = parser.parse_args()
|
53 |
|
54 |
# Paths for the separate mode directories
|
55 |
separate_dir = Path(args.separate_dir)
|
56 |
+
unet_dir = separate_dir / "unet"
|
57 |
+
vae_dir = separate_dir / "vae"
|
58 |
+
scheduler_dir = separate_dir / "scheduler"
|
59 |
|
60 |
# Load models
|
61 |
vae = load_vae(vae_dir)
|
|
|
65 |
# Patchifier (remains the same)
|
66 |
patchifier = SymmetricPatchifier(patch_size=1)
|
67 |
|
68 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
69 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
70 |
+
).to("cuda")
|
71 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
72 |
+
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
73 |
+
)
|
74 |
|
75 |
# Use submodels for the pipeline
|
76 |
submodel_dict = {
|
|
|
94 |
frame_rate = 25
|
95 |
sample = {
|
96 |
"prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
|
97 |
+
"The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
|
98 |
+
"prompt_attention_mask": None, # Adjust attention masks as needed
|
99 |
+
"negative_prompt": "Ugly deformed",
|
100 |
+
"negative_prompt_attention_mask": None,
|
101 |
}
|
102 |
|
103 |
# Generate images (video frames)
|
104 |
+
_ = pipeline(
|
105 |
num_inference_steps=num_inference_steps,
|
106 |
num_images_per_prompt=num_images_per_prompt,
|
107 |
guidance_scale=guidance_scale,
|
|
|
119 |
|
120 |
print("Generated images (video frames).")
|
121 |
|
122 |
+
|
123 |
if __name__ == "__main__":
|
124 |
main()
|
xora/models/autoencoders/causal_conv3d.py
CHANGED
@@ -40,11 +40,17 @@ class CausalConv3d(nn.Module):
|
|
40 |
|
41 |
def forward(self, x, causal: bool = True):
|
42 |
if causal:
|
43 |
-
first_frame_pad = x[:, :, :1, :, :].repeat(
|
|
|
|
|
44 |
x = torch.concatenate((first_frame_pad, x), dim=2)
|
45 |
else:
|
46 |
-
first_frame_pad = x[:, :, :1, :, :].repeat(
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
49 |
x = self.conv(x)
|
50 |
return x
|
|
|
40 |
|
41 |
def forward(self, x, causal: bool = True):
|
42 |
if causal:
|
43 |
+
first_frame_pad = x[:, :, :1, :, :].repeat(
|
44 |
+
(1, 1, self.time_kernel_size - 1, 1, 1)
|
45 |
+
)
|
46 |
x = torch.concatenate((first_frame_pad, x), dim=2)
|
47 |
else:
|
48 |
+
first_frame_pad = x[:, :, :1, :, :].repeat(
|
49 |
+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
50 |
+
)
|
51 |
+
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
52 |
+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
53 |
+
)
|
54 |
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
55 |
x = self.conv(x)
|
56 |
return x
|
xora/models/autoencoders/causal_video_autoencoder.py
CHANGED
@@ -16,9 +16,15 @@ from xora.models.autoencoders.vae import AutoencoderKLWrapper
|
|
16 |
|
17 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
18 |
|
|
|
19 |
class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
20 |
@classmethod
|
21 |
-
def from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
22 |
config_local_path = pretrained_model_name_or_path / "config.json"
|
23 |
config = cls.load_config(config_local_path, **kwargs)
|
24 |
video_vae = cls.from_config(config)
|
@@ -28,29 +34,41 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
28 |
ckpt_state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
|
29 |
video_vae.load_state_dict(ckpt_state_dict)
|
30 |
|
31 |
-
statistics_local_path =
|
|
|
|
|
32 |
if statistics_local_path.exists():
|
33 |
with open(statistics_local_path, "r") as file:
|
34 |
data = json.load(file)
|
35 |
transposed_data = list(zip(*data["data"]))
|
36 |
-
data_dict = {
|
|
|
|
|
|
|
37 |
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
38 |
video_vae.register_buffer(
|
39 |
-
"mean_of_means",
|
|
|
|
|
|
|
40 |
)
|
41 |
|
42 |
return video_vae
|
43 |
|
44 |
@staticmethod
|
45 |
def from_config(config):
|
46 |
-
assert
|
|
|
|
|
47 |
if isinstance(config["dims"], list):
|
48 |
config["dims"] = tuple(config["dims"])
|
49 |
|
50 |
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
51 |
|
52 |
double_z = config.get("double_z", True)
|
53 |
-
latent_log_var = config.get(
|
|
|
|
|
54 |
use_quant_conv = config.get("use_quant_conv", True)
|
55 |
|
56 |
if use_quant_conv and latent_log_var == "uniform":
|
@@ -91,7 +109,8 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
91 |
_class_name="CausalVideoAutoencoder",
|
92 |
dims=self.dims,
|
93 |
in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
|
94 |
-
out_channels=self.decoder.conv_out.out_channels
|
|
|
95 |
latent_channels=self.decoder.conv_in.in_channels,
|
96 |
blocks=self.encoder.blocks_desc,
|
97 |
scaling_factor=1.0,
|
@@ -112,13 +131,26 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
112 |
@property
|
113 |
def spatial_downscale_factor(self):
|
114 |
return (
|
115 |
-
2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
* self.encoder.patch_size
|
117 |
)
|
118 |
|
119 |
@property
|
120 |
def temporal_downscale_factor(self):
|
121 |
-
return 2 ** len(
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
def to_json_string(self) -> str:
|
124 |
import json
|
@@ -146,7 +178,9 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
|
146 |
key = key.replace(k, v)
|
147 |
|
148 |
if "norm" in key and key not in model_keys:
|
149 |
-
logger.info(
|
|
|
|
|
150 |
continue
|
151 |
|
152 |
converted_state_dict[key] = value
|
@@ -293,7 +327,9 @@ class Encoder(nn.Module):
|
|
293 |
|
294 |
# out
|
295 |
if norm_layer == "group_norm":
|
296 |
-
self.conv_norm_out = nn.GroupNorm(
|
|
|
|
|
297 |
elif norm_layer == "pixel_norm":
|
298 |
self.conv_norm_out = PixelNorm()
|
299 |
elif norm_layer == "layer_norm":
|
@@ -308,7 +344,9 @@ class Encoder(nn.Module):
|
|
308 |
conv_out_channels += 1
|
309 |
elif latent_log_var != "none":
|
310 |
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
311 |
-
self.conv_out = make_conv_nd(
|
|
|
|
|
312 |
|
313 |
self.gradient_checkpointing = False
|
314 |
|
@@ -337,11 +375,15 @@ class Encoder(nn.Module):
|
|
337 |
|
338 |
if num_dims == 4:
|
339 |
# For shape (B, C, H, W)
|
340 |
-
repeated_last_channel = last_channel.repeat(
|
|
|
|
|
341 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
342 |
elif num_dims == 5:
|
343 |
# For shape (B, C, F, H, W)
|
344 |
-
repeated_last_channel = last_channel.repeat(
|
|
|
|
|
345 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
346 |
else:
|
347 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
@@ -430,25 +472,35 @@ class Decoder(nn.Module):
|
|
430 |
norm_layer=norm_layer,
|
431 |
)
|
432 |
elif block_name == "compress_time":
|
433 |
-
block = DepthToSpaceUpsample(
|
|
|
|
|
434 |
elif block_name == "compress_space":
|
435 |
-
block = DepthToSpaceUpsample(
|
|
|
|
|
436 |
elif block_name == "compress_all":
|
437 |
-
block = DepthToSpaceUpsample(
|
|
|
|
|
438 |
else:
|
439 |
raise ValueError(f"unknown layer: {block_name}")
|
440 |
|
441 |
self.up_blocks.append(block)
|
442 |
|
443 |
if norm_layer == "group_norm":
|
444 |
-
self.conv_norm_out = nn.GroupNorm(
|
|
|
|
|
445 |
elif norm_layer == "pixel_norm":
|
446 |
self.conv_norm_out = PixelNorm()
|
447 |
elif norm_layer == "layer_norm":
|
448 |
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
449 |
|
450 |
self.conv_act = nn.SiLU()
|
451 |
-
self.conv_out = make_conv_nd(
|
|
|
|
|
452 |
|
453 |
self.gradient_checkpointing = False
|
454 |
|
@@ -509,7 +561,9 @@ class UNetMidBlock3D(nn.Module):
|
|
509 |
norm_layer: str = "group_norm",
|
510 |
):
|
511 |
super().__init__()
|
512 |
-
resnet_groups =
|
|
|
|
|
513 |
|
514 |
self.res_blocks = nn.ModuleList(
|
515 |
[
|
@@ -526,7 +580,9 @@ class UNetMidBlock3D(nn.Module):
|
|
526 |
]
|
527 |
)
|
528 |
|
529 |
-
def forward(
|
|
|
|
|
530 |
for resnet in self.res_blocks:
|
531 |
hidden_states = resnet(hidden_states, causal=causal)
|
532 |
|
@@ -604,7 +660,9 @@ class ResnetBlock3D(nn.Module):
|
|
604 |
self.use_conv_shortcut = conv_shortcut
|
605 |
|
606 |
if norm_layer == "group_norm":
|
607 |
-
self.norm1 = nn.GroupNorm(
|
|
|
|
|
608 |
elif norm_layer == "pixel_norm":
|
609 |
self.norm1 = PixelNorm()
|
610 |
elif norm_layer == "layer_norm":
|
@@ -612,10 +670,20 @@ class ResnetBlock3D(nn.Module):
|
|
612 |
|
613 |
self.non_linearity = nn.SiLU()
|
614 |
|
615 |
-
self.conv1 = make_conv_nd(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
616 |
|
617 |
if norm_layer == "group_norm":
|
618 |
-
self.norm2 = nn.GroupNorm(
|
|
|
|
|
619 |
elif norm_layer == "pixel_norm":
|
620 |
self.norm2 = PixelNorm()
|
621 |
elif norm_layer == "layer_norm":
|
@@ -623,16 +691,28 @@ class ResnetBlock3D(nn.Module):
|
|
623 |
|
624 |
self.dropout = torch.nn.Dropout(dropout)
|
625 |
|
626 |
-
self.conv2 = make_conv_nd(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
627 |
|
628 |
self.conv_shortcut = (
|
629 |
-
make_linear_nd(
|
|
|
|
|
630 |
if in_channels != out_channels
|
631 |
else nn.Identity()
|
632 |
)
|
633 |
|
634 |
self.norm3 = (
|
635 |
-
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
|
|
|
|
636 |
)
|
637 |
|
638 |
def forward(
|
@@ -669,9 +749,17 @@ def patchify(x, patch_size_hw, patch_size_t=1):
|
|
669 |
if patch_size_hw == 1 and patch_size_t == 1:
|
670 |
return x
|
671 |
if x.dim() == 4:
|
672 |
-
x = rearrange(
|
|
|
|
|
673 |
elif x.dim() == 5:
|
674 |
-
x = rearrange(
|
|
|
|
|
|
|
|
|
|
|
|
|
675 |
else:
|
676 |
raise ValueError(f"Invalid input shape: {x.shape}")
|
677 |
|
@@ -683,9 +771,17 @@ def unpatchify(x, patch_size_hw, patch_size_t=1):
|
|
683 |
return x
|
684 |
|
685 |
if x.dim() == 4:
|
686 |
-
x = rearrange(
|
|
|
|
|
687 |
elif x.dim() == 5:
|
688 |
-
x = rearrange(
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
|
690 |
return x
|
691 |
|
@@ -755,14 +851,18 @@ def demo_video_autoencoder_forward_backward():
|
|
755 |
print(f"input shape={input_videos.shape}")
|
756 |
print(f"latent shape={latent.shape}")
|
757 |
|
758 |
-
reconstructed_videos = video_autoencoder.decode(
|
|
|
|
|
759 |
|
760 |
print(f"reconstructed shape={reconstructed_videos.shape}")
|
761 |
|
762 |
# Validate that single image gets treated the same way as first frame
|
763 |
input_image = input_videos[:, :, :1, :, :]
|
764 |
image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
|
765 |
-
reconstructed_image = video_autoencoder.decode(
|
|
|
|
|
766 |
|
767 |
first_frame_latent = latent[:, :, :1, :, :]
|
768 |
|
|
|
16 |
|
17 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
18 |
|
19 |
+
|
20 |
class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
21 |
@classmethod
|
22 |
+
def from_pretrained(
|
23 |
+
cls,
|
24 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
25 |
+
*args,
|
26 |
+
**kwargs,
|
27 |
+
):
|
28 |
config_local_path = pretrained_model_name_or_path / "config.json"
|
29 |
config = cls.load_config(config_local_path, **kwargs)
|
30 |
video_vae = cls.from_config(config)
|
|
|
34 |
ckpt_state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
|
35 |
video_vae.load_state_dict(ckpt_state_dict)
|
36 |
|
37 |
+
statistics_local_path = (
|
38 |
+
pretrained_model_name_or_path / "per_channel_statistics.json"
|
39 |
+
)
|
40 |
if statistics_local_path.exists():
|
41 |
with open(statistics_local_path, "r") as file:
|
42 |
data = json.load(file)
|
43 |
transposed_data = list(zip(*data["data"]))
|
44 |
+
data_dict = {
|
45 |
+
col: torch.tensor(vals)
|
46 |
+
for col, vals in zip(data["columns"], transposed_data)
|
47 |
+
}
|
48 |
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
49 |
video_vae.register_buffer(
|
50 |
+
"mean_of_means",
|
51 |
+
data_dict.get(
|
52 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
53 |
+
),
|
54 |
)
|
55 |
|
56 |
return video_vae
|
57 |
|
58 |
@staticmethod
|
59 |
def from_config(config):
|
60 |
+
assert (
|
61 |
+
config["_class_name"] == "CausalVideoAutoencoder"
|
62 |
+
), "config must have _class_name=CausalVideoAutoencoder"
|
63 |
if isinstance(config["dims"], list):
|
64 |
config["dims"] = tuple(config["dims"])
|
65 |
|
66 |
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
67 |
|
68 |
double_z = config.get("double_z", True)
|
69 |
+
latent_log_var = config.get(
|
70 |
+
"latent_log_var", "per_channel" if double_z else "none"
|
71 |
+
)
|
72 |
use_quant_conv = config.get("use_quant_conv", True)
|
73 |
|
74 |
if use_quant_conv and latent_log_var == "uniform":
|
|
|
109 |
_class_name="CausalVideoAutoencoder",
|
110 |
dims=self.dims,
|
111 |
in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
|
112 |
+
out_channels=self.decoder.conv_out.out_channels
|
113 |
+
// self.decoder.patch_size**2,
|
114 |
latent_channels=self.decoder.conv_in.in_channels,
|
115 |
blocks=self.encoder.blocks_desc,
|
116 |
scaling_factor=1.0,
|
|
|
131 |
@property
|
132 |
def spatial_downscale_factor(self):
|
133 |
return (
|
134 |
+
2
|
135 |
+
** len(
|
136 |
+
[
|
137 |
+
block
|
138 |
+
for block in self.encoder.blocks_desc
|
139 |
+
if block[0] in ["compress_space", "compress_all"]
|
140 |
+
]
|
141 |
+
)
|
142 |
* self.encoder.patch_size
|
143 |
)
|
144 |
|
145 |
@property
|
146 |
def temporal_downscale_factor(self):
|
147 |
+
return 2 ** len(
|
148 |
+
[
|
149 |
+
block
|
150 |
+
for block in self.encoder.blocks_desc
|
151 |
+
if block[0] in ["compress_time", "compress_all"]
|
152 |
+
]
|
153 |
+
)
|
154 |
|
155 |
def to_json_string(self) -> str:
|
156 |
import json
|
|
|
178 |
key = key.replace(k, v)
|
179 |
|
180 |
if "norm" in key and key not in model_keys:
|
181 |
+
logger.info(
|
182 |
+
f"Removing key {key} from state_dict as it is not present in the model"
|
183 |
+
)
|
184 |
continue
|
185 |
|
186 |
converted_state_dict[key] = value
|
|
|
327 |
|
328 |
# out
|
329 |
if norm_layer == "group_norm":
|
330 |
+
self.conv_norm_out = nn.GroupNorm(
|
331 |
+
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
332 |
+
)
|
333 |
elif norm_layer == "pixel_norm":
|
334 |
self.conv_norm_out = PixelNorm()
|
335 |
elif norm_layer == "layer_norm":
|
|
|
344 |
conv_out_channels += 1
|
345 |
elif latent_log_var != "none":
|
346 |
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
347 |
+
self.conv_out = make_conv_nd(
|
348 |
+
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
|
349 |
+
)
|
350 |
|
351 |
self.gradient_checkpointing = False
|
352 |
|
|
|
375 |
|
376 |
if num_dims == 4:
|
377 |
# For shape (B, C, H, W)
|
378 |
+
repeated_last_channel = last_channel.repeat(
|
379 |
+
1, sample.shape[1] - 2, 1, 1
|
380 |
+
)
|
381 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
382 |
elif num_dims == 5:
|
383 |
# For shape (B, C, F, H, W)
|
384 |
+
repeated_last_channel = last_channel.repeat(
|
385 |
+
1, sample.shape[1] - 2, 1, 1, 1
|
386 |
+
)
|
387 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
388 |
else:
|
389 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
|
|
472 |
norm_layer=norm_layer,
|
473 |
)
|
474 |
elif block_name == "compress_time":
|
475 |
+
block = DepthToSpaceUpsample(
|
476 |
+
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
|
477 |
+
)
|
478 |
elif block_name == "compress_space":
|
479 |
+
block = DepthToSpaceUpsample(
|
480 |
+
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
481 |
+
)
|
482 |
elif block_name == "compress_all":
|
483 |
+
block = DepthToSpaceUpsample(
|
484 |
+
dims=dims, in_channels=input_channel, stride=(2, 2, 2)
|
485 |
+
)
|
486 |
else:
|
487 |
raise ValueError(f"unknown layer: {block_name}")
|
488 |
|
489 |
self.up_blocks.append(block)
|
490 |
|
491 |
if norm_layer == "group_norm":
|
492 |
+
self.conv_norm_out = nn.GroupNorm(
|
493 |
+
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
494 |
+
)
|
495 |
elif norm_layer == "pixel_norm":
|
496 |
self.conv_norm_out = PixelNorm()
|
497 |
elif norm_layer == "layer_norm":
|
498 |
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
499 |
|
500 |
self.conv_act = nn.SiLU()
|
501 |
+
self.conv_out = make_conv_nd(
|
502 |
+
dims, output_channel, out_channels, 3, padding=1, causal=True
|
503 |
+
)
|
504 |
|
505 |
self.gradient_checkpointing = False
|
506 |
|
|
|
561 |
norm_layer: str = "group_norm",
|
562 |
):
|
563 |
super().__init__()
|
564 |
+
resnet_groups = (
|
565 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
566 |
+
)
|
567 |
|
568 |
self.res_blocks = nn.ModuleList(
|
569 |
[
|
|
|
580 |
]
|
581 |
)
|
582 |
|
583 |
+
def forward(
|
584 |
+
self, hidden_states: torch.FloatTensor, causal: bool = True
|
585 |
+
) -> torch.FloatTensor:
|
586 |
for resnet in self.res_blocks:
|
587 |
hidden_states = resnet(hidden_states, causal=causal)
|
588 |
|
|
|
660 |
self.use_conv_shortcut = conv_shortcut
|
661 |
|
662 |
if norm_layer == "group_norm":
|
663 |
+
self.norm1 = nn.GroupNorm(
|
664 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
665 |
+
)
|
666 |
elif norm_layer == "pixel_norm":
|
667 |
self.norm1 = PixelNorm()
|
668 |
elif norm_layer == "layer_norm":
|
|
|
670 |
|
671 |
self.non_linearity = nn.SiLU()
|
672 |
|
673 |
+
self.conv1 = make_conv_nd(
|
674 |
+
dims,
|
675 |
+
in_channels,
|
676 |
+
out_channels,
|
677 |
+
kernel_size=3,
|
678 |
+
stride=1,
|
679 |
+
padding=1,
|
680 |
+
causal=True,
|
681 |
+
)
|
682 |
|
683 |
if norm_layer == "group_norm":
|
684 |
+
self.norm2 = nn.GroupNorm(
|
685 |
+
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
686 |
+
)
|
687 |
elif norm_layer == "pixel_norm":
|
688 |
self.norm2 = PixelNorm()
|
689 |
elif norm_layer == "layer_norm":
|
|
|
691 |
|
692 |
self.dropout = torch.nn.Dropout(dropout)
|
693 |
|
694 |
+
self.conv2 = make_conv_nd(
|
695 |
+
dims,
|
696 |
+
out_channels,
|
697 |
+
out_channels,
|
698 |
+
kernel_size=3,
|
699 |
+
stride=1,
|
700 |
+
padding=1,
|
701 |
+
causal=True,
|
702 |
+
)
|
703 |
|
704 |
self.conv_shortcut = (
|
705 |
+
make_linear_nd(
|
706 |
+
dims=dims, in_channels=in_channels, out_channels=out_channels
|
707 |
+
)
|
708 |
if in_channels != out_channels
|
709 |
else nn.Identity()
|
710 |
)
|
711 |
|
712 |
self.norm3 = (
|
713 |
+
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
714 |
+
if in_channels != out_channels
|
715 |
+
else nn.Identity()
|
716 |
)
|
717 |
|
718 |
def forward(
|
|
|
749 |
if patch_size_hw == 1 and patch_size_t == 1:
|
750 |
return x
|
751 |
if x.dim() == 4:
|
752 |
+
x = rearrange(
|
753 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
754 |
+
)
|
755 |
elif x.dim() == 5:
|
756 |
+
x = rearrange(
|
757 |
+
x,
|
758 |
+
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
759 |
+
p=patch_size_t,
|
760 |
+
q=patch_size_hw,
|
761 |
+
r=patch_size_hw,
|
762 |
+
)
|
763 |
else:
|
764 |
raise ValueError(f"Invalid input shape: {x.shape}")
|
765 |
|
|
|
771 |
return x
|
772 |
|
773 |
if x.dim() == 4:
|
774 |
+
x = rearrange(
|
775 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
776 |
+
)
|
777 |
elif x.dim() == 5:
|
778 |
+
x = rearrange(
|
779 |
+
x,
|
780 |
+
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
781 |
+
p=patch_size_t,
|
782 |
+
q=patch_size_hw,
|
783 |
+
r=patch_size_hw,
|
784 |
+
)
|
785 |
|
786 |
return x
|
787 |
|
|
|
851 |
print(f"input shape={input_videos.shape}")
|
852 |
print(f"latent shape={latent.shape}")
|
853 |
|
854 |
+
reconstructed_videos = video_autoencoder.decode(
|
855 |
+
latent, target_shape=input_videos.shape
|
856 |
+
).sample
|
857 |
|
858 |
print(f"reconstructed shape={reconstructed_videos.shape}")
|
859 |
|
860 |
# Validate that single image gets treated the same way as first frame
|
861 |
input_image = input_videos[:, :, :1, :, :]
|
862 |
image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
|
863 |
+
reconstructed_image = video_autoencoder.decode(
|
864 |
+
image_latent, target_shape=image_latent.shape
|
865 |
+
).sample
|
866 |
|
867 |
first_frame_latent = latent[:, :, :1, :, :]
|
868 |
|
xora/models/autoencoders/conv_nd_factory.py
CHANGED
@@ -71,8 +71,12 @@ def make_linear_nd(
|
|
71 |
bias=True,
|
72 |
):
|
73 |
if dims == 2:
|
74 |
-
return torch.nn.Conv2d(
|
|
|
|
|
75 |
elif dims == 3 or dims == (2, 1):
|
76 |
-
return torch.nn.Conv3d(
|
|
|
|
|
77 |
else:
|
78 |
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
71 |
bias=True,
|
72 |
):
|
73 |
if dims == 2:
|
74 |
+
return torch.nn.Conv2d(
|
75 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
76 |
+
)
|
77 |
elif dims == 3 or dims == (2, 1):
|
78 |
+
return torch.nn.Conv3d(
|
79 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
80 |
+
)
|
81 |
else:
|
82 |
raise ValueError(f"unsupported dimensions: {dims}")
|
xora/models/autoencoders/dual_conv3d.py
CHANGED
@@ -27,7 +27,9 @@ class DualConv3d(nn.Module):
|
|
27 |
if isinstance(kernel_size, int):
|
28 |
kernel_size = (kernel_size, kernel_size, kernel_size)
|
29 |
if kernel_size == (1, 1, 1):
|
30 |
-
raise ValueError(
|
|
|
|
|
31 |
if isinstance(stride, int):
|
32 |
stride = (stride, stride, stride)
|
33 |
if isinstance(padding, int):
|
@@ -40,11 +42,19 @@ class DualConv3d(nn.Module):
|
|
40 |
self.bias = bias
|
41 |
|
42 |
# Define the size of the channels after the first convolution
|
43 |
-
intermediate_channels =
|
|
|
|
|
44 |
|
45 |
# Define parameters for the first convolution
|
46 |
self.weight1 = nn.Parameter(
|
47 |
-
torch.Tensor(
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
)
|
49 |
self.stride1 = (1, stride[1], stride[2])
|
50 |
self.padding1 = (0, padding[1], padding[2])
|
@@ -55,7 +65,11 @@ class DualConv3d(nn.Module):
|
|
55 |
self.register_parameter("bias1", None)
|
56 |
|
57 |
# Define parameters for the second convolution
|
58 |
-
self.weight2 = nn.Parameter(
|
|
|
|
|
|
|
|
|
59 |
self.stride2 = (stride[0], 1, 1)
|
60 |
self.padding2 = (padding[0], 0, 0)
|
61 |
self.dilation2 = (dilation[0], 1, 1)
|
@@ -86,13 +100,29 @@ class DualConv3d(nn.Module):
|
|
86 |
|
87 |
def forward_with_3d(self, x, skip_time_conv):
|
88 |
# First convolution
|
89 |
-
x = F.conv3d(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
if skip_time_conv:
|
92 |
return x
|
93 |
|
94 |
# Second convolution
|
95 |
-
x = F.conv3d(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
return x
|
98 |
|
|
|
27 |
if isinstance(kernel_size, int):
|
28 |
kernel_size = (kernel_size, kernel_size, kernel_size)
|
29 |
if kernel_size == (1, 1, 1):
|
30 |
+
raise ValueError(
|
31 |
+
"kernel_size must be greater than 1. Use make_linear_nd instead."
|
32 |
+
)
|
33 |
if isinstance(stride, int):
|
34 |
stride = (stride, stride, stride)
|
35 |
if isinstance(padding, int):
|
|
|
42 |
self.bias = bias
|
43 |
|
44 |
# Define the size of the channels after the first convolution
|
45 |
+
intermediate_channels = (
|
46 |
+
out_channels if in_channels < out_channels else in_channels
|
47 |
+
)
|
48 |
|
49 |
# Define parameters for the first convolution
|
50 |
self.weight1 = nn.Parameter(
|
51 |
+
torch.Tensor(
|
52 |
+
intermediate_channels,
|
53 |
+
in_channels // groups,
|
54 |
+
1,
|
55 |
+
kernel_size[1],
|
56 |
+
kernel_size[2],
|
57 |
+
)
|
58 |
)
|
59 |
self.stride1 = (1, stride[1], stride[2])
|
60 |
self.padding1 = (0, padding[1], padding[2])
|
|
|
65 |
self.register_parameter("bias1", None)
|
66 |
|
67 |
# Define parameters for the second convolution
|
68 |
+
self.weight2 = nn.Parameter(
|
69 |
+
torch.Tensor(
|
70 |
+
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
|
71 |
+
)
|
72 |
+
)
|
73 |
self.stride2 = (stride[0], 1, 1)
|
74 |
self.padding2 = (padding[0], 0, 0)
|
75 |
self.dilation2 = (dilation[0], 1, 1)
|
|
|
100 |
|
101 |
def forward_with_3d(self, x, skip_time_conv):
|
102 |
# First convolution
|
103 |
+
x = F.conv3d(
|
104 |
+
x,
|
105 |
+
self.weight1,
|
106 |
+
self.bias1,
|
107 |
+
self.stride1,
|
108 |
+
self.padding1,
|
109 |
+
self.dilation1,
|
110 |
+
self.groups,
|
111 |
+
)
|
112 |
|
113 |
if skip_time_conv:
|
114 |
return x
|
115 |
|
116 |
# Second convolution
|
117 |
+
x = F.conv3d(
|
118 |
+
x,
|
119 |
+
self.weight2,
|
120 |
+
self.bias2,
|
121 |
+
self.stride2,
|
122 |
+
self.padding2,
|
123 |
+
self.dilation2,
|
124 |
+
self.groups,
|
125 |
+
)
|
126 |
|
127 |
return x
|
128 |
|
xora/models/autoencoders/vae.py
CHANGED
@@ -4,7 +4,10 @@ import torch
|
|
4 |
import math
|
5 |
import torch.nn as nn
|
6 |
from diffusers import ConfigMixin, ModelMixin
|
7 |
-
from diffusers.models.autoencoders.vae import
|
|
|
|
|
|
|
8 |
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
9 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd
|
10 |
|
@@ -43,8 +46,12 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
43 |
quant_dims = 2 if dims == 2 else 3
|
44 |
self.decoder = decoder
|
45 |
if use_quant_conv:
|
46 |
-
self.quant_conv = make_conv_nd(
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
else:
|
49 |
self.quant_conv = nn.Identity()
|
50 |
self.post_quant_conv = nn.Identity()
|
@@ -104,7 +111,13 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
104 |
for i in range(0, x.shape[3], overlap_size):
|
105 |
row = []
|
106 |
for j in range(0, x.shape[4], overlap_size):
|
107 |
-
tile = x[
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
tile = self.encoder(tile)
|
109 |
tile = self.quant_conv(tile)
|
110 |
row.append(tile)
|
@@ -125,42 +138,58 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
125 |
moments = torch.cat(result_rows, dim=3)
|
126 |
return moments
|
127 |
|
128 |
-
def blend_z(
|
|
|
|
|
129 |
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
130 |
for z in range(blend_extent):
|
131 |
-
b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
|
132 |
-
z / blend_extent
|
133 |
-
)
|
134 |
return b
|
135 |
|
136 |
-
def blend_v(
|
|
|
|
|
137 |
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
138 |
for y in range(blend_extent):
|
139 |
-
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
|
140 |
-
y / blend_extent
|
141 |
-
)
|
142 |
return b
|
143 |
|
144 |
-
def blend_h(
|
|
|
|
|
145 |
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
146 |
for x in range(blend_extent):
|
147 |
-
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
|
148 |
-
x / blend_extent
|
149 |
-
)
|
150 |
return b
|
151 |
|
152 |
def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
|
153 |
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
154 |
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
155 |
row_limit = self.tile_sample_min_size - blend_extent
|
156 |
-
tile_target_shape = (
|
|
|
|
|
|
|
|
|
157 |
# Split z into overlapping 64x64 tiles and decode them separately.
|
158 |
# The tiles have an overlap to avoid seams between tiles.
|
159 |
rows = []
|
160 |
for i in range(0, z.shape[3], overlap_size):
|
161 |
row = []
|
162 |
for j in range(0, z.shape[4], overlap_size):
|
163 |
-
tile = z[
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
tile = self.post_quant_conv(tile)
|
165 |
decoded = self.decoder(tile, target_shape=tile_target_shape)
|
166 |
row.append(decoded)
|
@@ -181,20 +210,34 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
181 |
dec = torch.cat(result_rows, dim=3)
|
182 |
return dec
|
183 |
|
184 |
-
def encode(
|
|
|
|
|
185 |
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
186 |
num_splits = z.shape[2] // self.z_sample_size
|
187 |
sizes = [self.z_sample_size] * num_splits
|
188 |
-
sizes =
|
|
|
|
|
|
|
|
|
189 |
tiles = z.split(sizes, dim=2)
|
190 |
moments_tiles = [
|
191 |
-
|
|
|
|
|
|
|
|
|
192 |
for z_tile in tiles
|
193 |
]
|
194 |
moments = torch.cat(moments_tiles, dim=2)
|
195 |
|
196 |
else:
|
197 |
-
moments =
|
|
|
|
|
|
|
|
|
198 |
|
199 |
posterior = DiagonalGaussianDistribution(moments)
|
200 |
if not return_dict:
|
@@ -207,7 +250,9 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
207 |
moments = self.quant_conv(h)
|
208 |
return moments
|
209 |
|
210 |
-
def _decode(
|
|
|
|
|
211 |
z = self.post_quant_conv(z)
|
212 |
dec = self.decoder(z, target_shape=target_shape)
|
213 |
return dec
|
@@ -219,7 +264,12 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
219 |
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
220 |
reduction_factor = int(
|
221 |
self.encoder.patch_size_t
|
222 |
-
* 2
|
|
|
|
|
|
|
|
|
|
|
223 |
)
|
224 |
split_size = self.z_sample_size // reduction_factor
|
225 |
num_splits = z.shape[2] // split_size
|
|
|
4 |
import math
|
5 |
import torch.nn as nn
|
6 |
from diffusers import ConfigMixin, ModelMixin
|
7 |
+
from diffusers.models.autoencoders.vae import (
|
8 |
+
DecoderOutput,
|
9 |
+
DiagonalGaussianDistribution,
|
10 |
+
)
|
11 |
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
12 |
from xora.models.autoencoders.conv_nd_factory import make_conv_nd
|
13 |
|
|
|
46 |
quant_dims = 2 if dims == 2 else 3
|
47 |
self.decoder = decoder
|
48 |
if use_quant_conv:
|
49 |
+
self.quant_conv = make_conv_nd(
|
50 |
+
quant_dims, 2 * latent_channels, 2 * latent_channels, 1
|
51 |
+
)
|
52 |
+
self.post_quant_conv = make_conv_nd(
|
53 |
+
quant_dims, latent_channels, latent_channels, 1
|
54 |
+
)
|
55 |
else:
|
56 |
self.quant_conv = nn.Identity()
|
57 |
self.post_quant_conv = nn.Identity()
|
|
|
111 |
for i in range(0, x.shape[3], overlap_size):
|
112 |
row = []
|
113 |
for j in range(0, x.shape[4], overlap_size):
|
114 |
+
tile = x[
|
115 |
+
:,
|
116 |
+
:,
|
117 |
+
:,
|
118 |
+
i : i + self.tile_sample_min_size,
|
119 |
+
j : j + self.tile_sample_min_size,
|
120 |
+
]
|
121 |
tile = self.encoder(tile)
|
122 |
tile = self.quant_conv(tile)
|
123 |
row.append(tile)
|
|
|
138 |
moments = torch.cat(result_rows, dim=3)
|
139 |
return moments
|
140 |
|
141 |
+
def blend_z(
|
142 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
143 |
+
) -> torch.Tensor:
|
144 |
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
145 |
for z in range(blend_extent):
|
146 |
+
b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
|
147 |
+
1 - z / blend_extent
|
148 |
+
) + b[:, :, z, :, :] * (z / blend_extent)
|
149 |
return b
|
150 |
|
151 |
+
def blend_v(
|
152 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
153 |
+
) -> torch.Tensor:
|
154 |
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
155 |
for y in range(blend_extent):
|
156 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
|
157 |
+
1 - y / blend_extent
|
158 |
+
) + b[:, :, :, y, :] * (y / blend_extent)
|
159 |
return b
|
160 |
|
161 |
+
def blend_h(
|
162 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
163 |
+
) -> torch.Tensor:
|
164 |
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
165 |
for x in range(blend_extent):
|
166 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
|
167 |
+
1 - x / blend_extent
|
168 |
+
) + b[:, :, :, :, x] * (x / blend_extent)
|
169 |
return b
|
170 |
|
171 |
def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
|
172 |
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
173 |
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
174 |
row_limit = self.tile_sample_min_size - blend_extent
|
175 |
+
tile_target_shape = (
|
176 |
+
*target_shape[:3],
|
177 |
+
self.tile_sample_min_size,
|
178 |
+
self.tile_sample_min_size,
|
179 |
+
)
|
180 |
# Split z into overlapping 64x64 tiles and decode them separately.
|
181 |
# The tiles have an overlap to avoid seams between tiles.
|
182 |
rows = []
|
183 |
for i in range(0, z.shape[3], overlap_size):
|
184 |
row = []
|
185 |
for j in range(0, z.shape[4], overlap_size):
|
186 |
+
tile = z[
|
187 |
+
:,
|
188 |
+
:,
|
189 |
+
:,
|
190 |
+
i : i + self.tile_latent_min_size,
|
191 |
+
j : j + self.tile_latent_min_size,
|
192 |
+
]
|
193 |
tile = self.post_quant_conv(tile)
|
194 |
decoded = self.decoder(tile, target_shape=tile_target_shape)
|
195 |
row.append(decoded)
|
|
|
210 |
dec = torch.cat(result_rows, dim=3)
|
211 |
return dec
|
212 |
|
213 |
+
def encode(
|
214 |
+
self, z: torch.FloatTensor, return_dict: bool = True
|
215 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
216 |
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
217 |
num_splits = z.shape[2] // self.z_sample_size
|
218 |
sizes = [self.z_sample_size] * num_splits
|
219 |
+
sizes = (
|
220 |
+
sizes + [z.shape[2] - sum(sizes)]
|
221 |
+
if z.shape[2] - sum(sizes) > 0
|
222 |
+
else sizes
|
223 |
+
)
|
224 |
tiles = z.split(sizes, dim=2)
|
225 |
moments_tiles = [
|
226 |
+
(
|
227 |
+
self._hw_tiled_encode(z_tile, return_dict)
|
228 |
+
if self.use_hw_tiling
|
229 |
+
else self._encode(z_tile)
|
230 |
+
)
|
231 |
for z_tile in tiles
|
232 |
]
|
233 |
moments = torch.cat(moments_tiles, dim=2)
|
234 |
|
235 |
else:
|
236 |
+
moments = (
|
237 |
+
self._hw_tiled_encode(z, return_dict)
|
238 |
+
if self.use_hw_tiling
|
239 |
+
else self._encode(z)
|
240 |
+
)
|
241 |
|
242 |
posterior = DiagonalGaussianDistribution(moments)
|
243 |
if not return_dict:
|
|
|
250 |
moments = self.quant_conv(h)
|
251 |
return moments
|
252 |
|
253 |
+
def _decode(
|
254 |
+
self, z: torch.FloatTensor, target_shape=None
|
255 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
256 |
z = self.post_quant_conv(z)
|
257 |
dec = self.decoder(z, target_shape=target_shape)
|
258 |
return dec
|
|
|
264 |
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
265 |
reduction_factor = int(
|
266 |
self.encoder.patch_size_t
|
267 |
+
* 2
|
268 |
+
** (
|
269 |
+
len(self.encoder.down_blocks)
|
270 |
+
- 1
|
271 |
+
- math.sqrt(self.encoder.patch_size)
|
272 |
+
)
|
273 |
)
|
274 |
split_size = self.z_sample_size // reduction_factor
|
275 |
num_splits = z.shape[2] // split_size
|
xora/models/autoencoders/vae_encode.py
CHANGED
@@ -6,12 +6,19 @@ from torch import Tensor
|
|
6 |
|
7 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
8 |
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
|
|
|
9 |
try:
|
10 |
import torch_xla.core.xla_model as xm
|
11 |
-
except:
|
12 |
-
|
|
|
13 |
|
14 |
-
def vae_encode(
|
|
|
|
|
|
|
|
|
|
|
15 |
"""
|
16 |
Encodes media items (images or videos) into latent representations using a specified VAE model.
|
17 |
The function supports processing batches of images or video frames and can handle the processing
|
@@ -48,11 +55,15 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
|
|
48 |
if channels != 3:
|
49 |
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
50 |
|
51 |
-
if is_video_shaped and not isinstance(
|
|
|
|
|
52 |
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
53 |
if split_size > 1:
|
54 |
if len(media_items) % split_size != 0:
|
55 |
-
raise ValueError(
|
|
|
|
|
56 |
encode_bs = len(media_items) // split_size
|
57 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
58 |
latents = []
|
@@ -67,22 +78,32 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
|
|
67 |
latents = vae.encode(media_items).latent_dist.sample()
|
68 |
|
69 |
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
70 |
-
if is_video_shaped and not isinstance(
|
|
|
|
|
71 |
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
72 |
return latents
|
73 |
|
74 |
|
75 |
def vae_decode(
|
76 |
-
latents: Tensor,
|
|
|
|
|
|
|
|
|
77 |
) -> Tensor:
|
78 |
is_video_shaped = latents.dim() == 5
|
79 |
batch_size = latents.shape[0]
|
80 |
|
81 |
-
if is_video_shaped and not isinstance(
|
|
|
|
|
82 |
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
83 |
if split_size > 1:
|
84 |
if len(latents) % split_size != 0:
|
85 |
-
raise ValueError(
|
|
|
|
|
86 |
encode_bs = len(latents) // split_size
|
87 |
image_batch = [
|
88 |
_run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize)
|
@@ -92,12 +113,16 @@ def vae_decode(
|
|
92 |
else:
|
93 |
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
|
94 |
|
95 |
-
if is_video_shaped and not isinstance(
|
|
|
|
|
96 |
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
97 |
return images
|
98 |
|
99 |
|
100 |
-
def _run_decoder(
|
|
|
|
|
101 |
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
102 |
*_, fl, hl, wl = latents.shape
|
103 |
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
@@ -105,7 +130,13 @@ def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_ch
|
|
105 |
image = vae.decode(
|
106 |
un_normalize_latents(latents, vae, vae_per_channel_normalize),
|
107 |
return_dict=False,
|
108 |
-
target_shape=(
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
)[0]
|
110 |
else:
|
111 |
image = vae.decode(
|
@@ -120,14 +151,26 @@ def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
|
|
120 |
spatial = vae.spatial_downscale_factor
|
121 |
temporal = vae.temporal_downscale_factor
|
122 |
else:
|
123 |
-
down_blocks = len(
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
spatial = vae.config.patch_size * 2**down_blocks
|
125 |
-
temporal =
|
|
|
|
|
|
|
|
|
126 |
|
127 |
return (temporal, spatial, spatial)
|
128 |
|
129 |
|
130 |
-
def normalize_latents(
|
|
|
|
|
131 |
return (
|
132 |
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
|
133 |
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
@@ -136,10 +179,12 @@ def normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_norma
|
|
136 |
)
|
137 |
|
138 |
|
139 |
-
def un_normalize_latents(
|
|
|
|
|
140 |
return (
|
141 |
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
142 |
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
143 |
if vae_per_channel_normalize
|
144 |
else latents / vae.config.scaling_factor
|
145 |
-
)
|
|
|
6 |
|
7 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
8 |
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
|
9 |
+
|
10 |
try:
|
11 |
import torch_xla.core.xla_model as xm
|
12 |
+
except ImportError:
|
13 |
+
xm = None
|
14 |
+
|
15 |
|
16 |
+
def vae_encode(
|
17 |
+
media_items: Tensor,
|
18 |
+
vae: AutoencoderKL,
|
19 |
+
split_size: int = 1,
|
20 |
+
vae_per_channel_normalize=False,
|
21 |
+
) -> Tensor:
|
22 |
"""
|
23 |
Encodes media items (images or videos) into latent representations using a specified VAE model.
|
24 |
The function supports processing batches of images or video frames and can handle the processing
|
|
|
55 |
if channels != 3:
|
56 |
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
57 |
|
58 |
+
if is_video_shaped and not isinstance(
|
59 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
60 |
+
):
|
61 |
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
62 |
if split_size > 1:
|
63 |
if len(media_items) % split_size != 0:
|
64 |
+
raise ValueError(
|
65 |
+
"Error: The batch size must be divisible by 'train.vae_bs_split"
|
66 |
+
)
|
67 |
encode_bs = len(media_items) // split_size
|
68 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
69 |
latents = []
|
|
|
78 |
latents = vae.encode(media_items).latent_dist.sample()
|
79 |
|
80 |
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
81 |
+
if is_video_shaped and not isinstance(
|
82 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
83 |
+
):
|
84 |
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
85 |
return latents
|
86 |
|
87 |
|
88 |
def vae_decode(
|
89 |
+
latents: Tensor,
|
90 |
+
vae: AutoencoderKL,
|
91 |
+
is_video: bool = True,
|
92 |
+
split_size: int = 1,
|
93 |
+
vae_per_channel_normalize=False,
|
94 |
) -> Tensor:
|
95 |
is_video_shaped = latents.dim() == 5
|
96 |
batch_size = latents.shape[0]
|
97 |
|
98 |
+
if is_video_shaped and not isinstance(
|
99 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
100 |
+
):
|
101 |
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
102 |
if split_size > 1:
|
103 |
if len(latents) % split_size != 0:
|
104 |
+
raise ValueError(
|
105 |
+
"Error: The batch size must be divisible by 'train.vae_bs_split"
|
106 |
+
)
|
107 |
encode_bs = len(latents) // split_size
|
108 |
image_batch = [
|
109 |
_run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize)
|
|
|
113 |
else:
|
114 |
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
|
115 |
|
116 |
+
if is_video_shaped and not isinstance(
|
117 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
118 |
+
):
|
119 |
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
120 |
return images
|
121 |
|
122 |
|
123 |
+
def _run_decoder(
|
124 |
+
latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False
|
125 |
+
) -> Tensor:
|
126 |
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
127 |
*_, fl, hl, wl = latents.shape
|
128 |
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
|
|
130 |
image = vae.decode(
|
131 |
un_normalize_latents(latents, vae, vae_per_channel_normalize),
|
132 |
return_dict=False,
|
133 |
+
target_shape=(
|
134 |
+
1,
|
135 |
+
3,
|
136 |
+
fl * temporal_scale if is_video else 1,
|
137 |
+
hl * spatial_scale,
|
138 |
+
wl * spatial_scale,
|
139 |
+
),
|
140 |
)[0]
|
141 |
else:
|
142 |
image = vae.decode(
|
|
|
151 |
spatial = vae.spatial_downscale_factor
|
152 |
temporal = vae.temporal_downscale_factor
|
153 |
else:
|
154 |
+
down_blocks = len(
|
155 |
+
[
|
156 |
+
block
|
157 |
+
for block in vae.encoder.down_blocks
|
158 |
+
if isinstance(block.downsample, Downsample3D)
|
159 |
+
]
|
160 |
+
)
|
161 |
spatial = vae.config.patch_size * 2**down_blocks
|
162 |
+
temporal = (
|
163 |
+
vae.config.patch_size_t * 2**down_blocks
|
164 |
+
if isinstance(vae, VideoAutoencoder)
|
165 |
+
else 1
|
166 |
+
)
|
167 |
|
168 |
return (temporal, spatial, spatial)
|
169 |
|
170 |
|
171 |
+
def normalize_latents(
|
172 |
+
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
|
173 |
+
) -> Tensor:
|
174 |
return (
|
175 |
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
|
176 |
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
|
|
179 |
)
|
180 |
|
181 |
|
182 |
+
def un_normalize_latents(
|
183 |
+
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
|
184 |
+
) -> Tensor:
|
185 |
return (
|
186 |
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
187 |
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
188 |
if vae_per_channel_normalize
|
189 |
else latents / vae.config.scaling_factor
|
190 |
+
)
|
xora/models/autoencoders/video_autoencoder.py
CHANGED
@@ -21,7 +21,12 @@ logger = logging.get_logger(__name__)
|
|
21 |
|
22 |
class VideoAutoencoder(AutoencoderKLWrapper):
|
23 |
@classmethod
|
24 |
-
def from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
25 |
config_local_path = pretrained_model_name_or_path / "config.json"
|
26 |
config = cls.load_config(config_local_path, **kwargs)
|
27 |
video_vae = cls.from_config(config)
|
@@ -31,29 +36,41 @@ class VideoAutoencoder(AutoencoderKLWrapper):
|
|
31 |
ckpt_state_dict = torch.load(model_local_path)
|
32 |
video_vae.load_state_dict(ckpt_state_dict)
|
33 |
|
34 |
-
statistics_local_path =
|
|
|
|
|
35 |
if statistics_local_path.exists():
|
36 |
with open(statistics_local_path, "r") as file:
|
37 |
data = json.load(file)
|
38 |
transposed_data = list(zip(*data["data"]))
|
39 |
-
data_dict = {
|
|
|
|
|
|
|
40 |
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
41 |
video_vae.register_buffer(
|
42 |
-
"mean_of_means",
|
|
|
|
|
|
|
43 |
)
|
44 |
|
45 |
return video_vae
|
46 |
|
47 |
@staticmethod
|
48 |
def from_config(config):
|
49 |
-
assert
|
|
|
|
|
50 |
if isinstance(config["dims"], list):
|
51 |
config["dims"] = tuple(config["dims"])
|
52 |
|
53 |
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
54 |
|
55 |
double_z = config.get("double_z", True)
|
56 |
-
latent_log_var = config.get(
|
|
|
|
|
57 |
use_quant_conv = config.get("use_quant_conv", True)
|
58 |
|
59 |
if use_quant_conv and latent_log_var == "uniform":
|
@@ -96,8 +113,10 @@ class VideoAutoencoder(AutoencoderKLWrapper):
|
|
96 |
return SimpleNamespace(
|
97 |
_class_name="VideoAutoencoder",
|
98 |
dims=self.dims,
|
99 |
-
in_channels=self.encoder.conv_in.in_channels
|
100 |
-
|
|
|
|
|
101 |
latent_channels=self.decoder.conv_in.in_channels,
|
102 |
block_out_channels=[
|
103 |
self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
|
@@ -143,7 +162,9 @@ class VideoAutoencoder(AutoencoderKLWrapper):
|
|
143 |
key = key.replace(k, v)
|
144 |
|
145 |
if "norm" in key and key not in model_keys:
|
146 |
-
logger.info(
|
|
|
|
|
147 |
continue
|
148 |
|
149 |
converted_state_dict[key] = value
|
@@ -253,7 +274,11 @@ class Encoder(nn.Module):
|
|
253 |
|
254 |
# out
|
255 |
if norm_layer == "group_norm":
|
256 |
-
self.conv_norm_out = nn.GroupNorm(
|
|
|
|
|
|
|
|
|
257 |
elif norm_layer == "pixel_norm":
|
258 |
self.conv_norm_out = PixelNorm()
|
259 |
self.conv_act = nn.SiLU()
|
@@ -265,14 +290,23 @@ class Encoder(nn.Module):
|
|
265 |
conv_out_channels += 1
|
266 |
elif latent_log_var != "none":
|
267 |
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
268 |
-
self.conv_out = make_conv_nd(
|
|
|
|
|
269 |
|
270 |
self.gradient_checkpointing = False
|
271 |
|
272 |
@property
|
273 |
def downscale_factor(self):
|
274 |
return (
|
275 |
-
2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
* self.patch_size
|
277 |
)
|
278 |
|
@@ -299,7 +333,9 @@ class Encoder(nn.Module):
|
|
299 |
)
|
300 |
|
301 |
for down_block in self.down_blocks:
|
302 |
-
sample = checkpoint_fn(down_block)(
|
|
|
|
|
303 |
|
304 |
sample = checkpoint_fn(self.mid_block)(sample)
|
305 |
|
@@ -314,11 +350,15 @@ class Encoder(nn.Module):
|
|
314 |
|
315 |
if num_dims == 4:
|
316 |
# For shape (B, C, H, W)
|
317 |
-
repeated_last_channel = last_channel.repeat(
|
|
|
|
|
318 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
319 |
elif num_dims == 5:
|
320 |
# For shape (B, C, F, H, W)
|
321 |
-
repeated_last_channel = last_channel.repeat(
|
|
|
|
|
322 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
323 |
else:
|
324 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
@@ -405,7 +445,8 @@ class Decoder(nn.Module):
|
|
405 |
num_layers=self.layers_per_block + 1,
|
406 |
in_channels=prev_output_channel,
|
407 |
out_channels=output_channel,
|
408 |
-
add_upsample=not is_final_block
|
|
|
409 |
resnet_eps=1e-6,
|
410 |
resnet_groups=norm_num_groups,
|
411 |
norm_layer=norm_layer,
|
@@ -413,12 +454,16 @@ class Decoder(nn.Module):
|
|
413 |
self.up_blocks.append(up_block)
|
414 |
|
415 |
if norm_layer == "group_norm":
|
416 |
-
self.conv_norm_out = nn.GroupNorm(
|
|
|
|
|
417 |
elif norm_layer == "pixel_norm":
|
418 |
self.conv_norm_out = PixelNorm()
|
419 |
|
420 |
self.conv_act = nn.SiLU()
|
421 |
-
self.conv_out = make_conv_nd(
|
|
|
|
|
422 |
|
423 |
self.gradient_checkpointing = False
|
424 |
|
@@ -494,15 +539,24 @@ class DownEncoderBlock3D(nn.Module):
|
|
494 |
self.res_blocks = nn.ModuleList(res_blocks)
|
495 |
|
496 |
if add_downsample:
|
497 |
-
self.downsample = Downsample3D(
|
|
|
|
|
|
|
|
|
|
|
498 |
else:
|
499 |
self.downsample = Identity()
|
500 |
|
501 |
-
def forward(
|
|
|
|
|
502 |
for resnet in self.res_blocks:
|
503 |
hidden_states = resnet(hidden_states)
|
504 |
|
505 |
-
hidden_states = self.downsample(
|
|
|
|
|
506 |
|
507 |
return hidden_states
|
508 |
|
@@ -536,7 +590,9 @@ class UNetMidBlock3D(nn.Module):
|
|
536 |
norm_layer: str = "group_norm",
|
537 |
):
|
538 |
super().__init__()
|
539 |
-
resnet_groups =
|
|
|
|
|
540 |
|
541 |
self.res_blocks = nn.ModuleList(
|
542 |
[
|
@@ -595,13 +651,17 @@ class UpDecoderBlock3D(nn.Module):
|
|
595 |
self.res_blocks = nn.ModuleList(res_blocks)
|
596 |
|
597 |
if add_upsample:
|
598 |
-
self.upsample = Upsample3D(
|
|
|
|
|
599 |
else:
|
600 |
self.upsample = Identity()
|
601 |
|
602 |
self.resolution_idx = resolution_idx
|
603 |
|
604 |
-
def forward(
|
|
|
|
|
605 |
for resnet in self.res_blocks:
|
606 |
hidden_states = resnet(hidden_states)
|
607 |
|
@@ -641,25 +701,35 @@ class ResnetBlock3D(nn.Module):
|
|
641 |
self.use_conv_shortcut = conv_shortcut
|
642 |
|
643 |
if norm_layer == "group_norm":
|
644 |
-
self.norm1 = torch.nn.GroupNorm(
|
|
|
|
|
645 |
elif norm_layer == "pixel_norm":
|
646 |
self.norm1 = PixelNorm()
|
647 |
|
648 |
self.non_linearity = nn.SiLU()
|
649 |
|
650 |
-
self.conv1 = make_conv_nd(
|
|
|
|
|
651 |
|
652 |
if norm_layer == "group_norm":
|
653 |
-
self.norm2 = torch.nn.GroupNorm(
|
|
|
|
|
654 |
elif norm_layer == "pixel_norm":
|
655 |
self.norm2 = PixelNorm()
|
656 |
|
657 |
self.dropout = torch.nn.Dropout(dropout)
|
658 |
|
659 |
-
self.conv2 = make_conv_nd(
|
|
|
|
|
660 |
|
661 |
self.conv_shortcut = (
|
662 |
-
make_linear_nd(
|
|
|
|
|
663 |
if in_channels != out_channels
|
664 |
else nn.Identity()
|
665 |
)
|
@@ -692,7 +762,14 @@ class ResnetBlock3D(nn.Module):
|
|
692 |
|
693 |
|
694 |
class Downsample3D(nn.Module):
|
695 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
696 |
super().__init__()
|
697 |
stride: int = 2
|
698 |
self.padding = padding
|
@@ -735,18 +812,24 @@ class Upsample3D(nn.Module):
|
|
735 |
self.dims = dims
|
736 |
self.channels = channels
|
737 |
self.out_channels = out_channels or channels
|
738 |
-
self.conv = make_conv_nd(
|
|
|
|
|
739 |
|
740 |
def forward(self, x, upsample_in_time):
|
741 |
if self.dims == 2:
|
742 |
-
x = functional.interpolate(
|
|
|
|
|
743 |
else:
|
744 |
time_scale_factor = 2 if upsample_in_time else 1
|
745 |
# print("before:", x.shape)
|
746 |
b, c, d, h, w = x.shape
|
747 |
x = rearrange(x, "b c d h w -> (b d) c h w")
|
748 |
# height and width interpolate
|
749 |
-
x = functional.interpolate(
|
|
|
|
|
750 |
_, _, h, w = x.shape
|
751 |
|
752 |
if not upsample_in_time and self.dims == (2, 1):
|
@@ -760,7 +843,9 @@ class Upsample3D(nn.Module):
|
|
760 |
new_d = x.shape[-1] * time_scale_factor
|
761 |
x = functional.interpolate(x, (1, new_d), mode="nearest")
|
762 |
# (b h w) c 1 new_d
|
763 |
-
x = rearrange(
|
|
|
|
|
764 |
# b c d h w
|
765 |
|
766 |
# x = functional.interpolate(
|
@@ -775,13 +860,25 @@ def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
|
775 |
if patch_size_hw == 1 and patch_size_t == 1:
|
776 |
return x
|
777 |
if x.dim() == 4:
|
778 |
-
x = rearrange(
|
|
|
|
|
779 |
elif x.dim() == 5:
|
780 |
-
x = rearrange(
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
else:
|
782 |
raise ValueError(f"Invalid input shape: {x.shape}")
|
783 |
|
784 |
-
if (
|
|
|
|
|
|
|
|
|
785 |
channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
|
786 |
padding_zeros = torch.zeros(
|
787 |
x.shape[0],
|
@@ -801,14 +898,26 @@ def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
|
801 |
if patch_size_hw == 1 and patch_size_t == 1:
|
802 |
return x
|
803 |
|
804 |
-
if (
|
|
|
|
|
|
|
|
|
805 |
channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
|
806 |
x = x[:, :channels_to_keep, :, :, :]
|
807 |
|
808 |
if x.dim() == 4:
|
809 |
-
x = rearrange(
|
|
|
|
|
810 |
elif x.dim() == 5:
|
811 |
-
x = rearrange(
|
|
|
|
|
|
|
|
|
|
|
|
|
812 |
|
813 |
return x
|
814 |
|
@@ -818,11 +927,19 @@ def create_video_autoencoder_config(
|
|
818 |
):
|
819 |
config = {
|
820 |
"_class_name": "VideoAutoencoder",
|
821 |
-
"dims": (
|
|
|
|
|
|
|
822 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
823 |
"out_channels": 3, # Number of output color channels
|
824 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
825 |
-
"block_out_channels": [
|
|
|
|
|
|
|
|
|
|
|
826 |
"patch_size": 1,
|
827 |
}
|
828 |
|
@@ -834,11 +951,15 @@ def create_video_autoencoder_pathify4x4x4_config(
|
|
834 |
):
|
835 |
config = {
|
836 |
"_class_name": "VideoAutoencoder",
|
837 |
-
"dims": (
|
|
|
|
|
|
|
838 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
839 |
"out_channels": 3, # Number of output color channels
|
840 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
841 |
-
"block_out_channels": [512]
|
|
|
842 |
"patch_size": 4,
|
843 |
"latent_log_var": "uniform",
|
844 |
}
|
@@ -855,7 +976,8 @@ def create_video_autoencoder_pathify4x4_config(
|
|
855 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
856 |
"out_channels": 3, # Number of output color channels
|
857 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
858 |
-
"block_out_channels": [512]
|
|
|
859 |
"patch_size": 4,
|
860 |
"norm_layer": "pixel_norm",
|
861 |
}
|
@@ -894,7 +1016,9 @@ def demo_video_autoencoder_forward_backward():
|
|
894 |
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
|
895 |
print(f"input shape={input_videos.shape}")
|
896 |
print(f"latent shape={latent.shape}")
|
897 |
-
reconstructed_videos = video_autoencoder.decode(
|
|
|
|
|
898 |
|
899 |
print(f"reconstructed shape={reconstructed_videos.shape}")
|
900 |
|
|
|
21 |
|
22 |
class VideoAutoencoder(AutoencoderKLWrapper):
|
23 |
@classmethod
|
24 |
+
def from_pretrained(
|
25 |
+
cls,
|
26 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
27 |
+
*args,
|
28 |
+
**kwargs,
|
29 |
+
):
|
30 |
config_local_path = pretrained_model_name_or_path / "config.json"
|
31 |
config = cls.load_config(config_local_path, **kwargs)
|
32 |
video_vae = cls.from_config(config)
|
|
|
36 |
ckpt_state_dict = torch.load(model_local_path)
|
37 |
video_vae.load_state_dict(ckpt_state_dict)
|
38 |
|
39 |
+
statistics_local_path = (
|
40 |
+
pretrained_model_name_or_path / "per_channel_statistics.json"
|
41 |
+
)
|
42 |
if statistics_local_path.exists():
|
43 |
with open(statistics_local_path, "r") as file:
|
44 |
data = json.load(file)
|
45 |
transposed_data = list(zip(*data["data"]))
|
46 |
+
data_dict = {
|
47 |
+
col: torch.tensor(vals)
|
48 |
+
for col, vals in zip(data["columns"], transposed_data)
|
49 |
+
}
|
50 |
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
51 |
video_vae.register_buffer(
|
52 |
+
"mean_of_means",
|
53 |
+
data_dict.get(
|
54 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
55 |
+
),
|
56 |
)
|
57 |
|
58 |
return video_vae
|
59 |
|
60 |
@staticmethod
|
61 |
def from_config(config):
|
62 |
+
assert (
|
63 |
+
config["_class_name"] == "VideoAutoencoder"
|
64 |
+
), "config must have _class_name=VideoAutoencoder"
|
65 |
if isinstance(config["dims"], list):
|
66 |
config["dims"] = tuple(config["dims"])
|
67 |
|
68 |
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
69 |
|
70 |
double_z = config.get("double_z", True)
|
71 |
+
latent_log_var = config.get(
|
72 |
+
"latent_log_var", "per_channel" if double_z else "none"
|
73 |
+
)
|
74 |
use_quant_conv = config.get("use_quant_conv", True)
|
75 |
|
76 |
if use_quant_conv and latent_log_var == "uniform":
|
|
|
113 |
return SimpleNamespace(
|
114 |
_class_name="VideoAutoencoder",
|
115 |
dims=self.dims,
|
116 |
+
in_channels=self.encoder.conv_in.in_channels
|
117 |
+
// (self.encoder.patch_size_t * self.encoder.patch_size**2),
|
118 |
+
out_channels=self.decoder.conv_out.out_channels
|
119 |
+
// (self.decoder.patch_size_t * self.decoder.patch_size**2),
|
120 |
latent_channels=self.decoder.conv_in.in_channels,
|
121 |
block_out_channels=[
|
122 |
self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
|
|
|
162 |
key = key.replace(k, v)
|
163 |
|
164 |
if "norm" in key and key not in model_keys:
|
165 |
+
logger.info(
|
166 |
+
f"Removing key {key} from state_dict as it is not present in the model"
|
167 |
+
)
|
168 |
continue
|
169 |
|
170 |
converted_state_dict[key] = value
|
|
|
274 |
|
275 |
# out
|
276 |
if norm_layer == "group_norm":
|
277 |
+
self.conv_norm_out = nn.GroupNorm(
|
278 |
+
num_channels=block_out_channels[-1],
|
279 |
+
num_groups=norm_num_groups,
|
280 |
+
eps=1e-6,
|
281 |
+
)
|
282 |
elif norm_layer == "pixel_norm":
|
283 |
self.conv_norm_out = PixelNorm()
|
284 |
self.conv_act = nn.SiLU()
|
|
|
290 |
conv_out_channels += 1
|
291 |
elif latent_log_var != "none":
|
292 |
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
293 |
+
self.conv_out = make_conv_nd(
|
294 |
+
dims, block_out_channels[-1], conv_out_channels, 3, padding=1
|
295 |
+
)
|
296 |
|
297 |
self.gradient_checkpointing = False
|
298 |
|
299 |
@property
|
300 |
def downscale_factor(self):
|
301 |
return (
|
302 |
+
2
|
303 |
+
** len(
|
304 |
+
[
|
305 |
+
block
|
306 |
+
for block in self.down_blocks
|
307 |
+
if isinstance(block.downsample, Downsample3D)
|
308 |
+
]
|
309 |
+
)
|
310 |
* self.patch_size
|
311 |
)
|
312 |
|
|
|
333 |
)
|
334 |
|
335 |
for down_block in self.down_blocks:
|
336 |
+
sample = checkpoint_fn(down_block)(
|
337 |
+
sample, downsample_in_time=downsample_in_time
|
338 |
+
)
|
339 |
|
340 |
sample = checkpoint_fn(self.mid_block)(sample)
|
341 |
|
|
|
350 |
|
351 |
if num_dims == 4:
|
352 |
# For shape (B, C, H, W)
|
353 |
+
repeated_last_channel = last_channel.repeat(
|
354 |
+
1, sample.shape[1] - 2, 1, 1
|
355 |
+
)
|
356 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
357 |
elif num_dims == 5:
|
358 |
# For shape (B, C, F, H, W)
|
359 |
+
repeated_last_channel = last_channel.repeat(
|
360 |
+
1, sample.shape[1] - 2, 1, 1, 1
|
361 |
+
)
|
362 |
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
363 |
else:
|
364 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
|
|
445 |
num_layers=self.layers_per_block + 1,
|
446 |
in_channels=prev_output_channel,
|
447 |
out_channels=output_channel,
|
448 |
+
add_upsample=not is_final_block
|
449 |
+
and 2 ** (len(block_out_channels) - i - 1) > patch_size,
|
450 |
resnet_eps=1e-6,
|
451 |
resnet_groups=norm_num_groups,
|
452 |
norm_layer=norm_layer,
|
|
|
454 |
self.up_blocks.append(up_block)
|
455 |
|
456 |
if norm_layer == "group_norm":
|
457 |
+
self.conv_norm_out = nn.GroupNorm(
|
458 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
459 |
+
)
|
460 |
elif norm_layer == "pixel_norm":
|
461 |
self.conv_norm_out = PixelNorm()
|
462 |
|
463 |
self.conv_act = nn.SiLU()
|
464 |
+
self.conv_out = make_conv_nd(
|
465 |
+
dims, block_out_channels[0], out_channels, 3, padding=1
|
466 |
+
)
|
467 |
|
468 |
self.gradient_checkpointing = False
|
469 |
|
|
|
539 |
self.res_blocks = nn.ModuleList(res_blocks)
|
540 |
|
541 |
if add_downsample:
|
542 |
+
self.downsample = Downsample3D(
|
543 |
+
dims,
|
544 |
+
out_channels,
|
545 |
+
out_channels=out_channels,
|
546 |
+
padding=downsample_padding,
|
547 |
+
)
|
548 |
else:
|
549 |
self.downsample = Identity()
|
550 |
|
551 |
+
def forward(
|
552 |
+
self, hidden_states: torch.FloatTensor, downsample_in_time
|
553 |
+
) -> torch.FloatTensor:
|
554 |
for resnet in self.res_blocks:
|
555 |
hidden_states = resnet(hidden_states)
|
556 |
|
557 |
+
hidden_states = self.downsample(
|
558 |
+
hidden_states, downsample_in_time=downsample_in_time
|
559 |
+
)
|
560 |
|
561 |
return hidden_states
|
562 |
|
|
|
590 |
norm_layer: str = "group_norm",
|
591 |
):
|
592 |
super().__init__()
|
593 |
+
resnet_groups = (
|
594 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
595 |
+
)
|
596 |
|
597 |
self.res_blocks = nn.ModuleList(
|
598 |
[
|
|
|
651 |
self.res_blocks = nn.ModuleList(res_blocks)
|
652 |
|
653 |
if add_upsample:
|
654 |
+
self.upsample = Upsample3D(
|
655 |
+
dims=dims, channels=out_channels, out_channels=out_channels
|
656 |
+
)
|
657 |
else:
|
658 |
self.upsample = Identity()
|
659 |
|
660 |
self.resolution_idx = resolution_idx
|
661 |
|
662 |
+
def forward(
|
663 |
+
self, hidden_states: torch.FloatTensor, upsample_in_time=True
|
664 |
+
) -> torch.FloatTensor:
|
665 |
for resnet in self.res_blocks:
|
666 |
hidden_states = resnet(hidden_states)
|
667 |
|
|
|
701 |
self.use_conv_shortcut = conv_shortcut
|
702 |
|
703 |
if norm_layer == "group_norm":
|
704 |
+
self.norm1 = torch.nn.GroupNorm(
|
705 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
706 |
+
)
|
707 |
elif norm_layer == "pixel_norm":
|
708 |
self.norm1 = PixelNorm()
|
709 |
|
710 |
self.non_linearity = nn.SiLU()
|
711 |
|
712 |
+
self.conv1 = make_conv_nd(
|
713 |
+
dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
714 |
+
)
|
715 |
|
716 |
if norm_layer == "group_norm":
|
717 |
+
self.norm2 = torch.nn.GroupNorm(
|
718 |
+
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
719 |
+
)
|
720 |
elif norm_layer == "pixel_norm":
|
721 |
self.norm2 = PixelNorm()
|
722 |
|
723 |
self.dropout = torch.nn.Dropout(dropout)
|
724 |
|
725 |
+
self.conv2 = make_conv_nd(
|
726 |
+
dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
727 |
+
)
|
728 |
|
729 |
self.conv_shortcut = (
|
730 |
+
make_linear_nd(
|
731 |
+
dims=dims, in_channels=in_channels, out_channels=out_channels
|
732 |
+
)
|
733 |
if in_channels != out_channels
|
734 |
else nn.Identity()
|
735 |
)
|
|
|
762 |
|
763 |
|
764 |
class Downsample3D(nn.Module):
|
765 |
+
def __init__(
|
766 |
+
self,
|
767 |
+
dims,
|
768 |
+
in_channels: int,
|
769 |
+
out_channels: int,
|
770 |
+
kernel_size: int = 3,
|
771 |
+
padding: int = 1,
|
772 |
+
):
|
773 |
super().__init__()
|
774 |
stride: int = 2
|
775 |
self.padding = padding
|
|
|
812 |
self.dims = dims
|
813 |
self.channels = channels
|
814 |
self.out_channels = out_channels or channels
|
815 |
+
self.conv = make_conv_nd(
|
816 |
+
dims, channels, out_channels, kernel_size=3, padding=1, bias=True
|
817 |
+
)
|
818 |
|
819 |
def forward(self, x, upsample_in_time):
|
820 |
if self.dims == 2:
|
821 |
+
x = functional.interpolate(
|
822 |
+
x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
|
823 |
+
)
|
824 |
else:
|
825 |
time_scale_factor = 2 if upsample_in_time else 1
|
826 |
# print("before:", x.shape)
|
827 |
b, c, d, h, w = x.shape
|
828 |
x = rearrange(x, "b c d h w -> (b d) c h w")
|
829 |
# height and width interpolate
|
830 |
+
x = functional.interpolate(
|
831 |
+
x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
|
832 |
+
)
|
833 |
_, _, h, w = x.shape
|
834 |
|
835 |
if not upsample_in_time and self.dims == (2, 1):
|
|
|
843 |
new_d = x.shape[-1] * time_scale_factor
|
844 |
x = functional.interpolate(x, (1, new_d), mode="nearest")
|
845 |
# (b h w) c 1 new_d
|
846 |
+
x = rearrange(
|
847 |
+
x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d
|
848 |
+
)
|
849 |
# b c d h w
|
850 |
|
851 |
# x = functional.interpolate(
|
|
|
860 |
if patch_size_hw == 1 and patch_size_t == 1:
|
861 |
return x
|
862 |
if x.dim() == 4:
|
863 |
+
x = rearrange(
|
864 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
865 |
+
)
|
866 |
elif x.dim() == 5:
|
867 |
+
x = rearrange(
|
868 |
+
x,
|
869 |
+
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
870 |
+
p=patch_size_t,
|
871 |
+
q=patch_size_hw,
|
872 |
+
r=patch_size_hw,
|
873 |
+
)
|
874 |
else:
|
875 |
raise ValueError(f"Invalid input shape: {x.shape}")
|
876 |
|
877 |
+
if (
|
878 |
+
(x.dim() == 5)
|
879 |
+
and (patch_size_hw > patch_size_t)
|
880 |
+
and (patch_size_t > 1 or add_channel_padding)
|
881 |
+
):
|
882 |
channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
|
883 |
padding_zeros = torch.zeros(
|
884 |
x.shape[0],
|
|
|
898 |
if patch_size_hw == 1 and patch_size_t == 1:
|
899 |
return x
|
900 |
|
901 |
+
if (
|
902 |
+
(x.dim() == 5)
|
903 |
+
and (patch_size_hw > patch_size_t)
|
904 |
+
and (patch_size_t > 1 or add_channel_padding)
|
905 |
+
):
|
906 |
channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
|
907 |
x = x[:, :channels_to_keep, :, :, :]
|
908 |
|
909 |
if x.dim() == 4:
|
910 |
+
x = rearrange(
|
911 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
912 |
+
)
|
913 |
elif x.dim() == 5:
|
914 |
+
x = rearrange(
|
915 |
+
x,
|
916 |
+
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
917 |
+
p=patch_size_t,
|
918 |
+
q=patch_size_hw,
|
919 |
+
r=patch_size_hw,
|
920 |
+
)
|
921 |
|
922 |
return x
|
923 |
|
|
|
927 |
):
|
928 |
config = {
|
929 |
"_class_name": "VideoAutoencoder",
|
930 |
+
"dims": (
|
931 |
+
2,
|
932 |
+
1,
|
933 |
+
), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
934 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
935 |
"out_channels": 3, # Number of output color channels
|
936 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
937 |
+
"block_out_channels": [
|
938 |
+
128,
|
939 |
+
256,
|
940 |
+
512,
|
941 |
+
512,
|
942 |
+
], # Number of output channels of each encoder / decoder inner block
|
943 |
"patch_size": 1,
|
944 |
}
|
945 |
|
|
|
951 |
):
|
952 |
config = {
|
953 |
"_class_name": "VideoAutoencoder",
|
954 |
+
"dims": (
|
955 |
+
2,
|
956 |
+
1,
|
957 |
+
), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
958 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
959 |
"out_channels": 3, # Number of output color channels
|
960 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
961 |
+
"block_out_channels": [512]
|
962 |
+
* 4, # Number of output channels of each encoder / decoder inner block
|
963 |
"patch_size": 4,
|
964 |
"latent_log_var": "uniform",
|
965 |
}
|
|
|
976 |
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
977 |
"out_channels": 3, # Number of output color channels
|
978 |
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
979 |
+
"block_out_channels": [512]
|
980 |
+
* 4, # Number of output channels of each encoder / decoder inner block
|
981 |
"patch_size": 4,
|
982 |
"norm_layer": "pixel_norm",
|
983 |
}
|
|
|
1016 |
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
|
1017 |
print(f"input shape={input_videos.shape}")
|
1018 |
print(f"latent shape={latent.shape}")
|
1019 |
+
reconstructed_videos = video_autoencoder.decode(
|
1020 |
+
latent, target_shape=input_videos.shape
|
1021 |
+
).sample
|
1022 |
|
1023 |
print(f"reconstructed shape={reconstructed_videos.shape}")
|
1024 |
|
xora/models/transformers/attention.py
CHANGED
@@ -106,11 +106,15 @@ class BasicTransformerBlock(nn.Module):
|
|
106 |
assert standardization_norm in ["layer_norm", "rms_norm"]
|
107 |
assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
|
108 |
|
109 |
-
make_norm_layer =
|
|
|
|
|
110 |
|
111 |
# Define 3 blocks. Each block has its own normalization layer.
|
112 |
# 1. Self-Attn
|
113 |
-
self.norm1 = make_norm_layer(
|
|
|
|
|
114 |
|
115 |
self.attn1 = Attention(
|
116 |
query_dim=dim,
|
@@ -130,7 +134,9 @@ class BasicTransformerBlock(nn.Module):
|
|
130 |
if cross_attention_dim is not None or double_self_attention:
|
131 |
self.attn2 = Attention(
|
132 |
query_dim=dim,
|
133 |
-
cross_attention_dim=
|
|
|
|
|
134 |
heads=num_attention_heads,
|
135 |
dim_head=attention_head_dim,
|
136 |
dropout=dropout,
|
@@ -143,7 +149,9 @@ class BasicTransformerBlock(nn.Module):
|
|
143 |
) # is self-attn if encoder_hidden_states is none
|
144 |
|
145 |
if adaptive_norm == "none":
|
146 |
-
self.attn2_norm = make_norm_layer(
|
|
|
|
|
147 |
else:
|
148 |
self.attn2 = None
|
149 |
self.attn2_norm = None
|
@@ -163,7 +171,9 @@ class BasicTransformerBlock(nn.Module):
|
|
163 |
# 5. Scale-shift for PixArt-Alpha.
|
164 |
if adaptive_norm != "none":
|
165 |
num_ada_params = 4 if adaptive_norm == "single_scale" else 6
|
166 |
-
self.scale_shift_table = nn.Parameter(
|
|
|
|
|
167 |
|
168 |
# let chunk size default to None
|
169 |
self._chunk_size = None
|
@@ -198,7 +208,9 @@ class BasicTransformerBlock(nn.Module):
|
|
198 |
) -> torch.FloatTensor:
|
199 |
if cross_attention_kwargs is not None:
|
200 |
if cross_attention_kwargs.get("scale", None) is not None:
|
201 |
-
logger.warning(
|
|
|
|
|
202 |
|
203 |
# Notice that normalization is always applied before the real computation in the following blocks.
|
204 |
# 0. Self-Attention
|
@@ -214,7 +226,9 @@ class BasicTransformerBlock(nn.Module):
|
|
214 |
batch_size, timestep.shape[1], num_ada_params, -1
|
215 |
)
|
216 |
if self.adaptive_norm == "single_scale_shift":
|
217 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp =
|
|
|
|
|
218 |
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
219 |
else:
|
220 |
scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
@@ -224,15 +238,21 @@ class BasicTransformerBlock(nn.Module):
|
|
224 |
else:
|
225 |
raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
|
226 |
|
227 |
-
norm_hidden_states = norm_hidden_states.squeeze(
|
|
|
|
|
228 |
|
229 |
# 1. Prepare GLIGEN inputs
|
230 |
-
cross_attention_kwargs =
|
|
|
|
|
231 |
|
232 |
attn_output = self.attn1(
|
233 |
norm_hidden_states,
|
234 |
freqs_cis=freqs_cis,
|
235 |
-
encoder_hidden_states=
|
|
|
|
|
236 |
attention_mask=attention_mask,
|
237 |
**cross_attention_kwargs,
|
238 |
)
|
@@ -271,7 +291,9 @@ class BasicTransformerBlock(nn.Module):
|
|
271 |
|
272 |
if self._chunk_size is not None:
|
273 |
# "feed_forward_chunk_size" can be used to save memory
|
274 |
-
ff_output = _chunked_feed_forward(
|
|
|
|
|
275 |
else:
|
276 |
ff_output = self.ff(norm_hidden_states)
|
277 |
if gate_mlp is not None:
|
@@ -371,7 +393,9 @@ class Attention(nn.Module):
|
|
371 |
self.query_dim = query_dim
|
372 |
self.use_bias = bias
|
373 |
self.is_cross_attention = cross_attention_dim is not None
|
374 |
-
self.cross_attention_dim =
|
|
|
|
|
375 |
self.upcast_attention = upcast_attention
|
376 |
self.upcast_softmax = upcast_softmax
|
377 |
self.rescale_output_factor = rescale_output_factor
|
@@ -416,12 +440,16 @@ class Attention(nn.Module):
|
|
416 |
)
|
417 |
|
418 |
if norm_num_groups is not None:
|
419 |
-
self.group_norm = nn.GroupNorm(
|
|
|
|
|
420 |
else:
|
421 |
self.group_norm = None
|
422 |
|
423 |
if spatial_norm_dim is not None:
|
424 |
-
self.spatial_norm = SpatialNorm(
|
|
|
|
|
425 |
else:
|
426 |
self.spatial_norm = None
|
427 |
|
@@ -441,7 +469,10 @@ class Attention(nn.Module):
|
|
441 |
norm_cross_num_channels = self.cross_attention_dim
|
442 |
|
443 |
self.norm_cross = nn.GroupNorm(
|
444 |
-
num_channels=norm_cross_num_channels,
|
|
|
|
|
|
|
445 |
)
|
446 |
else:
|
447 |
raise ValueError(
|
@@ -499,12 +530,16 @@ class Attention(nn.Module):
|
|
499 |
and isinstance(self.processor, torch.nn.Module)
|
500 |
and not isinstance(processor, torch.nn.Module)
|
501 |
):
|
502 |
-
logger.info(
|
|
|
|
|
503 |
self._modules.pop("processor")
|
504 |
|
505 |
self.processor = processor
|
506 |
|
507 |
-
def get_processor(
|
|
|
|
|
508 |
r"""
|
509 |
Get the attention processor in use.
|
510 |
|
@@ -542,12 +577,18 @@ class Attention(nn.Module):
|
|
542 |
|
543 |
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
544 |
non_lora_processor_cls_name = self.processor.__class__.__name__
|
545 |
-
lora_processor_cls = getattr(
|
|
|
|
|
546 |
|
547 |
hidden_size = self.inner_dim
|
548 |
|
549 |
# now create a LoRA attention processor from the LoRA layers
|
550 |
-
if lora_processor_cls in [
|
|
|
|
|
|
|
|
|
551 |
kwargs = {
|
552 |
"cross_attention_dim": self.cross_attention_dim,
|
553 |
"rank": self.to_q.lora_layer.rank,
|
@@ -569,7 +610,9 @@ class Attention(nn.Module):
|
|
569 |
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
570 |
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
571 |
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
572 |
-
lora_processor.to_out_lora.load_state_dict(
|
|
|
|
|
573 |
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
574 |
lora_processor = lora_processor_cls(
|
575 |
hidden_size,
|
@@ -580,12 +623,18 @@ class Attention(nn.Module):
|
|
580 |
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
581 |
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
582 |
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
583 |
-
lora_processor.to_out_lora.load_state_dict(
|
|
|
|
|
584 |
|
585 |
# only save if used
|
586 |
if self.add_k_proj.lora_layer is not None:
|
587 |
-
lora_processor.add_k_proj_lora.load_state_dict(
|
588 |
-
|
|
|
|
|
|
|
|
|
589 |
else:
|
590 |
lora_processor.add_k_proj_lora = None
|
591 |
lora_processor.add_v_proj_lora = None
|
@@ -622,14 +671,20 @@ class Attention(nn.Module):
|
|
622 |
# here we simply pass along all tensors to the selected processor class
|
623 |
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
624 |
|
625 |
-
attn_parameters = set(
|
626 |
-
|
|
|
|
|
|
|
|
|
627 |
if len(unused_kwargs) > 0:
|
628 |
logger.warning(
|
629 |
f"cross_attention_kwargs {unused_kwargs} are not expected by"
|
630 |
f" {self.processor.__class__.__name__} and will be ignored."
|
631 |
)
|
632 |
-
cross_attention_kwargs = {
|
|
|
|
|
633 |
|
634 |
return self.processor(
|
635 |
self,
|
@@ -654,7 +709,9 @@ class Attention(nn.Module):
|
|
654 |
head_size = self.heads
|
655 |
batch_size, seq_len, dim = tensor.shape
|
656 |
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
657 |
-
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
|
|
|
|
658 |
return tensor
|
659 |
|
660 |
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
@@ -677,16 +734,23 @@ class Attention(nn.Module):
|
|
677 |
extra_dim = 1
|
678 |
else:
|
679 |
batch_size, extra_dim, seq_len, dim = tensor.shape
|
680 |
-
tensor = tensor.reshape(
|
|
|
|
|
681 |
tensor = tensor.permute(0, 2, 1, 3)
|
682 |
|
683 |
if out_dim == 3:
|
684 |
-
tensor = tensor.reshape(
|
|
|
|
|
685 |
|
686 |
return tensor
|
687 |
|
688 |
def get_attention_scores(
|
689 |
-
self,
|
|
|
|
|
|
|
690 |
) -> torch.Tensor:
|
691 |
r"""
|
692 |
Compute the attention scores.
|
@@ -706,7 +770,11 @@ class Attention(nn.Module):
|
|
706 |
|
707 |
if attention_mask is None:
|
708 |
baddbmm_input = torch.empty(
|
709 |
-
query.shape[0],
|
|
|
|
|
|
|
|
|
710 |
)
|
711 |
beta = 0
|
712 |
else:
|
@@ -733,7 +801,11 @@ class Attention(nn.Module):
|
|
733 |
return attention_probs
|
734 |
|
735 |
def prepare_attention_mask(
|
736 |
-
self,
|
|
|
|
|
|
|
|
|
737 |
) -> torch.Tensor:
|
738 |
r"""
|
739 |
Prepare the attention mask for the attention computation.
|
@@ -760,8 +832,16 @@ class Attention(nn.Module):
|
|
760 |
if attention_mask.device.type == "mps":
|
761 |
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
762 |
# Instead, we can manually construct the padding tensor.
|
763 |
-
padding_shape = (
|
764 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
765 |
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
766 |
else:
|
767 |
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
@@ -779,7 +859,9 @@ class Attention(nn.Module):
|
|
779 |
|
780 |
return attention_mask
|
781 |
|
782 |
-
def norm_encoder_hidden_states(
|
|
|
|
|
783 |
r"""
|
784 |
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
785 |
`Attention` class.
|
@@ -790,7 +872,9 @@ class Attention(nn.Module):
|
|
790 |
Returns:
|
791 |
`torch.Tensor`: The normalized encoder hidden states.
|
792 |
"""
|
793 |
-
assert
|
|
|
|
|
794 |
|
795 |
if isinstance(self.norm_cross, nn.LayerNorm):
|
796 |
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
@@ -857,27 +941,39 @@ class AttnProcessor2_0:
|
|
857 |
|
858 |
if input_ndim == 4:
|
859 |
batch_size, channel, height, width = hidden_states.shape
|
860 |
-
hidden_states = hidden_states.view(
|
|
|
|
|
861 |
|
862 |
batch_size, sequence_length, _ = (
|
863 |
-
hidden_states.shape
|
|
|
|
|
864 |
)
|
865 |
|
866 |
if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
|
867 |
-
attention_mask = attn.prepare_attention_mask(
|
|
|
|
|
868 |
# scaled_dot_product_attention expects attention_mask shape to be
|
869 |
# (batch, heads, source_length, target_length)
|
870 |
-
attention_mask = attention_mask.view(
|
|
|
|
|
871 |
|
872 |
if attn.group_norm is not None:
|
873 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
|
|
|
|
874 |
|
875 |
query = attn.to_q(hidden_states)
|
876 |
query = attn.q_norm(query)
|
877 |
|
878 |
if encoder_hidden_states is not None:
|
879 |
if attn.norm_cross:
|
880 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
|
|
|
|
881 |
key = attn.to_k(encoder_hidden_states)
|
882 |
key = attn.k_norm(key)
|
883 |
else: # if no context provided do self-attention
|
@@ -901,10 +997,14 @@ class AttnProcessor2_0:
|
|
901 |
|
902 |
if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
|
903 |
q_segment_indexes = None
|
904 |
-
if
|
|
|
|
|
905 |
# attention_mask = torch.squeeze(attention_mask).to(torch.float32)
|
906 |
attention_mask = attention_mask.to(torch.float32)
|
907 |
-
q_segment_indexes = torch.ones(
|
|
|
|
|
908 |
assert (
|
909 |
attention_mask.shape[1] == key.shape[2]
|
910 |
), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
|
@@ -927,10 +1027,17 @@ class AttnProcessor2_0:
|
|
927 |
)
|
928 |
else:
|
929 |
hidden_states = F.scaled_dot_product_attention(
|
930 |
-
query,
|
|
|
|
|
|
|
|
|
|
|
931 |
)
|
932 |
|
933 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(
|
|
|
|
|
934 |
hidden_states = hidden_states.to(query.dtype)
|
935 |
|
936 |
# linear proj
|
@@ -939,7 +1046,9 @@ class AttnProcessor2_0:
|
|
939 |
hidden_states = attn.to_out[1](hidden_states)
|
940 |
|
941 |
if input_ndim == 4:
|
942 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
|
|
|
|
943 |
|
944 |
if attn.residual_connection:
|
945 |
hidden_states = hidden_states + residual
|
@@ -977,22 +1086,32 @@ class AttnProcessor:
|
|
977 |
|
978 |
if input_ndim == 4:
|
979 |
batch_size, channel, height, width = hidden_states.shape
|
980 |
-
hidden_states = hidden_states.view(
|
|
|
|
|
981 |
|
982 |
batch_size, sequence_length, _ = (
|
983 |
-
hidden_states.shape
|
|
|
|
|
|
|
|
|
|
|
984 |
)
|
985 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
986 |
|
987 |
if attn.group_norm is not None:
|
988 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
|
|
|
|
989 |
|
990 |
query = attn.to_q(hidden_states)
|
991 |
|
992 |
if encoder_hidden_states is None:
|
993 |
encoder_hidden_states = hidden_states
|
994 |
elif attn.norm_cross:
|
995 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
|
|
|
|
996 |
|
997 |
key = attn.to_k(encoder_hidden_states)
|
998 |
value = attn.to_v(encoder_hidden_states)
|
@@ -1014,7 +1133,9 @@ class AttnProcessor:
|
|
1014 |
hidden_states = attn.to_out[1](hidden_states)
|
1015 |
|
1016 |
if input_ndim == 4:
|
1017 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
|
|
|
|
1018 |
|
1019 |
if attn.residual_connection:
|
1020 |
hidden_states = hidden_states + residual
|
|
|
106 |
assert standardization_norm in ["layer_norm", "rms_norm"]
|
107 |
assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
|
108 |
|
109 |
+
make_norm_layer = (
|
110 |
+
nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
|
111 |
+
)
|
112 |
|
113 |
# Define 3 blocks. Each block has its own normalization layer.
|
114 |
# 1. Self-Attn
|
115 |
+
self.norm1 = make_norm_layer(
|
116 |
+
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
|
117 |
+
)
|
118 |
|
119 |
self.attn1 = Attention(
|
120 |
query_dim=dim,
|
|
|
134 |
if cross_attention_dim is not None or double_self_attention:
|
135 |
self.attn2 = Attention(
|
136 |
query_dim=dim,
|
137 |
+
cross_attention_dim=(
|
138 |
+
cross_attention_dim if not double_self_attention else None
|
139 |
+
),
|
140 |
heads=num_attention_heads,
|
141 |
dim_head=attention_head_dim,
|
142 |
dropout=dropout,
|
|
|
149 |
) # is self-attn if encoder_hidden_states is none
|
150 |
|
151 |
if adaptive_norm == "none":
|
152 |
+
self.attn2_norm = make_norm_layer(
|
153 |
+
dim, norm_eps, norm_elementwise_affine
|
154 |
+
)
|
155 |
else:
|
156 |
self.attn2 = None
|
157 |
self.attn2_norm = None
|
|
|
171 |
# 5. Scale-shift for PixArt-Alpha.
|
172 |
if adaptive_norm != "none":
|
173 |
num_ada_params = 4 if adaptive_norm == "single_scale" else 6
|
174 |
+
self.scale_shift_table = nn.Parameter(
|
175 |
+
torch.randn(num_ada_params, dim) / dim**0.5
|
176 |
+
)
|
177 |
|
178 |
# let chunk size default to None
|
179 |
self._chunk_size = None
|
|
|
208 |
) -> torch.FloatTensor:
|
209 |
if cross_attention_kwargs is not None:
|
210 |
if cross_attention_kwargs.get("scale", None) is not None:
|
211 |
+
logger.warning(
|
212 |
+
"Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored."
|
213 |
+
)
|
214 |
|
215 |
# Notice that normalization is always applied before the real computation in the following blocks.
|
216 |
# 0. Self-Attention
|
|
|
226 |
batch_size, timestep.shape[1], num_ada_params, -1
|
227 |
)
|
228 |
if self.adaptive_norm == "single_scale_shift":
|
229 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
230 |
+
ada_values.unbind(dim=2)
|
231 |
+
)
|
232 |
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
233 |
else:
|
234 |
scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
|
|
238 |
else:
|
239 |
raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
|
240 |
|
241 |
+
norm_hidden_states = norm_hidden_states.squeeze(
|
242 |
+
1
|
243 |
+
) # TODO: Check if this is needed
|
244 |
|
245 |
# 1. Prepare GLIGEN inputs
|
246 |
+
cross_attention_kwargs = (
|
247 |
+
cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
248 |
+
)
|
249 |
|
250 |
attn_output = self.attn1(
|
251 |
norm_hidden_states,
|
252 |
freqs_cis=freqs_cis,
|
253 |
+
encoder_hidden_states=(
|
254 |
+
encoder_hidden_states if self.only_cross_attention else None
|
255 |
+
),
|
256 |
attention_mask=attention_mask,
|
257 |
**cross_attention_kwargs,
|
258 |
)
|
|
|
291 |
|
292 |
if self._chunk_size is not None:
|
293 |
# "feed_forward_chunk_size" can be used to save memory
|
294 |
+
ff_output = _chunked_feed_forward(
|
295 |
+
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
|
296 |
+
)
|
297 |
else:
|
298 |
ff_output = self.ff(norm_hidden_states)
|
299 |
if gate_mlp is not None:
|
|
|
393 |
self.query_dim = query_dim
|
394 |
self.use_bias = bias
|
395 |
self.is_cross_attention = cross_attention_dim is not None
|
396 |
+
self.cross_attention_dim = (
|
397 |
+
cross_attention_dim if cross_attention_dim is not None else query_dim
|
398 |
+
)
|
399 |
self.upcast_attention = upcast_attention
|
400 |
self.upcast_softmax = upcast_softmax
|
401 |
self.rescale_output_factor = rescale_output_factor
|
|
|
440 |
)
|
441 |
|
442 |
if norm_num_groups is not None:
|
443 |
+
self.group_norm = nn.GroupNorm(
|
444 |
+
num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
|
445 |
+
)
|
446 |
else:
|
447 |
self.group_norm = None
|
448 |
|
449 |
if spatial_norm_dim is not None:
|
450 |
+
self.spatial_norm = SpatialNorm(
|
451 |
+
f_channels=query_dim, zq_channels=spatial_norm_dim
|
452 |
+
)
|
453 |
else:
|
454 |
self.spatial_norm = None
|
455 |
|
|
|
469 |
norm_cross_num_channels = self.cross_attention_dim
|
470 |
|
471 |
self.norm_cross = nn.GroupNorm(
|
472 |
+
num_channels=norm_cross_num_channels,
|
473 |
+
num_groups=cross_attention_norm_num_groups,
|
474 |
+
eps=1e-5,
|
475 |
+
affine=True,
|
476 |
)
|
477 |
else:
|
478 |
raise ValueError(
|
|
|
530 |
and isinstance(self.processor, torch.nn.Module)
|
531 |
and not isinstance(processor, torch.nn.Module)
|
532 |
):
|
533 |
+
logger.info(
|
534 |
+
f"You are removing possibly trained weights of {self.processor} with {processor}"
|
535 |
+
)
|
536 |
self._modules.pop("processor")
|
537 |
|
538 |
self.processor = processor
|
539 |
|
540 |
+
def get_processor(
|
541 |
+
self, return_deprecated_lora: bool = False
|
542 |
+
) -> "AttentionProcessor": # noqa: F821
|
543 |
r"""
|
544 |
Get the attention processor in use.
|
545 |
|
|
|
577 |
|
578 |
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
579 |
non_lora_processor_cls_name = self.processor.__class__.__name__
|
580 |
+
lora_processor_cls = getattr(
|
581 |
+
import_module(__name__), "LoRA" + non_lora_processor_cls_name
|
582 |
+
)
|
583 |
|
584 |
hidden_size = self.inner_dim
|
585 |
|
586 |
# now create a LoRA attention processor from the LoRA layers
|
587 |
+
if lora_processor_cls in [
|
588 |
+
LoRAAttnProcessor,
|
589 |
+
LoRAAttnProcessor2_0,
|
590 |
+
LoRAXFormersAttnProcessor,
|
591 |
+
]:
|
592 |
kwargs = {
|
593 |
"cross_attention_dim": self.cross_attention_dim,
|
594 |
"rank": self.to_q.lora_layer.rank,
|
|
|
610 |
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
611 |
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
612 |
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
613 |
+
lora_processor.to_out_lora.load_state_dict(
|
614 |
+
self.to_out[0].lora_layer.state_dict()
|
615 |
+
)
|
616 |
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
617 |
lora_processor = lora_processor_cls(
|
618 |
hidden_size,
|
|
|
623 |
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
624 |
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
625 |
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
626 |
+
lora_processor.to_out_lora.load_state_dict(
|
627 |
+
self.to_out[0].lora_layer.state_dict()
|
628 |
+
)
|
629 |
|
630 |
# only save if used
|
631 |
if self.add_k_proj.lora_layer is not None:
|
632 |
+
lora_processor.add_k_proj_lora.load_state_dict(
|
633 |
+
self.add_k_proj.lora_layer.state_dict()
|
634 |
+
)
|
635 |
+
lora_processor.add_v_proj_lora.load_state_dict(
|
636 |
+
self.add_v_proj.lora_layer.state_dict()
|
637 |
+
)
|
638 |
else:
|
639 |
lora_processor.add_k_proj_lora = None
|
640 |
lora_processor.add_v_proj_lora = None
|
|
|
671 |
# here we simply pass along all tensors to the selected processor class
|
672 |
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
673 |
|
674 |
+
attn_parameters = set(
|
675 |
+
inspect.signature(self.processor.__call__).parameters.keys()
|
676 |
+
)
|
677 |
+
unused_kwargs = [
|
678 |
+
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters
|
679 |
+
]
|
680 |
if len(unused_kwargs) > 0:
|
681 |
logger.warning(
|
682 |
f"cross_attention_kwargs {unused_kwargs} are not expected by"
|
683 |
f" {self.processor.__class__.__name__} and will be ignored."
|
684 |
)
|
685 |
+
cross_attention_kwargs = {
|
686 |
+
k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
|
687 |
+
}
|
688 |
|
689 |
return self.processor(
|
690 |
self,
|
|
|
709 |
head_size = self.heads
|
710 |
batch_size, seq_len, dim = tensor.shape
|
711 |
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
712 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
713 |
+
batch_size // head_size, seq_len, dim * head_size
|
714 |
+
)
|
715 |
return tensor
|
716 |
|
717 |
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
|
|
734 |
extra_dim = 1
|
735 |
else:
|
736 |
batch_size, extra_dim, seq_len, dim = tensor.shape
|
737 |
+
tensor = tensor.reshape(
|
738 |
+
batch_size, seq_len * extra_dim, head_size, dim // head_size
|
739 |
+
)
|
740 |
tensor = tensor.permute(0, 2, 1, 3)
|
741 |
|
742 |
if out_dim == 3:
|
743 |
+
tensor = tensor.reshape(
|
744 |
+
batch_size * head_size, seq_len * extra_dim, dim // head_size
|
745 |
+
)
|
746 |
|
747 |
return tensor
|
748 |
|
749 |
def get_attention_scores(
|
750 |
+
self,
|
751 |
+
query: torch.Tensor,
|
752 |
+
key: torch.Tensor,
|
753 |
+
attention_mask: torch.Tensor = None,
|
754 |
) -> torch.Tensor:
|
755 |
r"""
|
756 |
Compute the attention scores.
|
|
|
770 |
|
771 |
if attention_mask is None:
|
772 |
baddbmm_input = torch.empty(
|
773 |
+
query.shape[0],
|
774 |
+
query.shape[1],
|
775 |
+
key.shape[1],
|
776 |
+
dtype=query.dtype,
|
777 |
+
device=query.device,
|
778 |
)
|
779 |
beta = 0
|
780 |
else:
|
|
|
801 |
return attention_probs
|
802 |
|
803 |
def prepare_attention_mask(
|
804 |
+
self,
|
805 |
+
attention_mask: torch.Tensor,
|
806 |
+
target_length: int,
|
807 |
+
batch_size: int,
|
808 |
+
out_dim: int = 3,
|
809 |
) -> torch.Tensor:
|
810 |
r"""
|
811 |
Prepare the attention mask for the attention computation.
|
|
|
832 |
if attention_mask.device.type == "mps":
|
833 |
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
834 |
# Instead, we can manually construct the padding tensor.
|
835 |
+
padding_shape = (
|
836 |
+
attention_mask.shape[0],
|
837 |
+
attention_mask.shape[1],
|
838 |
+
target_length,
|
839 |
+
)
|
840 |
+
padding = torch.zeros(
|
841 |
+
padding_shape,
|
842 |
+
dtype=attention_mask.dtype,
|
843 |
+
device=attention_mask.device,
|
844 |
+
)
|
845 |
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
846 |
else:
|
847 |
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
|
|
859 |
|
860 |
return attention_mask
|
861 |
|
862 |
+
def norm_encoder_hidden_states(
|
863 |
+
self, encoder_hidden_states: torch.Tensor
|
864 |
+
) -> torch.Tensor:
|
865 |
r"""
|
866 |
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
867 |
`Attention` class.
|
|
|
872 |
Returns:
|
873 |
`torch.Tensor`: The normalized encoder hidden states.
|
874 |
"""
|
875 |
+
assert (
|
876 |
+
self.norm_cross is not None
|
877 |
+
), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
878 |
|
879 |
if isinstance(self.norm_cross, nn.LayerNorm):
|
880 |
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
|
|
941 |
|
942 |
if input_ndim == 4:
|
943 |
batch_size, channel, height, width = hidden_states.shape
|
944 |
+
hidden_states = hidden_states.view(
|
945 |
+
batch_size, channel, height * width
|
946 |
+
).transpose(1, 2)
|
947 |
|
948 |
batch_size, sequence_length, _ = (
|
949 |
+
hidden_states.shape
|
950 |
+
if encoder_hidden_states is None
|
951 |
+
else encoder_hidden_states.shape
|
952 |
)
|
953 |
|
954 |
if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
|
955 |
+
attention_mask = attn.prepare_attention_mask(
|
956 |
+
attention_mask, sequence_length, batch_size
|
957 |
+
)
|
958 |
# scaled_dot_product_attention expects attention_mask shape to be
|
959 |
# (batch, heads, source_length, target_length)
|
960 |
+
attention_mask = attention_mask.view(
|
961 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
962 |
+
)
|
963 |
|
964 |
if attn.group_norm is not None:
|
965 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
966 |
+
1, 2
|
967 |
+
)
|
968 |
|
969 |
query = attn.to_q(hidden_states)
|
970 |
query = attn.q_norm(query)
|
971 |
|
972 |
if encoder_hidden_states is not None:
|
973 |
if attn.norm_cross:
|
974 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
975 |
+
encoder_hidden_states
|
976 |
+
)
|
977 |
key = attn.to_k(encoder_hidden_states)
|
978 |
key = attn.k_norm(key)
|
979 |
else: # if no context provided do self-attention
|
|
|
997 |
|
998 |
if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
|
999 |
q_segment_indexes = None
|
1000 |
+
if (
|
1001 |
+
attention_mask is not None
|
1002 |
+
): # if mask is required need to tune both segmenIds fields
|
1003 |
# attention_mask = torch.squeeze(attention_mask).to(torch.float32)
|
1004 |
attention_mask = attention_mask.to(torch.float32)
|
1005 |
+
q_segment_indexes = torch.ones(
|
1006 |
+
batch_size, query.shape[2], device=query.device, dtype=torch.float32
|
1007 |
+
)
|
1008 |
assert (
|
1009 |
attention_mask.shape[1] == key.shape[2]
|
1010 |
), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
|
|
|
1027 |
)
|
1028 |
else:
|
1029 |
hidden_states = F.scaled_dot_product_attention(
|
1030 |
+
query,
|
1031 |
+
key,
|
1032 |
+
value,
|
1033 |
+
attn_mask=attention_mask,
|
1034 |
+
dropout_p=0.0,
|
1035 |
+
is_causal=False,
|
1036 |
)
|
1037 |
|
1038 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
1039 |
+
batch_size, -1, attn.heads * head_dim
|
1040 |
+
)
|
1041 |
hidden_states = hidden_states.to(query.dtype)
|
1042 |
|
1043 |
# linear proj
|
|
|
1046 |
hidden_states = attn.to_out[1](hidden_states)
|
1047 |
|
1048 |
if input_ndim == 4:
|
1049 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
1050 |
+
batch_size, channel, height, width
|
1051 |
+
)
|
1052 |
|
1053 |
if attn.residual_connection:
|
1054 |
hidden_states = hidden_states + residual
|
|
|
1086 |
|
1087 |
if input_ndim == 4:
|
1088 |
batch_size, channel, height, width = hidden_states.shape
|
1089 |
+
hidden_states = hidden_states.view(
|
1090 |
+
batch_size, channel, height * width
|
1091 |
+
).transpose(1, 2)
|
1092 |
|
1093 |
batch_size, sequence_length, _ = (
|
1094 |
+
hidden_states.shape
|
1095 |
+
if encoder_hidden_states is None
|
1096 |
+
else encoder_hidden_states.shape
|
1097 |
+
)
|
1098 |
+
attention_mask = attn.prepare_attention_mask(
|
1099 |
+
attention_mask, sequence_length, batch_size
|
1100 |
)
|
|
|
1101 |
|
1102 |
if attn.group_norm is not None:
|
1103 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
1104 |
+
1, 2
|
1105 |
+
)
|
1106 |
|
1107 |
query = attn.to_q(hidden_states)
|
1108 |
|
1109 |
if encoder_hidden_states is None:
|
1110 |
encoder_hidden_states = hidden_states
|
1111 |
elif attn.norm_cross:
|
1112 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
1113 |
+
encoder_hidden_states
|
1114 |
+
)
|
1115 |
|
1116 |
key = attn.to_k(encoder_hidden_states)
|
1117 |
value = attn.to_v(encoder_hidden_states)
|
|
|
1133 |
hidden_states = attn.to_out[1](hidden_states)
|
1134 |
|
1135 |
if input_ndim == 4:
|
1136 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
1137 |
+
batch_size, channel, height, width
|
1138 |
+
)
|
1139 |
|
1140 |
if attn.residual_connection:
|
1141 |
hidden_states = hidden_states + residual
|
xora/models/transformers/embeddings.py
CHANGED
@@ -26,7 +26,9 @@ def get_timestep_embedding(
|
|
26 |
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
27 |
|
28 |
half_dim = embedding_dim // 2
|
29 |
-
exponent = -math.log(max_period) * torch.arange(
|
|
|
|
|
30 |
exponent = exponent / (half_dim - downscale_freq_shift)
|
31 |
|
32 |
emb = torch.exp(exponent)
|
@@ -113,7 +115,9 @@ class SinusoidalPositionalEmbedding(nn.Module):
|
|
113 |
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
114 |
super().__init__()
|
115 |
position = torch.arange(max_seq_length).unsqueeze(1)
|
116 |
-
div_term = torch.exp(
|
|
|
|
|
117 |
pe = torch.zeros(1, max_seq_length, embed_dim)
|
118 |
pe[0, :, 0::2] = torch.sin(position * div_term)
|
119 |
pe[0, :, 1::2] = torch.cos(position * div_term)
|
|
|
26 |
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
27 |
|
28 |
half_dim = embedding_dim // 2
|
29 |
+
exponent = -math.log(max_period) * torch.arange(
|
30 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
31 |
+
)
|
32 |
exponent = exponent / (half_dim - downscale_freq_shift)
|
33 |
|
34 |
emb = torch.exp(exponent)
|
|
|
115 |
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
116 |
super().__init__()
|
117 |
position = torch.arange(max_seq_length).unsqueeze(1)
|
118 |
+
div_term = torch.exp(
|
119 |
+
torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
|
120 |
+
)
|
121 |
pe = torch.zeros(1, max_seq_length, embed_dim)
|
122 |
pe[0, :, 0::2] = torch.sin(position * div_term)
|
123 |
pe[0, :, 1::2] = torch.cos(position * div_term)
|
xora/models/transformers/symmetric_patchifier.py
CHANGED
@@ -15,12 +15,19 @@ class Patchifier(ConfigMixin, ABC):
|
|
15 |
self._patch_size = (1, patch_size, patch_size)
|
16 |
|
17 |
@abstractmethod
|
18 |
-
def patchify(
|
|
|
|
|
19 |
pass
|
20 |
|
21 |
@abstractmethod
|
22 |
def unpatchify(
|
23 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
24 |
) -> Tuple[Tensor, Tensor]:
|
25 |
pass
|
26 |
|
@@ -28,7 +35,9 @@ class Patchifier(ConfigMixin, ABC):
|
|
28 |
def patch_size(self):
|
29 |
return self._patch_size
|
30 |
|
31 |
-
def get_grid(
|
|
|
|
|
32 |
f = orig_num_frames // self._patch_size[0]
|
33 |
h = orig_height // self._patch_size[1]
|
34 |
w = orig_width // self._patch_size[2]
|
@@ -64,6 +73,7 @@ def pixart_alpha_patchify(
|
|
64 |
)
|
65 |
return latents
|
66 |
|
|
|
67 |
class SymmetricPatchifier(Patchifier):
|
68 |
def patchify(
|
69 |
self,
|
@@ -72,7 +82,12 @@ class SymmetricPatchifier(Patchifier):
|
|
72 |
return pixart_alpha_patchify(latents, self._patch_size)
|
73 |
|
74 |
def unpatchify(
|
75 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
76 |
) -> Tuple[Tensor, Tensor]:
|
77 |
output_height = output_height // self._patch_size[1]
|
78 |
output_width = output_width // self._patch_size[2]
|
|
|
15 |
self._patch_size = (1, patch_size, patch_size)
|
16 |
|
17 |
@abstractmethod
|
18 |
+
def patchify(
|
19 |
+
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
|
20 |
+
) -> Tuple[Tensor, Tensor]:
|
21 |
pass
|
22 |
|
23 |
@abstractmethod
|
24 |
def unpatchify(
|
25 |
+
self,
|
26 |
+
latents: Tensor,
|
27 |
+
output_height: int,
|
28 |
+
output_width: int,
|
29 |
+
output_num_frames: int,
|
30 |
+
out_channels: int,
|
31 |
) -> Tuple[Tensor, Tensor]:
|
32 |
pass
|
33 |
|
|
|
35 |
def patch_size(self):
|
36 |
return self._patch_size
|
37 |
|
38 |
+
def get_grid(
|
39 |
+
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
|
40 |
+
):
|
41 |
f = orig_num_frames // self._patch_size[0]
|
42 |
h = orig_height // self._patch_size[1]
|
43 |
w = orig_width // self._patch_size[2]
|
|
|
73 |
)
|
74 |
return latents
|
75 |
|
76 |
+
|
77 |
class SymmetricPatchifier(Patchifier):
|
78 |
def patchify(
|
79 |
self,
|
|
|
82 |
return pixart_alpha_patchify(latents, self._patch_size)
|
83 |
|
84 |
def unpatchify(
|
85 |
+
self,
|
86 |
+
latents: Tensor,
|
87 |
+
output_height: int,
|
88 |
+
output_width: int,
|
89 |
+
output_num_frames: int,
|
90 |
+
out_channels: int,
|
91 |
) -> Tuple[Tensor, Tensor]:
|
92 |
output_height = output_height // self._patch_size[1]
|
93 |
output_width = output_width // self._patch_size[2]
|
xora/models/transformers/transformer3d.py
CHANGED
@@ -17,6 +17,7 @@ from xora.models.transformers.embeddings import get_3d_sincos_pos_embed
|
|
17 |
|
18 |
logger = logging.get_logger(__name__)
|
19 |
|
|
|
20 |
@dataclass
|
21 |
class Transformer3DModelOutput(BaseOutput):
|
22 |
"""
|
@@ -68,7 +69,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
68 |
timestep_scale_multiplier: Optional[float] = None,
|
69 |
):
|
70 |
super().__init__()
|
71 |
-
self.use_tpu_flash_attention =
|
|
|
|
|
72 |
self.use_linear_projection = use_linear_projection
|
73 |
self.num_attention_heads = num_attention_heads
|
74 |
self.attention_head_dim = attention_head_dim
|
@@ -86,7 +89,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
86 |
self.timestep_scale_multiplier = timestep_scale_multiplier
|
87 |
|
88 |
if self.positional_embedding_type == "absolute":
|
89 |
-
embed_dim_3d =
|
|
|
|
|
90 |
if self.project_to_2d_pos:
|
91 |
self.to_2d_proj = torch.nn.Linear(embed_dim_3d, inner_dim, bias=False)
|
92 |
self._init_to_2d_proj_weights(self.to_2d_proj)
|
@@ -131,18 +136,24 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
131 |
# 4. Define output layers
|
132 |
self.out_channels = in_channels if out_channels is None else out_channels
|
133 |
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
134 |
-
self.scale_shift_table = nn.Parameter(
|
|
|
|
|
135 |
self.proj_out = nn.Linear(inner_dim, self.out_channels)
|
136 |
|
137 |
# 5. PixArt-Alpha blocks.
|
138 |
-
self.adaln_single = AdaLayerNormSingle(
|
|
|
|
|
139 |
if adaptive_norm == "single_scale":
|
140 |
# Use 4 channels instead of the 6 for the PixArt-Alpha scale + shift ada norm.
|
141 |
self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
|
142 |
|
143 |
self.caption_projection = None
|
144 |
if caption_channels is not None:
|
145 |
-
self.caption_projection = PixArtAlphaTextProjection(
|
|
|
|
|
146 |
|
147 |
self.gradient_checkpointing = False
|
148 |
|
@@ -169,16 +180,32 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
169 |
self.apply(_basic_init)
|
170 |
|
171 |
# Initialize timestep embedding MLP:
|
172 |
-
nn.init.normal_(
|
173 |
-
|
|
|
|
|
|
|
|
|
174 |
nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
|
175 |
|
176 |
if hasattr(self.adaln_single.emb, "resolution_embedder"):
|
177 |
-
nn.init.normal_(
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
|
180 |
-
nn.init.normal_(
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
# Initialize caption embedding MLP:
|
184 |
nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
|
@@ -220,7 +247,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
220 |
|
221 |
def get_fractional_positions(self, indices_grid):
|
222 |
fractional_positions = torch.stack(
|
223 |
-
[
|
|
|
|
|
|
|
|
|
224 |
)
|
225 |
return fractional_positions
|
226 |
|
@@ -236,7 +267,13 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
236 |
device = fractional_positions.device
|
237 |
if spacing == "exp":
|
238 |
indices = theta ** (
|
239 |
-
torch.linspace(
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
)
|
241 |
indices = indices.to(dtype=dtype)
|
242 |
elif spacing == "exp_2":
|
@@ -245,14 +282,24 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
245 |
elif spacing == "linear":
|
246 |
indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
|
247 |
elif spacing == "sqrt":
|
248 |
-
indices = torch.linspace(
|
|
|
|
|
249 |
|
250 |
indices = indices * math.pi / 2
|
251 |
|
252 |
if spacing == "exp_2":
|
253 |
-
freqs = (
|
|
|
|
|
|
|
|
|
254 |
else:
|
255 |
-
freqs = (
|
|
|
|
|
|
|
|
|
256 |
|
257 |
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
258 |
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
@@ -336,7 +383,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
336 |
|
337 |
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
338 |
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
339 |
-
encoder_attention_mask = (
|
|
|
|
|
340 |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
341 |
|
342 |
# 1. Input
|
@@ -346,7 +395,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
346 |
timestep = self.timestep_scale_multiplier * timestep
|
347 |
|
348 |
if self.positional_embedding_type == "absolute":
|
349 |
-
pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(
|
|
|
|
|
350 |
if self.project_to_2d_pos:
|
351 |
pos_embed = self.to_2d_proj(pos_embed_3d)
|
352 |
hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
|
@@ -363,13 +414,17 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
363 |
)
|
364 |
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
365 |
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
366 |
-
embedded_timestep = embedded_timestep.view(
|
|
|
|
|
367 |
|
368 |
# 2. Blocks
|
369 |
if self.caption_projection is not None:
|
370 |
batch_size = hidden_states.shape[0]
|
371 |
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
372 |
-
encoder_hidden_states = encoder_hidden_states.view(
|
|
|
|
|
373 |
|
374 |
for block in self.transformer_blocks:
|
375 |
if self.training and self.gradient_checkpointing:
|
@@ -383,7 +438,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
383 |
|
384 |
return custom_forward
|
385 |
|
386 |
-
ckpt_kwargs: Dict[str, Any] =
|
|
|
|
|
387 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
388 |
create_custom_forward(block),
|
389 |
hidden_states,
|
@@ -409,7 +466,9 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
409 |
)
|
410 |
|
411 |
# 3. Output
|
412 |
-
scale_shift_values =
|
|
|
|
|
413 |
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
414 |
hidden_states = self.norm_out(hidden_states)
|
415 |
# Modulation
|
@@ -422,7 +481,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
422 |
|
423 |
def get_absolute_pos_embed(self, grid):
|
424 |
grid_np = grid[0].cpu().numpy()
|
425 |
-
embed_dim_3d =
|
|
|
|
|
|
|
|
|
426 |
pos_embed = get_3d_sincos_pos_embed( # (f h w)
|
427 |
embed_dim_3d,
|
428 |
grid_np,
|
|
|
17 |
|
18 |
logger = logging.get_logger(__name__)
|
19 |
|
20 |
+
|
21 |
@dataclass
|
22 |
class Transformer3DModelOutput(BaseOutput):
|
23 |
"""
|
|
|
69 |
timestep_scale_multiplier: Optional[float] = None,
|
70 |
):
|
71 |
super().__init__()
|
72 |
+
self.use_tpu_flash_attention = (
|
73 |
+
use_tpu_flash_attention # FIXME: push config down to the attention modules
|
74 |
+
)
|
75 |
self.use_linear_projection = use_linear_projection
|
76 |
self.num_attention_heads = num_attention_heads
|
77 |
self.attention_head_dim = attention_head_dim
|
|
|
89 |
self.timestep_scale_multiplier = timestep_scale_multiplier
|
90 |
|
91 |
if self.positional_embedding_type == "absolute":
|
92 |
+
embed_dim_3d = (
|
93 |
+
math.ceil((inner_dim / 2) * 3) if project_to_2d_pos else inner_dim
|
94 |
+
)
|
95 |
if self.project_to_2d_pos:
|
96 |
self.to_2d_proj = torch.nn.Linear(embed_dim_3d, inner_dim, bias=False)
|
97 |
self._init_to_2d_proj_weights(self.to_2d_proj)
|
|
|
136 |
# 4. Define output layers
|
137 |
self.out_channels = in_channels if out_channels is None else out_channels
|
138 |
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
139 |
+
self.scale_shift_table = nn.Parameter(
|
140 |
+
torch.randn(2, inner_dim) / inner_dim**0.5
|
141 |
+
)
|
142 |
self.proj_out = nn.Linear(inner_dim, self.out_channels)
|
143 |
|
144 |
# 5. PixArt-Alpha blocks.
|
145 |
+
self.adaln_single = AdaLayerNormSingle(
|
146 |
+
inner_dim, use_additional_conditions=False
|
147 |
+
)
|
148 |
if adaptive_norm == "single_scale":
|
149 |
# Use 4 channels instead of the 6 for the PixArt-Alpha scale + shift ada norm.
|
150 |
self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
|
151 |
|
152 |
self.caption_projection = None
|
153 |
if caption_channels is not None:
|
154 |
+
self.caption_projection = PixArtAlphaTextProjection(
|
155 |
+
in_features=caption_channels, hidden_size=inner_dim
|
156 |
+
)
|
157 |
|
158 |
self.gradient_checkpointing = False
|
159 |
|
|
|
180 |
self.apply(_basic_init)
|
181 |
|
182 |
# Initialize timestep embedding MLP:
|
183 |
+
nn.init.normal_(
|
184 |
+
self.adaln_single.emb.timestep_embedder.linear_1.weight, std=embedding_std
|
185 |
+
)
|
186 |
+
nn.init.normal_(
|
187 |
+
self.adaln_single.emb.timestep_embedder.linear_2.weight, std=embedding_std
|
188 |
+
)
|
189 |
nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std)
|
190 |
|
191 |
if hasattr(self.adaln_single.emb, "resolution_embedder"):
|
192 |
+
nn.init.normal_(
|
193 |
+
self.adaln_single.emb.resolution_embedder.linear_1.weight,
|
194 |
+
std=embedding_std,
|
195 |
+
)
|
196 |
+
nn.init.normal_(
|
197 |
+
self.adaln_single.emb.resolution_embedder.linear_2.weight,
|
198 |
+
std=embedding_std,
|
199 |
+
)
|
200 |
if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
|
201 |
+
nn.init.normal_(
|
202 |
+
self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight,
|
203 |
+
std=embedding_std,
|
204 |
+
)
|
205 |
+
nn.init.normal_(
|
206 |
+
self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight,
|
207 |
+
std=embedding_std,
|
208 |
+
)
|
209 |
|
210 |
# Initialize caption embedding MLP:
|
211 |
nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std)
|
|
|
247 |
|
248 |
def get_fractional_positions(self, indices_grid):
|
249 |
fractional_positions = torch.stack(
|
250 |
+
[
|
251 |
+
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
252 |
+
for i in range(3)
|
253 |
+
],
|
254 |
+
dim=-1,
|
255 |
)
|
256 |
return fractional_positions
|
257 |
|
|
|
267 |
device = fractional_positions.device
|
268 |
if spacing == "exp":
|
269 |
indices = theta ** (
|
270 |
+
torch.linspace(
|
271 |
+
math.log(start, theta),
|
272 |
+
math.log(end, theta),
|
273 |
+
dim // 6,
|
274 |
+
device=device,
|
275 |
+
dtype=dtype,
|
276 |
+
)
|
277 |
)
|
278 |
indices = indices.to(dtype=dtype)
|
279 |
elif spacing == "exp_2":
|
|
|
282 |
elif spacing == "linear":
|
283 |
indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
|
284 |
elif spacing == "sqrt":
|
285 |
+
indices = torch.linspace(
|
286 |
+
start**2, end**2, dim // 6, device=device, dtype=dtype
|
287 |
+
).sqrt()
|
288 |
|
289 |
indices = indices * math.pi / 2
|
290 |
|
291 |
if spacing == "exp_2":
|
292 |
+
freqs = (
|
293 |
+
(indices * fractional_positions.unsqueeze(-1))
|
294 |
+
.transpose(-1, -2)
|
295 |
+
.flatten(2)
|
296 |
+
)
|
297 |
else:
|
298 |
+
freqs = (
|
299 |
+
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
300 |
+
.transpose(-1, -2)
|
301 |
+
.flatten(2)
|
302 |
+
)
|
303 |
|
304 |
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
305 |
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
|
|
383 |
|
384 |
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
385 |
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
386 |
+
encoder_attention_mask = (
|
387 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
388 |
+
) * -10000.0
|
389 |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
390 |
|
391 |
# 1. Input
|
|
|
395 |
timestep = self.timestep_scale_multiplier * timestep
|
396 |
|
397 |
if self.positional_embedding_type == "absolute":
|
398 |
+
pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(
|
399 |
+
hidden_states.device
|
400 |
+
)
|
401 |
if self.project_to_2d_pos:
|
402 |
pos_embed = self.to_2d_proj(pos_embed_3d)
|
403 |
hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
|
|
|
414 |
)
|
415 |
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
416 |
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
417 |
+
embedded_timestep = embedded_timestep.view(
|
418 |
+
batch_size, -1, embedded_timestep.shape[-1]
|
419 |
+
)
|
420 |
|
421 |
# 2. Blocks
|
422 |
if self.caption_projection is not None:
|
423 |
batch_size = hidden_states.shape[0]
|
424 |
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
425 |
+
encoder_hidden_states = encoder_hidden_states.view(
|
426 |
+
batch_size, -1, hidden_states.shape[-1]
|
427 |
+
)
|
428 |
|
429 |
for block in self.transformer_blocks:
|
430 |
if self.training and self.gradient_checkpointing:
|
|
|
438 |
|
439 |
return custom_forward
|
440 |
|
441 |
+
ckpt_kwargs: Dict[str, Any] = (
|
442 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
443 |
+
)
|
444 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
445 |
create_custom_forward(block),
|
446 |
hidden_states,
|
|
|
466 |
)
|
467 |
|
468 |
# 3. Output
|
469 |
+
scale_shift_values = (
|
470 |
+
self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
471 |
+
)
|
472 |
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
473 |
hidden_states = self.norm_out(hidden_states)
|
474 |
# Modulation
|
|
|
481 |
|
482 |
def get_absolute_pos_embed(self, grid):
|
483 |
grid_np = grid[0].cpu().numpy()
|
484 |
+
embed_dim_3d = (
|
485 |
+
math.ceil((self.inner_dim / 2) * 3)
|
486 |
+
if self.project_to_2d_pos
|
487 |
+
else self.inner_dim
|
488 |
+
)
|
489 |
pos_embed = get_3d_sincos_pos_embed( # (f h w)
|
490 |
embed_dim_3d,
|
491 |
grid_np,
|
xora/pipelines/pipeline_video_pixart_alpha.py
CHANGED
@@ -5,12 +5,10 @@ import math
|
|
5 |
import re
|
6 |
import urllib.parse as ul
|
7 |
from typing import Callable, Dict, List, Optional, Tuple, Union
|
8 |
-
from abc import ABC, abstractmethod
|
9 |
|
10 |
|
11 |
import torch
|
12 |
import torch.nn.functional as F
|
13 |
-
from torch import Tensor
|
14 |
from diffusers.image_processor import VaeImageProcessor
|
15 |
from diffusers.models import AutoencoderKL
|
16 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
@@ -29,7 +27,11 @@ from transformers import T5EncoderModel, T5Tokenizer
|
|
29 |
|
30 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
31 |
from xora.models.transformers.symmetric_patchifier import Patchifier
|
32 |
-
from xora.models.autoencoders.vae_encode import
|
|
|
|
|
|
|
|
|
33 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
34 |
from xora.schedulers.rf import TimestepShifter
|
35 |
from xora.utils.conditioning_method import ConditioningMethod
|
@@ -161,7 +163,9 @@ def retrieve_timesteps(
|
|
161 |
second element is the number of inference steps.
|
162 |
"""
|
163 |
if timesteps is not None:
|
164 |
-
accepts_timesteps = "timesteps" in set(
|
|
|
|
|
165 |
if not accepts_timesteps:
|
166 |
raise ValueError(
|
167 |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
@@ -238,7 +242,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
238 |
patchifier=patchifier,
|
239 |
)
|
240 |
|
241 |
-
self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(
|
|
|
|
|
242 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
243 |
|
244 |
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
@@ -320,12 +326,16 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
320 |
return_tensors="pt",
|
321 |
)
|
322 |
text_input_ids = text_inputs.input_ids
|
323 |
-
untruncated_ids = self.tokenizer(
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
|
|
|
|
|
|
|
|
329 |
logger.warning(
|
330 |
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
331 |
f" {max_length} tokens: {removed_text}"
|
@@ -334,7 +344,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
334 |
prompt_attention_mask = text_inputs.attention_mask
|
335 |
prompt_attention_mask = prompt_attention_mask.to(device)
|
336 |
|
337 |
-
prompt_embeds = self.text_encoder(
|
|
|
|
|
338 |
prompt_embeds = prompt_embeds[0]
|
339 |
|
340 |
if self.text_encoder is not None:
|
@@ -349,14 +361,20 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
349 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
350 |
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
351 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
352 |
-
prompt_embeds = prompt_embeds.view(
|
|
|
|
|
353 |
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
|
354 |
-
prompt_attention_mask = prompt_attention_mask.view(
|
|
|
|
|
355 |
|
356 |
# get unconditional embeddings for classifier free guidance
|
357 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
358 |
uncond_tokens = [negative_prompt] * batch_size
|
359 |
-
uncond_tokens = self._text_preprocessing(
|
|
|
|
|
360 |
max_length = prompt_embeds.shape[1]
|
361 |
uncond_input = self.tokenizer(
|
362 |
uncond_tokens,
|
@@ -371,7 +389,8 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
371 |
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
372 |
|
373 |
negative_prompt_embeds = self.text_encoder(
|
374 |
-
uncond_input.input_ids.to(device),
|
|
|
375 |
)
|
376 |
negative_prompt_embeds = negative_prompt_embeds[0]
|
377 |
|
@@ -379,18 +398,33 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
379 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
380 |
seq_len = negative_prompt_embeds.shape[1]
|
381 |
|
382 |
-
negative_prompt_embeds = negative_prompt_embeds.to(
|
|
|
|
|
383 |
|
384 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
385 |
-
|
|
|
|
|
|
|
|
|
386 |
|
387 |
-
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
|
388 |
-
|
|
|
|
|
|
|
|
|
389 |
else:
|
390 |
negative_prompt_embeds = None
|
391 |
negative_prompt_attention_mask = None
|
392 |
|
393 |
-
return
|
|
|
|
|
|
|
|
|
|
|
394 |
|
395 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
396 |
def prepare_extra_step_kwargs(self, generator, eta):
|
@@ -399,13 +433,17 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
399 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
400 |
# and should be between [0, 1]
|
401 |
|
402 |
-
accepts_eta = "eta" in set(
|
|
|
|
|
403 |
extra_step_kwargs = {}
|
404 |
if accepts_eta:
|
405 |
extra_step_kwargs["eta"] = eta
|
406 |
|
407 |
# check if the scheduler accepts generator
|
408 |
-
accepts_generator = "generator" in set(
|
|
|
|
|
409 |
if accepts_generator:
|
410 |
extra_step_kwargs["generator"] = generator
|
411 |
return extra_step_kwargs
|
@@ -422,7 +460,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
422 |
negative_prompt_attention_mask=None,
|
423 |
):
|
424 |
if height % 8 != 0 or width % 8 != 0:
|
425 |
-
raise ValueError(
|
|
|
|
|
426 |
|
427 |
if prompt is not None and prompt_embeds is not None:
|
428 |
raise ValueError(
|
@@ -433,8 +473,12 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
433 |
raise ValueError(
|
434 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
435 |
)
|
436 |
-
elif prompt is not None and (
|
437 |
-
|
|
|
|
|
|
|
|
|
438 |
|
439 |
if prompt is not None and negative_prompt_embeds is not None:
|
440 |
raise ValueError(
|
@@ -449,10 +493,17 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
449 |
)
|
450 |
|
451 |
if prompt_embeds is not None and prompt_attention_mask is None:
|
452 |
-
raise ValueError(
|
|
|
|
|
453 |
|
454 |
-
if
|
455 |
-
|
|
|
|
|
|
|
|
|
|
|
456 |
|
457 |
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
458 |
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
@@ -471,12 +522,16 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
471 |
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
472 |
def _text_preprocessing(self, text, clean_caption=False):
|
473 |
if clean_caption and not is_bs4_available():
|
474 |
-
logger.warn(
|
|
|
|
|
475 |
logger.warn("Setting `clean_caption` to False...")
|
476 |
clean_caption = False
|
477 |
|
478 |
if clean_caption and not is_ftfy_available():
|
479 |
-
logger.warn(
|
|
|
|
|
480 |
logger.warn("Setting `clean_caption` to False...")
|
481 |
clean_caption = False
|
482 |
|
@@ -564,13 +619,17 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
564 |
# "123456.."
|
565 |
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
566 |
# filenames:
|
567 |
-
caption = re.sub(
|
|
|
|
|
568 |
|
569 |
#
|
570 |
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
571 |
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
572 |
|
573 |
-
caption = re.sub(
|
|
|
|
|
574 |
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
575 |
|
576 |
# this-is-my-cute-cat / this_is_my_cute_cat
|
@@ -588,10 +647,14 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
588 |
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
589 |
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
590 |
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
591 |
-
caption = re.sub(
|
|
|
|
|
592 |
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
593 |
|
594 |
-
caption = re.sub(
|
|
|
|
|
595 |
|
596 |
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
597 |
|
@@ -610,7 +673,15 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
610 |
|
611 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
612 |
def prepare_latents(
|
613 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
):
|
615 |
shape = (
|
616 |
batch_size,
|
@@ -625,10 +696,14 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
625 |
)
|
626 |
|
627 |
if latents is None:
|
628 |
-
latents = randn_tensor(
|
|
|
|
|
629 |
elif latents_mask is not None:
|
630 |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
631 |
-
latents = latents * latents_mask[..., None] + noise * (
|
|
|
|
|
632 |
else:
|
633 |
latents = latents.to(device)
|
634 |
|
@@ -637,7 +712,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
637 |
return latents
|
638 |
|
639 |
@staticmethod
|
640 |
-
def classify_height_width_bin(
|
|
|
|
|
641 |
"""Returns binned height and width."""
|
642 |
ar = float(height / width)
|
643 |
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
@@ -645,7 +722,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
645 |
return int(default_hw[0]), int(default_hw[1])
|
646 |
|
647 |
@staticmethod
|
648 |
-
def resize_and_crop_tensor(
|
|
|
|
|
649 |
n_frames, orig_height, orig_width = samples.shape[-3:]
|
650 |
|
651 |
# Check if resizing is needed
|
@@ -656,7 +735,12 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
656 |
|
657 |
# Resize
|
658 |
samples = rearrange(samples, "b c n h w -> (b n) c h w")
|
659 |
-
samples = F.interpolate(
|
|
|
|
|
|
|
|
|
|
|
660 |
samples = rearrange(samples, "(b n) c h w -> b c n h w", n=n_frames)
|
661 |
|
662 |
# Center Crop
|
@@ -821,14 +905,21 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
821 |
)
|
822 |
if do_classifier_free_guidance:
|
823 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
824 |
-
prompt_attention_mask = torch.cat(
|
|
|
|
|
825 |
|
826 |
# 3b. Encode and prepare conditioning data
|
827 |
self.video_scale_factor = self.video_scale_factor if is_video else 1
|
828 |
conditioning_method = kwargs.get("conditioning_method", None)
|
829 |
vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
|
830 |
init_latents, conditioning_mask = self.prepare_conditioning(
|
831 |
-
media_items,
|
|
|
|
|
|
|
|
|
|
|
832 |
)
|
833 |
|
834 |
# 4. Prepare latents.
|
@@ -851,29 +942,46 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
851 |
)
|
852 |
if conditioning_mask is not None and is_video:
|
853 |
assert num_images_per_prompt == 1
|
854 |
-
conditioning_mask =
|
|
|
|
|
|
|
|
|
855 |
|
856 |
# 5. Prepare timesteps
|
857 |
retrieve_timesteps_kwargs = {}
|
858 |
if isinstance(self.scheduler, TimestepShifter):
|
859 |
retrieve_timesteps_kwargs["samples"] = latents
|
860 |
timesteps, num_inference_steps = retrieve_timesteps(
|
861 |
-
self.scheduler,
|
|
|
|
|
|
|
|
|
862 |
)
|
863 |
|
864 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
865 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
866 |
|
867 |
# 7. Denoising loop
|
868 |
-
num_warmup_steps = max(
|
|
|
|
|
869 |
|
870 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
871 |
for i, t in enumerate(timesteps):
|
872 |
-
latent_model_input =
|
873 |
-
|
|
|
|
|
|
|
|
|
874 |
|
875 |
latent_frame_rates = (
|
876 |
-
torch.ones(
|
|
|
|
|
|
|
877 |
)
|
878 |
|
879 |
current_timestep = t
|
@@ -885,13 +993,25 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
885 |
dtype = torch.float32 if is_mps else torch.float64
|
886 |
else:
|
887 |
dtype = torch.int32 if is_mps else torch.int64
|
888 |
-
current_timestep = torch.tensor(
|
|
|
|
|
|
|
|
|
889 |
elif len(current_timestep.shape) == 0:
|
890 |
-
current_timestep = current_timestep[None].to(
|
|
|
|
|
891 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
892 |
-
current_timestep = current_timestep.expand(
|
|
|
|
|
893 |
scale_grid = (
|
894 |
-
(
|
|
|
|
|
|
|
|
|
895 |
if self.transformer.use_rope
|
896 |
else None
|
897 |
)
|
@@ -920,11 +1040,16 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
920 |
# perform guidance
|
921 |
if do_classifier_free_guidance:
|
922 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
923 |
-
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
|
|
|
924 |
current_timestep, _ = current_timestep.chunk(2)
|
925 |
|
926 |
# learned sigma
|
927 |
-
if
|
|
|
|
|
|
|
928 |
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
929 |
|
930 |
# compute previous image: x_t -> x_t-1
|
@@ -937,7 +1062,9 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
937 |
)[0]
|
938 |
|
939 |
# call the callback, if provided
|
940 |
-
if i == len(timesteps) - 1 or (
|
|
|
|
|
941 |
progress_bar.update()
|
942 |
|
943 |
if callback_on_step_end is not None:
|
@@ -948,11 +1075,15 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
948 |
output_height=latent_height,
|
949 |
output_width=latent_width,
|
950 |
output_num_frames=latent_num_frames,
|
951 |
-
out_channels=self.transformer.in_channels
|
|
|
952 |
)
|
953 |
if output_type != "latent":
|
954 |
image = vae_decode(
|
955 |
-
latents,
|
|
|
|
|
|
|
956 |
)
|
957 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
958 |
|
@@ -1005,20 +1136,31 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
|
|
1005 |
vae_per_channel_normalize=vae_per_channel_normalize,
|
1006 |
).float()
|
1007 |
|
1008 |
-
init_len, target_len =
|
|
|
|
|
|
|
1009 |
if isinstance(self.vae, CausalVideoAutoencoder):
|
1010 |
target_len += 1
|
1011 |
init_latents = init_latents[:, :, :target_len]
|
1012 |
if target_len > init_len:
|
1013 |
repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division
|
1014 |
-
init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[
|
|
|
|
|
1015 |
|
1016 |
# Prepare the conditioning mask (1.0 = condition on this token)
|
1017 |
b, n, f, h, w = init_latents.shape
|
1018 |
conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
|
1019 |
-
if method in [
|
|
|
|
|
|
|
1020 |
conditioning_mask[:, :, 0] = 1.0
|
1021 |
-
if method in [
|
|
|
|
|
|
|
1022 |
conditioning_mask[:, :, -1] = 1.0
|
1023 |
|
1024 |
# Patchify the init latents and the mask
|
|
|
5 |
import re
|
6 |
import urllib.parse as ul
|
7 |
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
8 |
|
9 |
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
|
|
12 |
from diffusers.image_processor import VaeImageProcessor
|
13 |
from diffusers.models import AutoencoderKL
|
14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
27 |
|
28 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
29 |
from xora.models.transformers.symmetric_patchifier import Patchifier
|
30 |
+
from xora.models.autoencoders.vae_encode import (
|
31 |
+
get_vae_size_scale_factor,
|
32 |
+
vae_decode,
|
33 |
+
vae_encode,
|
34 |
+
)
|
35 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
36 |
from xora.schedulers.rf import TimestepShifter
|
37 |
from xora.utils.conditioning_method import ConditioningMethod
|
|
|
163 |
second element is the number of inference steps.
|
164 |
"""
|
165 |
if timesteps is not None:
|
166 |
+
accepts_timesteps = "timesteps" in set(
|
167 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
168 |
+
)
|
169 |
if not accepts_timesteps:
|
170 |
raise ValueError(
|
171 |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
|
|
242 |
patchifier=patchifier,
|
243 |
)
|
244 |
|
245 |
+
self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(
|
246 |
+
self.vae
|
247 |
+
)
|
248 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
249 |
|
250 |
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
|
|
326 |
return_tensors="pt",
|
327 |
)
|
328 |
text_input_ids = text_inputs.input_ids
|
329 |
+
untruncated_ids = self.tokenizer(
|
330 |
+
prompt, padding="longest", return_tensors="pt"
|
331 |
+
).input_ids
|
332 |
+
|
333 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
334 |
+
-1
|
335 |
+
] and not torch.equal(text_input_ids, untruncated_ids):
|
336 |
+
removed_text = self.tokenizer.batch_decode(
|
337 |
+
untruncated_ids[:, max_length - 1 : -1]
|
338 |
+
)
|
339 |
logger.warning(
|
340 |
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
341 |
f" {max_length} tokens: {removed_text}"
|
|
|
344 |
prompt_attention_mask = text_inputs.attention_mask
|
345 |
prompt_attention_mask = prompt_attention_mask.to(device)
|
346 |
|
347 |
+
prompt_embeds = self.text_encoder(
|
348 |
+
text_input_ids.to(device), attention_mask=prompt_attention_mask
|
349 |
+
)
|
350 |
prompt_embeds = prompt_embeds[0]
|
351 |
|
352 |
if self.text_encoder is not None:
|
|
|
361 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
362 |
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
363 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
364 |
+
prompt_embeds = prompt_embeds.view(
|
365 |
+
bs_embed * num_images_per_prompt, seq_len, -1
|
366 |
+
)
|
367 |
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
|
368 |
+
prompt_attention_mask = prompt_attention_mask.view(
|
369 |
+
bs_embed * num_images_per_prompt, -1
|
370 |
+
)
|
371 |
|
372 |
# get unconditional embeddings for classifier free guidance
|
373 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
374 |
uncond_tokens = [negative_prompt] * batch_size
|
375 |
+
uncond_tokens = self._text_preprocessing(
|
376 |
+
uncond_tokens, clean_caption=clean_caption
|
377 |
+
)
|
378 |
max_length = prompt_embeds.shape[1]
|
379 |
uncond_input = self.tokenizer(
|
380 |
uncond_tokens,
|
|
|
389 |
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
390 |
|
391 |
negative_prompt_embeds = self.text_encoder(
|
392 |
+
uncond_input.input_ids.to(device),
|
393 |
+
attention_mask=negative_prompt_attention_mask,
|
394 |
)
|
395 |
negative_prompt_embeds = negative_prompt_embeds[0]
|
396 |
|
|
|
398 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
399 |
seq_len = negative_prompt_embeds.shape[1]
|
400 |
|
401 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
402 |
+
dtype=dtype, device=device
|
403 |
+
)
|
404 |
|
405 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
406 |
+
1, num_images_per_prompt, 1
|
407 |
+
)
|
408 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
409 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
410 |
+
)
|
411 |
|
412 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
|
413 |
+
1, num_images_per_prompt
|
414 |
+
)
|
415 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
|
416 |
+
bs_embed * num_images_per_prompt, -1
|
417 |
+
)
|
418 |
else:
|
419 |
negative_prompt_embeds = None
|
420 |
negative_prompt_attention_mask = None
|
421 |
|
422 |
+
return (
|
423 |
+
prompt_embeds,
|
424 |
+
prompt_attention_mask,
|
425 |
+
negative_prompt_embeds,
|
426 |
+
negative_prompt_attention_mask,
|
427 |
+
)
|
428 |
|
429 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
430 |
def prepare_extra_step_kwargs(self, generator, eta):
|
|
|
433 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
434 |
# and should be between [0, 1]
|
435 |
|
436 |
+
accepts_eta = "eta" in set(
|
437 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
438 |
+
)
|
439 |
extra_step_kwargs = {}
|
440 |
if accepts_eta:
|
441 |
extra_step_kwargs["eta"] = eta
|
442 |
|
443 |
# check if the scheduler accepts generator
|
444 |
+
accepts_generator = "generator" in set(
|
445 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
446 |
+
)
|
447 |
if accepts_generator:
|
448 |
extra_step_kwargs["generator"] = generator
|
449 |
return extra_step_kwargs
|
|
|
460 |
negative_prompt_attention_mask=None,
|
461 |
):
|
462 |
if height % 8 != 0 or width % 8 != 0:
|
463 |
+
raise ValueError(
|
464 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
465 |
+
)
|
466 |
|
467 |
if prompt is not None and prompt_embeds is not None:
|
468 |
raise ValueError(
|
|
|
473 |
raise ValueError(
|
474 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
475 |
)
|
476 |
+
elif prompt is not None and (
|
477 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
478 |
+
):
|
479 |
+
raise ValueError(
|
480 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
481 |
+
)
|
482 |
|
483 |
if prompt is not None and negative_prompt_embeds is not None:
|
484 |
raise ValueError(
|
|
|
493 |
)
|
494 |
|
495 |
if prompt_embeds is not None and prompt_attention_mask is None:
|
496 |
+
raise ValueError(
|
497 |
+
"Must provide `prompt_attention_mask` when specifying `prompt_embeds`."
|
498 |
+
)
|
499 |
|
500 |
+
if (
|
501 |
+
negative_prompt_embeds is not None
|
502 |
+
and negative_prompt_attention_mask is None
|
503 |
+
):
|
504 |
+
raise ValueError(
|
505 |
+
"Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`."
|
506 |
+
)
|
507 |
|
508 |
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
509 |
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
|
|
522 |
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
523 |
def _text_preprocessing(self, text, clean_caption=False):
|
524 |
if clean_caption and not is_bs4_available():
|
525 |
+
logger.warn(
|
526 |
+
BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")
|
527 |
+
)
|
528 |
logger.warn("Setting `clean_caption` to False...")
|
529 |
clean_caption = False
|
530 |
|
531 |
if clean_caption and not is_ftfy_available():
|
532 |
+
logger.warn(
|
533 |
+
BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")
|
534 |
+
)
|
535 |
logger.warn("Setting `clean_caption` to False...")
|
536 |
clean_caption = False
|
537 |
|
|
|
619 |
# "123456.."
|
620 |
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
621 |
# filenames:
|
622 |
+
caption = re.sub(
|
623 |
+
r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption
|
624 |
+
)
|
625 |
|
626 |
#
|
627 |
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
628 |
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
629 |
|
630 |
+
caption = re.sub(
|
631 |
+
self.bad_punct_regex, r" ", caption
|
632 |
+
) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
633 |
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
634 |
|
635 |
# this-is-my-cute-cat / this_is_my_cute_cat
|
|
|
647 |
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
648 |
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
649 |
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
650 |
+
caption = re.sub(
|
651 |
+
r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption
|
652 |
+
)
|
653 |
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
654 |
|
655 |
+
caption = re.sub(
|
656 |
+
r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption
|
657 |
+
) # j2d1a2a...
|
658 |
|
659 |
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
660 |
|
|
|
673 |
|
674 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
675 |
def prepare_latents(
|
676 |
+
self,
|
677 |
+
batch_size,
|
678 |
+
num_latent_channels,
|
679 |
+
num_patches,
|
680 |
+
dtype,
|
681 |
+
device,
|
682 |
+
generator,
|
683 |
+
latents=None,
|
684 |
+
latents_mask=None,
|
685 |
):
|
686 |
shape = (
|
687 |
batch_size,
|
|
|
696 |
)
|
697 |
|
698 |
if latents is None:
|
699 |
+
latents = randn_tensor(
|
700 |
+
shape, generator=generator, device=device, dtype=dtype
|
701 |
+
)
|
702 |
elif latents_mask is not None:
|
703 |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
704 |
+
latents = latents * latents_mask[..., None] + noise * (
|
705 |
+
1 - latents_mask[..., None]
|
706 |
+
)
|
707 |
else:
|
708 |
latents = latents.to(device)
|
709 |
|
|
|
712 |
return latents
|
713 |
|
714 |
@staticmethod
|
715 |
+
def classify_height_width_bin(
|
716 |
+
height: int, width: int, ratios: dict
|
717 |
+
) -> Tuple[int, int]:
|
718 |
"""Returns binned height and width."""
|
719 |
ar = float(height / width)
|
720 |
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
|
|
722 |
return int(default_hw[0]), int(default_hw[1])
|
723 |
|
724 |
@staticmethod
|
725 |
+
def resize_and_crop_tensor(
|
726 |
+
samples: torch.Tensor, new_width: int, new_height: int
|
727 |
+
) -> torch.Tensor:
|
728 |
n_frames, orig_height, orig_width = samples.shape[-3:]
|
729 |
|
730 |
# Check if resizing is needed
|
|
|
735 |
|
736 |
# Resize
|
737 |
samples = rearrange(samples, "b c n h w -> (b n) c h w")
|
738 |
+
samples = F.interpolate(
|
739 |
+
samples,
|
740 |
+
size=(resized_height, resized_width),
|
741 |
+
mode="bilinear",
|
742 |
+
align_corners=False,
|
743 |
+
)
|
744 |
samples = rearrange(samples, "(b n) c h w -> b c n h w", n=n_frames)
|
745 |
|
746 |
# Center Crop
|
|
|
905 |
)
|
906 |
if do_classifier_free_guidance:
|
907 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
908 |
+
prompt_attention_mask = torch.cat(
|
909 |
+
[negative_prompt_attention_mask, prompt_attention_mask], dim=0
|
910 |
+
)
|
911 |
|
912 |
# 3b. Encode and prepare conditioning data
|
913 |
self.video_scale_factor = self.video_scale_factor if is_video else 1
|
914 |
conditioning_method = kwargs.get("conditioning_method", None)
|
915 |
vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False)
|
916 |
init_latents, conditioning_mask = self.prepare_conditioning(
|
917 |
+
media_items,
|
918 |
+
num_frames,
|
919 |
+
height,
|
920 |
+
width,
|
921 |
+
conditioning_method,
|
922 |
+
vae_per_channel_normalize,
|
923 |
)
|
924 |
|
925 |
# 4. Prepare latents.
|
|
|
942 |
)
|
943 |
if conditioning_mask is not None and is_video:
|
944 |
assert num_images_per_prompt == 1
|
945 |
+
conditioning_mask = (
|
946 |
+
torch.cat([conditioning_mask] * 2)
|
947 |
+
if do_classifier_free_guidance
|
948 |
+
else conditioning_mask
|
949 |
+
)
|
950 |
|
951 |
# 5. Prepare timesteps
|
952 |
retrieve_timesteps_kwargs = {}
|
953 |
if isinstance(self.scheduler, TimestepShifter):
|
954 |
retrieve_timesteps_kwargs["samples"] = latents
|
955 |
timesteps, num_inference_steps = retrieve_timesteps(
|
956 |
+
self.scheduler,
|
957 |
+
num_inference_steps,
|
958 |
+
device,
|
959 |
+
timesteps,
|
960 |
+
**retrieve_timesteps_kwargs,
|
961 |
)
|
962 |
|
963 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
964 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
965 |
|
966 |
# 7. Denoising loop
|
967 |
+
num_warmup_steps = max(
|
968 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
969 |
+
)
|
970 |
|
971 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
972 |
for i, t in enumerate(timesteps):
|
973 |
+
latent_model_input = (
|
974 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
975 |
+
)
|
976 |
+
latent_model_input = self.scheduler.scale_model_input(
|
977 |
+
latent_model_input, t
|
978 |
+
)
|
979 |
|
980 |
latent_frame_rates = (
|
981 |
+
torch.ones(
|
982 |
+
latent_model_input.shape[0], 1, device=latent_model_input.device
|
983 |
+
)
|
984 |
+
* latent_frame_rate
|
985 |
)
|
986 |
|
987 |
current_timestep = t
|
|
|
993 |
dtype = torch.float32 if is_mps else torch.float64
|
994 |
else:
|
995 |
dtype = torch.int32 if is_mps else torch.int64
|
996 |
+
current_timestep = torch.tensor(
|
997 |
+
[current_timestep],
|
998 |
+
dtype=dtype,
|
999 |
+
device=latent_model_input.device,
|
1000 |
+
)
|
1001 |
elif len(current_timestep.shape) == 0:
|
1002 |
+
current_timestep = current_timestep[None].to(
|
1003 |
+
latent_model_input.device
|
1004 |
+
)
|
1005 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1006 |
+
current_timestep = current_timestep.expand(
|
1007 |
+
latent_model_input.shape[0]
|
1008 |
+
).unsqueeze(-1)
|
1009 |
scale_grid = (
|
1010 |
+
(
|
1011 |
+
1 / latent_frame_rates,
|
1012 |
+
self.vae_scale_factor,
|
1013 |
+
self.vae_scale_factor,
|
1014 |
+
)
|
1015 |
if self.transformer.use_rope
|
1016 |
else None
|
1017 |
)
|
|
|
1040 |
# perform guidance
|
1041 |
if do_classifier_free_guidance:
|
1042 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1043 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
1044 |
+
noise_pred_text - noise_pred_uncond
|
1045 |
+
)
|
1046 |
current_timestep, _ = current_timestep.chunk(2)
|
1047 |
|
1048 |
# learned sigma
|
1049 |
+
if (
|
1050 |
+
self.transformer.config.out_channels // 2
|
1051 |
+
== self.transformer.config.in_channels
|
1052 |
+
):
|
1053 |
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
1054 |
|
1055 |
# compute previous image: x_t -> x_t-1
|
|
|
1062 |
)[0]
|
1063 |
|
1064 |
# call the callback, if provided
|
1065 |
+
if i == len(timesteps) - 1 or (
|
1066 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
1067 |
+
):
|
1068 |
progress_bar.update()
|
1069 |
|
1070 |
if callback_on_step_end is not None:
|
|
|
1075 |
output_height=latent_height,
|
1076 |
output_width=latent_width,
|
1077 |
output_num_frames=latent_num_frames,
|
1078 |
+
out_channels=self.transformer.in_channels
|
1079 |
+
// math.prod(self.patchifier.patch_size),
|
1080 |
)
|
1081 |
if output_type != "latent":
|
1082 |
image = vae_decode(
|
1083 |
+
latents,
|
1084 |
+
self.vae,
|
1085 |
+
is_video,
|
1086 |
+
vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
|
1087 |
)
|
1088 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
1089 |
|
|
|
1136 |
vae_per_channel_normalize=vae_per_channel_normalize,
|
1137 |
).float()
|
1138 |
|
1139 |
+
init_len, target_len = (
|
1140 |
+
init_latents.shape[2],
|
1141 |
+
num_frames // self.video_scale_factor,
|
1142 |
+
)
|
1143 |
if isinstance(self.vae, CausalVideoAutoencoder):
|
1144 |
target_len += 1
|
1145 |
init_latents = init_latents[:, :, :target_len]
|
1146 |
if target_len > init_len:
|
1147 |
repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division
|
1148 |
+
init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[
|
1149 |
+
:, :, :target_len
|
1150 |
+
]
|
1151 |
|
1152 |
# Prepare the conditioning mask (1.0 = condition on this token)
|
1153 |
b, n, f, h, w = init_latents.shape
|
1154 |
conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device)
|
1155 |
+
if method in [
|
1156 |
+
ConditioningMethod.FIRST_FRAME,
|
1157 |
+
ConditioningMethod.FIRST_AND_LAST_FRAME,
|
1158 |
+
]:
|
1159 |
conditioning_mask[:, :, 0] = 1.0
|
1160 |
+
if method in [
|
1161 |
+
ConditioningMethod.LAST_FRAME,
|
1162 |
+
ConditioningMethod.FIRST_AND_LAST_FRAME,
|
1163 |
+
]:
|
1164 |
conditioning_mask[:, :, -1] = 1.0
|
1165 |
|
1166 |
# Patchify the init latents and the mask
|
xora/schedulers/rf.py
CHANGED
@@ -22,7 +22,9 @@ def simple_diffusion_resolution_dependent_timestep_shift(
|
|
22 |
elif len(samples.shape) in [4, 5]:
|
23 |
m = math.prod(samples.shape[2:])
|
24 |
else:
|
25 |
-
raise ValueError(
|
|
|
|
|
26 |
snr = (timesteps / (1 - timesteps)) ** 2
|
27 |
shift_snr = torch.log(snr) + 2 * math.log(m / n)
|
28 |
shifted_timesteps = torch.sigmoid(0.5 * shift_snr)
|
@@ -46,7 +48,9 @@ def get_normal_shift(
|
|
46 |
return m * n_tokens + b
|
47 |
|
48 |
|
49 |
-
def sd3_resolution_dependent_timestep_shift(
|
|
|
|
|
50 |
"""
|
51 |
Shifts the timestep schedule as a function of the generated resolution.
|
52 |
|
@@ -70,7 +74,9 @@ def sd3_resolution_dependent_timestep_shift(samples: Tensor, timesteps: Tensor)
|
|
70 |
elif len(samples.shape) in [4, 5]:
|
71 |
m = math.prod(samples.shape[2:])
|
72 |
else:
|
73 |
-
raise ValueError(
|
|
|
|
|
74 |
|
75 |
shift = get_normal_shift(m)
|
76 |
return time_shift(shift, 1, timesteps)
|
@@ -104,12 +110,21 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
104 |
order = 1
|
105 |
|
106 |
@register_to_config
|
107 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
108 |
super().__init__()
|
109 |
self.init_noise_sigma = 1.0
|
110 |
self.num_inference_steps = None
|
111 |
-
self.timesteps = self.sigmas = torch.linspace(
|
112 |
-
|
|
|
|
|
|
|
|
|
113 |
self.shifting = shifting
|
114 |
self.base_resolution = base_resolution
|
115 |
|
@@ -117,10 +132,17 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
117 |
if self.shifting == "SD3":
|
118 |
return sd3_resolution_dependent_timestep_shift(samples, timesteps)
|
119 |
elif self.shifting == "SimpleDiffusion":
|
120 |
-
return simple_diffusion_resolution_dependent_timestep_shift(
|
|
|
|
|
121 |
return timesteps
|
122 |
|
123 |
-
def set_timesteps(
|
|
|
|
|
|
|
|
|
|
|
124 |
"""
|
125 |
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
126 |
|
@@ -130,13 +152,19 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
130 |
device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
|
131 |
"""
|
132 |
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
133 |
-
timesteps = torch.linspace(1, 1 / num_inference_steps, num_inference_steps).to(
|
|
|
|
|
134 |
self.timesteps = self.shift_timesteps(samples, timesteps)
|
135 |
-
self.delta_timesteps = self.timesteps - torch.cat(
|
|
|
|
|
136 |
self.num_inference_steps = num_inference_steps
|
137 |
self.sigmas = self.timesteps
|
138 |
|
139 |
-
def scale_model_input(
|
|
|
|
|
140 |
# pylint: disable=unused-argument
|
141 |
"""
|
142 |
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
@@ -206,7 +234,9 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
206 |
else:
|
207 |
# Timestep per token
|
208 |
assert timestep.ndim == 2
|
209 |
-
current_index = (
|
|
|
|
|
210 |
dt = self.delta_timesteps[current_index]
|
211 |
# Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
|
212 |
dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
|
@@ -228,4 +258,4 @@ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
|
|
228 |
sigmas = append_dims(sigmas, original_samples.ndim)
|
229 |
alphas = 1 - sigmas
|
230 |
noisy_samples = alphas * original_samples + sigmas * noise
|
231 |
-
return noisy_samples
|
|
|
22 |
elif len(samples.shape) in [4, 5]:
|
23 |
m = math.prod(samples.shape[2:])
|
24 |
else:
|
25 |
+
raise ValueError(
|
26 |
+
"Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
|
27 |
+
)
|
28 |
snr = (timesteps / (1 - timesteps)) ** 2
|
29 |
shift_snr = torch.log(snr) + 2 * math.log(m / n)
|
30 |
shifted_timesteps = torch.sigmoid(0.5 * shift_snr)
|
|
|
48 |
return m * n_tokens + b
|
49 |
|
50 |
|
51 |
+
def sd3_resolution_dependent_timestep_shift(
|
52 |
+
samples: Tensor, timesteps: Tensor
|
53 |
+
) -> Tensor:
|
54 |
"""
|
55 |
Shifts the timestep schedule as a function of the generated resolution.
|
56 |
|
|
|
74 |
elif len(samples.shape) in [4, 5]:
|
75 |
m = math.prod(samples.shape[2:])
|
76 |
else:
|
77 |
+
raise ValueError(
|
78 |
+
"Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
|
79 |
+
)
|
80 |
|
81 |
shift = get_normal_shift(m)
|
82 |
return time_shift(shift, 1, timesteps)
|
|
|
110 |
order = 1
|
111 |
|
112 |
@register_to_config
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
num_train_timesteps=1000,
|
116 |
+
shifting: Optional[str] = None,
|
117 |
+
base_resolution: int = 32**2,
|
118 |
+
):
|
119 |
super().__init__()
|
120 |
self.init_noise_sigma = 1.0
|
121 |
self.num_inference_steps = None
|
122 |
+
self.timesteps = self.sigmas = torch.linspace(
|
123 |
+
1, 1 / num_train_timesteps, num_train_timesteps
|
124 |
+
)
|
125 |
+
self.delta_timesteps = self.timesteps - torch.cat(
|
126 |
+
[self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])]
|
127 |
+
)
|
128 |
self.shifting = shifting
|
129 |
self.base_resolution = base_resolution
|
130 |
|
|
|
132 |
if self.shifting == "SD3":
|
133 |
return sd3_resolution_dependent_timestep_shift(samples, timesteps)
|
134 |
elif self.shifting == "SimpleDiffusion":
|
135 |
+
return simple_diffusion_resolution_dependent_timestep_shift(
|
136 |
+
samples, timesteps, self.base_resolution
|
137 |
+
)
|
138 |
return timesteps
|
139 |
|
140 |
+
def set_timesteps(
|
141 |
+
self,
|
142 |
+
num_inference_steps: int,
|
143 |
+
samples: Tensor,
|
144 |
+
device: Union[str, torch.device] = None,
|
145 |
+
):
|
146 |
"""
|
147 |
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
148 |
|
|
|
152 |
device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
|
153 |
"""
|
154 |
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
155 |
+
timesteps = torch.linspace(1, 1 / num_inference_steps, num_inference_steps).to(
|
156 |
+
device
|
157 |
+
)
|
158 |
self.timesteps = self.shift_timesteps(samples, timesteps)
|
159 |
+
self.delta_timesteps = self.timesteps - torch.cat(
|
160 |
+
[self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])]
|
161 |
+
)
|
162 |
self.num_inference_steps = num_inference_steps
|
163 |
self.sigmas = self.timesteps
|
164 |
|
165 |
+
def scale_model_input(
|
166 |
+
self, sample: torch.FloatTensor, timestep: Optional[int] = None
|
167 |
+
) -> torch.FloatTensor:
|
168 |
# pylint: disable=unused-argument
|
169 |
"""
|
170 |
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
|
|
234 |
else:
|
235 |
# Timestep per token
|
236 |
assert timestep.ndim == 2
|
237 |
+
current_index = (
|
238 |
+
(self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0)
|
239 |
+
)
|
240 |
dt = self.delta_timesteps[current_index]
|
241 |
# Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample
|
242 |
dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None]
|
|
|
258 |
sigmas = append_dims(sigmas, original_samples.ndim)
|
259 |
alphas = 1 - sigmas
|
260 |
noisy_samples = alphas * original_samples + sigmas * noise
|
261 |
+
return noisy_samples
|
xora/utils/conditioning_method.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
from enum import Enum
|
2 |
|
|
|
3 |
class ConditioningMethod(Enum):
|
4 |
UNCONDITIONAL = "unconditional"
|
5 |
FIRST_FRAME = "first_frame"
|
6 |
LAST_FRAME = "last_frame"
|
7 |
-
FIRST_AND_LAST_FRAME = "first_and_last_frame"
|
|
|
1 |
from enum import Enum
|
2 |
|
3 |
+
|
4 |
class ConditioningMethod(Enum):
|
5 |
UNCONDITIONAL = "unconditional"
|
6 |
FIRST_FRAME = "first_frame"
|
7 |
LAST_FRAME = "last_frame"
|
8 |
+
FIRST_AND_LAST_FRAME = "first_and_last_frame"
|
xora/utils/torch_utils.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
|
|
|
4 |
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
5 |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
6 |
dims_to_append = target_dims - x.ndim
|
7 |
if dims_to_append < 0:
|
8 |
-
raise ValueError(
|
|
|
|
|
9 |
elif dims_to_append == 0:
|
10 |
return x
|
11 |
return x[(...,) + (None,) * dims_to_append]
|
12 |
|
|
|
13 |
class Identity(nn.Module):
|
14 |
"""A placeholder identity operator that is argument-insensitive."""
|
15 |
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
|
4 |
+
|
5 |
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
6 |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
7 |
dims_to_append = target_dims - x.ndim
|
8 |
if dims_to_append < 0:
|
9 |
+
raise ValueError(
|
10 |
+
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
11 |
+
)
|
12 |
elif dims_to_append == 0:
|
13 |
return x
|
14 |
return x[(...,) + (None,) * dims_to_append]
|
15 |
|
16 |
+
|
17 |
class Identity(nn.Module):
|
18 |
"""A placeholder identity operator that is argument-insensitive."""
|
19 |
|