File size: 4,370 Bytes
ebaff66 c811a04 d504563 cef1afc e46ff5e bebbcd0 e46ff5e ebaff66 325137b e46ff5e 325137b e46ff5e ebaff66 325137b e46ff5e bebbcd0 325137b e46ff5e bebbcd0 325137b e46ff5e 325137b e46ff5e bebbcd0 e46ff5e 325137b ebaff66 e46ff5e ebaff66 e46ff5e ebaff66 325137b cef1afc e46ff5e d504563 e46ff5e 325137b e46ff5e 325137b e46ff5e 325137b e46ff5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from xora.models.transformers.transformer3d import Transformer3DModel
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
from xora.schedulers.rf import RectifiedFlowScheduler
from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
from pathlib import Path
from transformers import T5EncoderModel, T5Tokenizer
import safetensors.torch
import json
import argparse
def load_vae(vae_dir):
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
vae_config_path = vae_dir / "config.json"
with open(vae_config_path, "r") as f:
vae_config = json.load(f)
vae = CausalVideoAutoencoder.from_config(vae_config)
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
vae.load_state_dict(vae_state_dict)
return vae.cuda().to(torch.bfloat16)
def load_unet(unet_dir):
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
unet_config_path = unet_dir / "config.json"
transformer_config = Transformer3DModel.load_config(unet_config_path)
transformer = Transformer3DModel.from_config(transformer_config)
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
transformer.load_state_dict(unet_state_dict, strict=True)
return transformer.cuda()
def load_scheduler(scheduler_dir):
scheduler_config_path = scheduler_dir / "scheduler_config.json"
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
return RectifiedFlowScheduler.from_config(scheduler_config)
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(
description="Load models from separate directories"
)
parser.add_argument(
"--separate_dir",
type=str,
required=True,
help="Path to the directory containing unet, vae, and scheduler subdirectories",
)
args = parser.parse_args()
# Paths for the separate mode directories
separate_dir = Path(args.separate_dir)
unet_dir = separate_dir / "unet"
vae_dir = separate_dir / "vae"
scheduler_dir = separate_dir / "scheduler"
# Load models
vae = load_vae(vae_dir)
unet = load_unet(unet_dir)
scheduler = load_scheduler(scheduler_dir)
# Patchifier (remains the same)
patchifier = SymmetricPatchifier(patch_size=1)
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
).to("cuda")
tokenizer = T5Tokenizer.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
)
# Use submodels for the pipeline
submodel_dict = {
"transformer": unet, # using unet for transformer
"patchifier": patchifier,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"vae": vae,
}
pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
# Sample input
num_inference_steps = 20
num_images_per_prompt = 2
guidance_scale = 3
height = 512
width = 768
num_frames = 57
frame_rate = 25
sample = {
"prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
"The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
"prompt_attention_mask": None, # Adjust attention masks as needed
"negative_prompt": "Ugly deformed",
"negative_prompt_attention_mask": None,
}
# Generate images (video frames)
_ = pipeline(
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images_per_prompt,
guidance_scale=guidance_scale,
generator=None,
output_type="pt",
callback_on_step_end=None,
height=height,
width=width,
num_frames=num_frames,
frame_rate=frame_rate,
**sample,
is_video=True,
vae_per_channel_normalize=True,
).images
print("Generated images (video frames).")
if __name__ == "__main__":
main()
|