File size: 4,909 Bytes
9d3c2b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
from typing import Optional, Union

import torch
from omegaconf import OmegaConf
from .model.dit import get_dit, parallelize
from .model.text_embedders import get_text_embedder
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
from omegaconf.dictconfig import DictConfig
from huggingface_hub import hf_hub_download, snapshot_download

from .t2v_pipeline import Kandinsky4T2VPipeline

from torch.distributed.device_mesh import DeviceMesh, init_device_mesh


def get_T2V_pipeline(
        device_map: Union[str, torch.device, dict],
        resolution: int = 512,
        cache_dir: str = './weights/',
        dit_path: str = None,
        text_encoder_path: str = None,
        tokenizer_path: str = None,
        vae_path: str = None,
        scheduler_path: str = None,
        conf_path: str = None,
) -> Kandinsky4T2VPipeline:
    
    assert resolution in [512]
    
    if not isinstance(device_map, dict):
        device_map = {
            'dit': device_map, 
            'vae': device_map, 
            'text_embedder': device_map
        }

    try:
        local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
    except:
        local_rank, world_size = 0, 1
        
    if world_size > 1:
        device_mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("tensor_parallel",))
        device_map["dit"] = torch.device(f'cuda:{local_rank}')

    os.makedirs(cache_dir, exist_ok=True)
    
    if dit_path is None:
        dit_path = hf_hub_download(
            repo_id="ai-forever/kandinsky4", filename=f"kandinsky4_distil_{resolution}.pt", local_dir=cache_dir
        )

    if vae_path is None:
        vae_path = snapshot_download(
            repo_id="THUDM/CogVideoX-5b", allow_patterns='vae/*', local_dir=cache_dir
        ) 
        vae_path = os.path.join(cache_dir, f"vae/")

    if scheduler_path is None:
        scheduler_path = snapshot_download(
            repo_id="THUDM/CogVideoX-5b", allow_patterns='scheduler/*', local_dir=cache_dir
        ) 
        scheduler_path = os.path.join(cache_dir, f"scheduler/")

    if text_encoder_path is None:
        text_encoder_path = snapshot_download(
            repo_id="THUDM/CogVideoX-5b", allow_patterns='text_encoder/*', local_dir=cache_dir
        ) 
        text_encoder_path = os.path.join(cache_dir, f"text_encoder/")

    if tokenizer_path is None:
        tokenizer_path = snapshot_download(
            repo_id="THUDM/CogVideoX-5b", allow_patterns='tokenizer/*', local_dir=cache_dir
        ) 
        tokenizer_path = os.path.join(cache_dir, f"tokenizer/")
        
    if conf_path is None:
        conf = get_default_conf(vae_path, text_encoder_path, tokenizer_path, scheduler_path, dit_path)
    else:
        conf = OmegaConf.load(conf_path)

    dit = get_dit(conf.dit)
    dit = dit.to(dtype=torch.bfloat16, device=device_map["dit"])

    noise_scheduler = CogVideoXDDIMScheduler.from_pretrained(conf.dit.scheduler)
    
    if world_size > 1:
        dit = parallelize(dit, device_mesh["tensor_parallel"])
        
    text_embedder = get_text_embedder(conf)
    text_embedder = text_embedder.freeze()
    if local_rank == 0:
        text_embedder = text_embedder.to(device=device_map["text_embedder"], dtype=torch.bfloat16)
    
    vae = AutoencoderKLCogVideoX.from_pretrained(conf.vae.checkpoint_path)
    vae = vae.eval()
    if local_rank == 0:
        vae = vae.to(device_map["vae"], dtype=torch.bfloat16)

    return Kandinsky4T2VPipeline(
        device_map=device_map,
        dit=dit,
        text_embedder=text_embedder,
        vae=vae,
        noise_scheduler=noise_scheduler,
        resolution=resolution,
        local_dit_rank=local_rank,
        world_size=world_size,
    )


def get_default_conf(
    vae_path,
    text_encoder_path,
    tokenizer_path,
    scheduler_path, 
    dit_path, 
) -> DictConfig:
    dit_params = {
            'in_visual_dim': 16, 
            'in_text_dim': 4096, 
            'out_visual_dim': 16, 
            'time_dim': 512, 
            'patch_size': [1, 2, 2], 
            'model_dim': 3072, 
            'ff_dim': 12288, 
            'num_blocks': 21, 
            'axes_dims': [16, 24, 24]
        }
    
    conf = {
        'vae': 
            {
                'checkpoint_path': vae_path
            }, 
        'text_embedder': 
            {
                'emb_size': 4096, 
                'tokens_lenght': 224, 
                'params': 
                    {
                        'checkpoint_path': text_encoder_path,
                        'tokenizer_path': tokenizer_path
                    }
            }, 
        'dit': 
            {
                'scheduler': scheduler_path, 
                'checkpoint_path': dit_path, 
                'params': dit_params
                
            }, 
        'resolution': 512, 
    }
    
    return DictConfig(conf)