README: added inference + installation guidelines, inference clearer.
Browse files- README.md +70 -1
- xora/examples/image_to_video.py → inference.py +123 -41
- scripts/to_safetensors.py +2 -2
- xora/examples/text_to_video.py +0 -138
README.md
CHANGED
@@ -1 +1,70 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
# Xora️
|
4 |
+
</div>
|
5 |
+
|
6 |
+
This is the official repository for Xora.
|
7 |
+
|
8 |
+
## Table of Contents
|
9 |
+
|
10 |
+
* [Introduction](#introduction)
|
11 |
+
* [Installation](#installation)
|
12 |
+
* [Inference](#inference)
|
13 |
+
* [Inference Code](#inference-code)
|
14 |
+
* [Acknowledgement](#acknowledgement)
|
15 |
+
|
16 |
+
## Introduction
|
17 |
+
|
18 |
+
The performance of Diffusion Transformers is heavily influenced by the number of generated latent pixels (or tokens). In video generation, the token count becomes substantial as the number of frames increases. To address this, we designed a carefully optimized VAE that compresses videos into a smaller number of tokens while utilizing a deeper latent space. This approach enables our model to generate high-quality 768x512 videos at 24 FPS, achieving near real-time speeds.
|
19 |
+
|
20 |
+
## Installation
|
21 |
+
|
22 |
+
# Setup
|
23 |
+
The codebase currently uses Python 3.10.5, CUDA version 12.2, and supports PyTorch >= 2.1.2.
|
24 |
+
|
25 |
+
|
26 |
+
```bash
|
27 |
+
git clone https://github.com/LightricksResearch/xora-core.git
|
28 |
+
cd xora-core
|
29 |
+
|
30 |
+
# create env
|
31 |
+
python -m venv env
|
32 |
+
source env/bin/activate
|
33 |
+
python -m pip install -e .\[inference-script\]
|
34 |
+
```
|
35 |
+
|
36 |
+
Then, download the model from [Hugging Face](https://huggingface.co/Lightricks/Xora)
|
37 |
+
|
38 |
+
```python
|
39 |
+
from huggingface_hub import snapshot_download
|
40 |
+
|
41 |
+
model_path = 'PATH' # The local directory to save downloaded checkpoint
|
42 |
+
snapshot_download("Lightricks/Orah", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
|
43 |
+
```
|
44 |
+
|
45 |
+
## Inference
|
46 |
+
|
47 |
+
### Inference Code
|
48 |
+
|
49 |
+
To use our model, please follow the inference code in `inference.py` at [https://github.com/LightricksResearch/xora-core/blob/main/inference.py]():
|
50 |
+
|
51 |
+
For text-to-video generation:
|
52 |
+
|
53 |
+
```bash
|
54 |
+
python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --height HEIGHT --width WIDTH
|
55 |
+
```
|
56 |
+
|
57 |
+
For image-to-video generation:
|
58 |
+
|
59 |
+
```python
|
60 |
+
python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --input_image_path IMAGE_PATH --height HEIGHT --width WIDTH
|
61 |
+
|
62 |
+
```
|
63 |
+
|
64 |
+
## Acknowledgement
|
65 |
+
|
66 |
+
We are grateful for the following awesome projects when implementing Xora:
|
67 |
+
* [DiT](https://github.com/facebookresearch/DiT) and [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): vision transformers for image generation.
|
68 |
+
|
69 |
+
|
70 |
+
[//]: # (## Citation)
|
xora/examples/image_to_video.py → inference.py
RENAMED
@@ -16,9 +16,39 @@ import cv2
|
|
16 |
from PIL import Image
|
17 |
import random
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def load_vae(vae_dir):
|
21 |
-
vae_ckpt_path = vae_dir / "
|
22 |
vae_config_path = vae_dir / "config.json"
|
23 |
with open(vae_config_path, "r") as f:
|
24 |
vae_config = json.load(f)
|
@@ -29,7 +59,7 @@ def load_vae(vae_dir):
|
|
29 |
|
30 |
|
31 |
def load_unet(unet_dir):
|
32 |
-
unet_ckpt_path = unet_dir / "
|
33 |
unet_config_path = unet_dir / "config.json"
|
34 |
transformer_config = Transformer3DModel.load_config(unet_config_path)
|
35 |
transformer = Transformer3DModel.from_config(transformer_config)
|
@@ -60,7 +90,7 @@ def center_crop_and_resize(frame, target_height, target_width):
|
|
60 |
return frame_resized
|
61 |
|
62 |
|
63 |
-
def load_video_to_tensor_with_resize(video_path, target_height
|
64 |
cap = cv2.VideoCapture(video_path)
|
65 |
frames = []
|
66 |
while True:
|
@@ -68,7 +98,12 @@ def load_video_to_tensor_with_resize(video_path, target_height=512, target_width
|
|
68 |
if not ret:
|
69 |
break
|
70 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
72 |
frames.append(frame_resized)
|
73 |
cap.release()
|
74 |
video_np = (np.array(frames) / 127.5) - 1.0
|
@@ -99,9 +134,19 @@ def main():
|
|
99 |
help="Path to the directory containing unet, vae, and scheduler subdirectories",
|
100 |
)
|
101 |
parser.add_argument(
|
102 |
-
"--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
)
|
104 |
-
parser.add_argument("--image_path", type=str, help="Path to the input image file")
|
105 |
parser.add_argument("--seed", type=int, default="171198")
|
106 |
|
107 |
# Pipeline parameters
|
@@ -121,10 +166,16 @@ def main():
|
|
121 |
help="Guidance scale for the pipeline",
|
122 |
)
|
123 |
parser.add_argument(
|
124 |
-
"--height",
|
|
|
|
|
|
|
125 |
)
|
126 |
parser.add_argument(
|
127 |
-
"--width",
|
|
|
|
|
|
|
128 |
)
|
129 |
parser.add_argument(
|
130 |
"--num_frames",
|
@@ -136,12 +187,6 @@ def main():
|
|
136 |
"--frame_rate", type=int, default=25, help="Frame rate for the output video"
|
137 |
)
|
138 |
|
139 |
-
parser.add_argument(
|
140 |
-
"--mixed_precision",
|
141 |
-
action="store_true",
|
142 |
-
help="Mixed precision in float32 and bfloat16",
|
143 |
-
)
|
144 |
-
|
145 |
parser.add_argument(
|
146 |
"--bfloat16",
|
147 |
action="store_true",
|
@@ -152,7 +197,6 @@ def main():
|
|
152 |
parser.add_argument(
|
153 |
"--prompt",
|
154 |
type=str,
|
155 |
-
default='A man wearing a black leather jacket and blue jeans is riding a Harley Davidson motorcycle down a paved road. The man has short brown hair and is wearing a black helmet. The motorcycle is a dark red color with a large front fairing. The road is surrounded by green grass and trees. There is a gas station on the left side of the road with a red and white sign that says "Oil" and "Diner".',
|
156 |
help="Text prompt to guide generation",
|
157 |
)
|
158 |
parser.add_argument(
|
@@ -161,9 +205,42 @@ def main():
|
|
161 |
default="worst quality, inconsistent motion, blurry, jittery, distorted",
|
162 |
help="Negative prompt for undesired features",
|
163 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
args = parser.parse_args()
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
# Paths for the separate mode directories
|
168 |
ckpt_dir = Path(args.ckpt_dir)
|
169 |
unet_dir = ckpt_dir / "unet"
|
@@ -197,18 +274,6 @@ def main():
|
|
197 |
|
198 |
pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
|
199 |
|
200 |
-
# Load media (video or image)
|
201 |
-
if args.video_path:
|
202 |
-
media_items = load_video_to_tensor_with_resize(
|
203 |
-
args.video_path, args.height, args.width
|
204 |
-
).unsqueeze(0)
|
205 |
-
elif args.image_path:
|
206 |
-
media_items = load_image_to_tensor_with_resize(
|
207 |
-
args.image_path, args.height, args.width
|
208 |
-
)
|
209 |
-
else:
|
210 |
-
raise ValueError("Either --video_path or --image_path must be provided.")
|
211 |
-
|
212 |
# Prepare input for the pipeline
|
213 |
sample = {
|
214 |
"prompt": args.prompt,
|
@@ -231,15 +296,19 @@ def main():
|
|
231 |
generator=generator,
|
232 |
output_type="pt",
|
233 |
callback_on_step_end=None,
|
234 |
-
height=
|
235 |
-
width=
|
236 |
num_frames=args.num_frames,
|
237 |
frame_rate=args.frame_rate,
|
238 |
**sample,
|
239 |
is_video=True,
|
240 |
vae_per_channel_normalize=True,
|
241 |
-
conditioning_method=
|
242 |
-
|
|
|
|
|
|
|
|
|
243 |
).images
|
244 |
|
245 |
# Save output video
|
@@ -257,16 +326,29 @@ def main():
|
|
257 |
video_np = (video_np * 255).astype(np.uint8)
|
258 |
fps = args.frame_rate
|
259 |
height, width = video_np.shape[1:3]
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
|
272 |
if __name__ == "__main__":
|
|
|
16 |
from PIL import Image
|
17 |
import random
|
18 |
|
19 |
+
RECOMMENDED_RESOLUTIONS = [
|
20 |
+
(704, 1216, 41),
|
21 |
+
(704, 1088, 49),
|
22 |
+
(640, 1056, 57),
|
23 |
+
(608, 992, 65),
|
24 |
+
(608, 896, 73),
|
25 |
+
(544, 896, 81),
|
26 |
+
(544, 832, 89),
|
27 |
+
(512, 800, 97),
|
28 |
+
(512, 768, 97),
|
29 |
+
(480, 800, 105),
|
30 |
+
(480, 736, 113),
|
31 |
+
(480, 704, 121),
|
32 |
+
(448, 704, 129),
|
33 |
+
(448, 672, 137),
|
34 |
+
(416, 640, 153),
|
35 |
+
(384, 672, 161),
|
36 |
+
(384, 640, 169),
|
37 |
+
(384, 608, 177),
|
38 |
+
(384, 576, 185),
|
39 |
+
(352, 608, 193),
|
40 |
+
(352, 576, 201),
|
41 |
+
(352, 544, 209),
|
42 |
+
(352, 512, 225),
|
43 |
+
(352, 512, 233),
|
44 |
+
(320, 544, 241),
|
45 |
+
(320, 512, 249),
|
46 |
+
(320, 512, 257),
|
47 |
+
]
|
48 |
+
|
49 |
|
50 |
def load_vae(vae_dir):
|
51 |
+
vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
|
52 |
vae_config_path = vae_dir / "config.json"
|
53 |
with open(vae_config_path, "r") as f:
|
54 |
vae_config = json.load(f)
|
|
|
59 |
|
60 |
|
61 |
def load_unet(unet_dir):
|
62 |
+
unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
|
63 |
unet_config_path = unet_dir / "config.json"
|
64 |
transformer_config = Transformer3DModel.load_config(unet_config_path)
|
65 |
transformer = Transformer3DModel.from_config(transformer_config)
|
|
|
90 |
return frame_resized
|
91 |
|
92 |
|
93 |
+
def load_video_to_tensor_with_resize(video_path, target_height, target_width):
|
94 |
cap = cv2.VideoCapture(video_path)
|
95 |
frames = []
|
96 |
while True:
|
|
|
98 |
if not ret:
|
99 |
break
|
100 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
101 |
+
if target_height is not None:
|
102 |
+
frame_resized = center_crop_and_resize(
|
103 |
+
frame_rgb, target_height, target_width
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
frame_resized = frame_rgb
|
107 |
frames.append(frame_resized)
|
108 |
cap.release()
|
109 |
video_np = (np.array(frames) / 127.5) - 1.0
|
|
|
134 |
help="Path to the directory containing unet, vae, and scheduler subdirectories",
|
135 |
)
|
136 |
parser.add_argument(
|
137 |
+
"--input_video_path",
|
138 |
+
type=str,
|
139 |
+
help="Path to the input video file (first frame used)",
|
140 |
+
)
|
141 |
+
parser.add_argument(
|
142 |
+
"--input_image_path", type=str, help="Path to the input image file"
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--output_path",
|
146 |
+
type=str,
|
147 |
+
default=None,
|
148 |
+
help="Path to save output video, if None will save in working directory.",
|
149 |
)
|
|
|
150 |
parser.add_argument("--seed", type=int, default="171198")
|
151 |
|
152 |
# Pipeline parameters
|
|
|
166 |
help="Guidance scale for the pipeline",
|
167 |
)
|
168 |
parser.add_argument(
|
169 |
+
"--height",
|
170 |
+
type=int,
|
171 |
+
default=None,
|
172 |
+
help="Height of the output video frames. Optional if an input image provided.",
|
173 |
)
|
174 |
parser.add_argument(
|
175 |
+
"--width",
|
176 |
+
type=int,
|
177 |
+
default=None,
|
178 |
+
help="Width of the output video frames. If None will infer from input image.",
|
179 |
)
|
180 |
parser.add_argument(
|
181 |
"--num_frames",
|
|
|
187 |
"--frame_rate", type=int, default=25, help="Frame rate for the output video"
|
188 |
)
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
parser.add_argument(
|
191 |
"--bfloat16",
|
192 |
action="store_true",
|
|
|
197 |
parser.add_argument(
|
198 |
"--prompt",
|
199 |
type=str,
|
|
|
200 |
help="Text prompt to guide generation",
|
201 |
)
|
202 |
parser.add_argument(
|
|
|
205 |
default="worst quality, inconsistent motion, blurry, jittery, distorted",
|
206 |
help="Negative prompt for undesired features",
|
207 |
)
|
208 |
+
parser.add_argument(
|
209 |
+
"--custom_resolution",
|
210 |
+
action="store_true",
|
211 |
+
default=False,
|
212 |
+
help="Enable custom resolution (not in recommneded resolutions) if specified (default: False)",
|
213 |
+
)
|
214 |
|
215 |
args = parser.parse_args()
|
216 |
|
217 |
+
if args.input_image_path is None and args.input_video_path is None:
|
218 |
+
assert (
|
219 |
+
args.height is not None and args.width is not None
|
220 |
+
), "Must enter height and width for text to image generation."
|
221 |
+
|
222 |
+
# Load media (video or image)
|
223 |
+
if args.input_video_path:
|
224 |
+
media_items = load_video_to_tensor_with_resize(
|
225 |
+
args.input_video_path, args.height, args.width
|
226 |
+
).unsqueeze(0)
|
227 |
+
elif args.input_image_path:
|
228 |
+
media_items = load_image_to_tensor_with_resize(
|
229 |
+
args.input_image_path, args.height, args.width
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
media_items = None
|
233 |
+
|
234 |
+
height = args.height if args.height else media_items.shape[-2]
|
235 |
+
width = args.width if args.width else media_items.shape[-1]
|
236 |
+
assert height % 32 == 0, f"Height ({height}) should be divisible by 32."
|
237 |
+
assert width % 32 == 0, f"Width ({width}) should be divisible by 32."
|
238 |
+
assert (
|
239 |
+
height,
|
240 |
+
width,
|
241 |
+
args.num_frames,
|
242 |
+
) in RECOMMENDED_RESOLUTIONS or args.custom_resolution, f"The selected resolution + num frames combination is not supported, results would be suboptimal. Supported (h,w,f) are: {RECOMMENDED_RESOLUTIONS}. Use --custom_resolution to enable working with this resolution."
|
243 |
+
|
244 |
# Paths for the separate mode directories
|
245 |
ckpt_dir = Path(args.ckpt_dir)
|
246 |
unet_dir = ckpt_dir / "unet"
|
|
|
274 |
|
275 |
pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
|
276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
# Prepare input for the pipeline
|
278 |
sample = {
|
279 |
"prompt": args.prompt,
|
|
|
296 |
generator=generator,
|
297 |
output_type="pt",
|
298 |
callback_on_step_end=None,
|
299 |
+
height=height,
|
300 |
+
width=width,
|
301 |
num_frames=args.num_frames,
|
302 |
frame_rate=args.frame_rate,
|
303 |
**sample,
|
304 |
is_video=True,
|
305 |
vae_per_channel_normalize=True,
|
306 |
+
conditioning_method=(
|
307 |
+
ConditioningMethod.FIRST_FRAME
|
308 |
+
if media_items is not None
|
309 |
+
else ConditioningMethod.UNCONDITIONAL
|
310 |
+
),
|
311 |
+
mixed_precision=not args.bfloat16,
|
312 |
).images
|
313 |
|
314 |
# Save output video
|
|
|
326 |
video_np = (video_np * 255).astype(np.uint8)
|
327 |
fps = args.frame_rate
|
328 |
height, width = video_np.shape[1:3]
|
329 |
+
if video_np.shape[0] == 1:
|
330 |
+
output_filename = (
|
331 |
+
args.output_path
|
332 |
+
if args.output_path is not None
|
333 |
+
else get_unique_filename(f"image_output_{i}", ".png", ".")
|
334 |
+
)
|
335 |
+
cv2.imwrite(
|
336 |
+
output_filename, video_np[0][..., ::-1]
|
337 |
+
) # Save single frame as image
|
338 |
+
else:
|
339 |
+
output_filename = (
|
340 |
+
args.output_path
|
341 |
+
if args.output_path is not None
|
342 |
+
else get_unique_filename(f"video_output_{i}", ".mp4", ".")
|
343 |
+
)
|
344 |
+
|
345 |
+
out = cv2.VideoWriter(
|
346 |
+
output_filename, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
|
347 |
+
)
|
348 |
+
|
349 |
+
for frame in video_np[..., ::-1]:
|
350 |
+
out.write(frame)
|
351 |
+
out.release()
|
352 |
|
353 |
|
354 |
if __name__ == "__main__":
|
scripts/to_safetensors.py
CHANGED
@@ -100,10 +100,10 @@ def main(
|
|
100 |
|
101 |
# Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
|
102 |
safetensors.torch.save_file(
|
103 |
-
unet, unet_dir / "
|
104 |
)
|
105 |
safetensors.torch.save_file(
|
106 |
-
vae, vae_dir / "
|
107 |
)
|
108 |
|
109 |
# Save config files for unet, vae, and scheduler
|
|
|
100 |
|
101 |
# Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
|
102 |
safetensors.torch.save_file(
|
103 |
+
unet, unet_dir / "unet_diffusion_pytorch_model.safetensors"
|
104 |
)
|
105 |
safetensors.torch.save_file(
|
106 |
+
vae, vae_dir / "vae_diffusion_pytorch_model.safetensors"
|
107 |
)
|
108 |
|
109 |
# Save config files for unet, vae, and scheduler
|
xora/examples/text_to_video.py
DELETED
@@ -1,138 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
3 |
-
from xora.models.transformers.transformer3d import Transformer3DModel
|
4 |
-
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
5 |
-
from xora.schedulers.rf import RectifiedFlowScheduler
|
6 |
-
from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
|
7 |
-
from pathlib import Path
|
8 |
-
from transformers import T5EncoderModel, T5Tokenizer
|
9 |
-
import safetensors.torch
|
10 |
-
import json
|
11 |
-
import argparse
|
12 |
-
|
13 |
-
|
14 |
-
def load_vae(vae_dir):
|
15 |
-
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
16 |
-
vae_config_path = vae_dir / "config.json"
|
17 |
-
with open(vae_config_path, "r") as f:
|
18 |
-
vae_config = json.load(f)
|
19 |
-
vae = CausalVideoAutoencoder.from_config(vae_config)
|
20 |
-
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
21 |
-
vae.load_state_dict(vae_state_dict)
|
22 |
-
return vae.cuda().to(torch.bfloat16)
|
23 |
-
|
24 |
-
|
25 |
-
def load_unet(unet_dir):
|
26 |
-
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
|
27 |
-
unet_config_path = unet_dir / "config.json"
|
28 |
-
transformer_config = Transformer3DModel.load_config(unet_config_path)
|
29 |
-
transformer = Transformer3DModel.from_config(transformer_config)
|
30 |
-
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
31 |
-
transformer.load_state_dict(unet_state_dict, strict=True)
|
32 |
-
return transformer.cuda()
|
33 |
-
|
34 |
-
|
35 |
-
def load_scheduler(scheduler_dir):
|
36 |
-
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
37 |
-
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
38 |
-
return RectifiedFlowScheduler.from_config(scheduler_config)
|
39 |
-
|
40 |
-
|
41 |
-
def main():
|
42 |
-
# Parse command line arguments
|
43 |
-
parser = argparse.ArgumentParser(
|
44 |
-
description="Load models from separate directories"
|
45 |
-
)
|
46 |
-
parser.add_argument(
|
47 |
-
"--separate_dir",
|
48 |
-
type=str,
|
49 |
-
required=True,
|
50 |
-
help="Path to the directory containing unet, vae, and scheduler subdirectories",
|
51 |
-
)
|
52 |
-
parser.add_argument(
|
53 |
-
"--mixed_precision",
|
54 |
-
action="store_true",
|
55 |
-
help="Mixed precision in float32 and bfloat16",
|
56 |
-
)
|
57 |
-
parser.add_argument(
|
58 |
-
"--bfloat16",
|
59 |
-
action="store_true",
|
60 |
-
help="Denoise in bfloat16",
|
61 |
-
)
|
62 |
-
args = parser.parse_args()
|
63 |
-
|
64 |
-
# Paths for the separate mode directories
|
65 |
-
separate_dir = Path(args.separate_dir)
|
66 |
-
unet_dir = separate_dir / "unet"
|
67 |
-
vae_dir = separate_dir / "vae"
|
68 |
-
scheduler_dir = separate_dir / "scheduler"
|
69 |
-
|
70 |
-
# Load models
|
71 |
-
vae = load_vae(vae_dir)
|
72 |
-
unet = load_unet(unet_dir)
|
73 |
-
scheduler = load_scheduler(scheduler_dir)
|
74 |
-
|
75 |
-
# Patchifier (remains the same)
|
76 |
-
patchifier = SymmetricPatchifier(patch_size=1)
|
77 |
-
|
78 |
-
text_encoder = T5EncoderModel.from_pretrained(
|
79 |
-
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
80 |
-
).to("cuda")
|
81 |
-
tokenizer = T5Tokenizer.from_pretrained(
|
82 |
-
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
83 |
-
)
|
84 |
-
|
85 |
-
if args.bfloat16 and unet.dtype != torch.bfloat16:
|
86 |
-
unet = unet.to(torch.bfloat16)
|
87 |
-
|
88 |
-
# Use submodels for the pipeline
|
89 |
-
submodel_dict = {
|
90 |
-
"transformer": unet, # using unet for transformer
|
91 |
-
"patchifier": patchifier,
|
92 |
-
"scheduler": scheduler,
|
93 |
-
"text_encoder": text_encoder,
|
94 |
-
"tokenizer": tokenizer,
|
95 |
-
"vae": vae,
|
96 |
-
}
|
97 |
-
|
98 |
-
pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
|
99 |
-
|
100 |
-
# Sample input
|
101 |
-
num_inference_steps = 20
|
102 |
-
num_images_per_prompt = 2
|
103 |
-
guidance_scale = 3
|
104 |
-
height = 512
|
105 |
-
width = 768
|
106 |
-
num_frames = 57
|
107 |
-
frame_rate = 25
|
108 |
-
sample = {
|
109 |
-
"prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
|
110 |
-
"The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
|
111 |
-
"prompt_attention_mask": None, # Adjust attention masks as needed
|
112 |
-
"negative_prompt": "Ugly deformed",
|
113 |
-
"negative_prompt_attention_mask": None,
|
114 |
-
}
|
115 |
-
|
116 |
-
# Generate images (video frames)
|
117 |
-
_ = pipeline(
|
118 |
-
num_inference_steps=num_inference_steps,
|
119 |
-
num_images_per_prompt=num_images_per_prompt,
|
120 |
-
guidance_scale=guidance_scale,
|
121 |
-
generator=None,
|
122 |
-
output_type="pt",
|
123 |
-
callback_on_step_end=None,
|
124 |
-
height=height,
|
125 |
-
width=width,
|
126 |
-
num_frames=num_frames,
|
127 |
-
frame_rate=frame_rate,
|
128 |
-
**sample,
|
129 |
-
is_video=True,
|
130 |
-
vae_per_channel_normalize=True,
|
131 |
-
mixed_precision=args.mixed_precision,
|
132 |
-
).images
|
133 |
-
|
134 |
-
print("Generated images (video frames).")
|
135 |
-
|
136 |
-
|
137 |
-
if __name__ == "__main__":
|
138 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|