File size: 4,370 Bytes
ebaff66
c811a04
 
 
 
d504563
cef1afc
e46ff5e
bebbcd0
 
e46ff5e
ebaff66
325137b
e46ff5e
 
 
325137b
e46ff5e
 
 
 
 
ebaff66
325137b
e46ff5e
 
 
 
 
 
 
 
bebbcd0
325137b
e46ff5e
 
 
 
bebbcd0
325137b
e46ff5e
 
325137b
 
 
 
 
 
 
 
 
e46ff5e
bebbcd0
e46ff5e
 
325137b
 
 
ebaff66
e46ff5e
 
 
 
ebaff66
e46ff5e
 
ebaff66
325137b
 
 
 
 
 
cef1afc
e46ff5e
 
 
 
 
 
 
 
 
 
d504563
e46ff5e
 
 
 
 
 
 
 
 
 
 
325137b
 
 
 
e46ff5e
 
 
325137b
e46ff5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325137b
e46ff5e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from xora.models.transformers.transformer3d import Transformer3DModel
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
from xora.schedulers.rf import RectifiedFlowScheduler
from xora.pipelines.pipeline_xora_video import XoraVideoPipeline
from pathlib import Path
from transformers import T5EncoderModel, T5Tokenizer
import safetensors.torch
import json
import argparse


def load_vae(vae_dir):
    vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
    vae_config_path = vae_dir / "config.json"
    with open(vae_config_path, "r") as f:
        vae_config = json.load(f)
    vae = CausalVideoAutoencoder.from_config(vae_config)
    vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
    vae.load_state_dict(vae_state_dict)
    return vae.cuda().to(torch.bfloat16)


def load_unet(unet_dir):
    unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
    unet_config_path = unet_dir / "config.json"
    transformer_config = Transformer3DModel.load_config(unet_config_path)
    transformer = Transformer3DModel.from_config(transformer_config)
    unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
    transformer.load_state_dict(unet_state_dict, strict=True)
    return transformer.cuda()


def load_scheduler(scheduler_dir):
    scheduler_config_path = scheduler_dir / "scheduler_config.json"
    scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
    return RectifiedFlowScheduler.from_config(scheduler_config)


def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(
        description="Load models from separate directories"
    )
    parser.add_argument(
        "--separate_dir",
        type=str,
        required=True,
        help="Path to the directory containing unet, vae, and scheduler subdirectories",
    )
    args = parser.parse_args()

    # Paths for the separate mode directories
    separate_dir = Path(args.separate_dir)
    unet_dir = separate_dir / "unet"
    vae_dir = separate_dir / "vae"
    scheduler_dir = separate_dir / "scheduler"

    # Load models
    vae = load_vae(vae_dir)
    unet = load_unet(unet_dir)
    scheduler = load_scheduler(scheduler_dir)

    # Patchifier (remains the same)
    patchifier = SymmetricPatchifier(patch_size=1)

    text_encoder = T5EncoderModel.from_pretrained(
        "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
    ).to("cuda")
    tokenizer = T5Tokenizer.from_pretrained(
        "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
    )

    # Use submodels for the pipeline
    submodel_dict = {
        "transformer": unet,  # using unet for transformer
        "patchifier": patchifier,
        "scheduler": scheduler,
        "text_encoder": text_encoder,
        "tokenizer": tokenizer,
        "vae": vae,
    }

    pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")

    # Sample input
    num_inference_steps = 20
    num_images_per_prompt = 2
    guidance_scale = 3
    height = 512
    width = 768
    num_frames = 57
    frame_rate = 25
    sample = {
        "prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
        "The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
        "prompt_attention_mask": None,  # Adjust attention masks as needed
        "negative_prompt": "Ugly deformed",
        "negative_prompt_attention_mask": None,
    }

    # Generate images (video frames)
    _ = pipeline(
        num_inference_steps=num_inference_steps,
        num_images_per_prompt=num_images_per_prompt,
        guidance_scale=guidance_scale,
        generator=None,
        output_type="pt",
        callback_on_step_end=None,
        height=height,
        width=width,
        num_frames=num_frames,
        frame_rate=frame_rate,
        **sample,
        is_video=True,
        vae_per_channel_normalize=True,
    ).images

    print("Generated images (video frames).")


if __name__ == "__main__":
    main()