Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from typing import Optional, Union | |
import torch | |
from omegaconf import OmegaConf | |
from .model.dit import get_dit, parallelize | |
from .model.text_embedders import get_text_embedder | |
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler | |
from omegaconf.dictconfig import DictConfig | |
from huggingface_hub import hf_hub_download, snapshot_download | |
from .t2v_pipeline import Kandinsky4T2VPipeline | |
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh | |
def get_T2V_pipeline( | |
device_map: Union[str, torch.device, dict], | |
resolution: int = 512, | |
cache_dir: str = './weights/', | |
dit_path: str = None, | |
text_encoder_path: str = None, | |
tokenizer_path: str = None, | |
vae_path: str = None, | |
scheduler_path: str = None, | |
conf_path: str = None, | |
) -> Kandinsky4T2VPipeline: | |
assert resolution in [512] | |
if not isinstance(device_map, dict): | |
device_map = { | |
'dit': device_map, | |
'vae': device_map, | |
'text_embedder': device_map | |
} | |
try: | |
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) | |
except: | |
local_rank, world_size = 0, 1 | |
if world_size > 1: | |
device_mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("tensor_parallel",)) | |
device_map["dit"] = torch.device(f'cuda:{local_rank}') | |
os.makedirs(cache_dir, exist_ok=True) | |
if dit_path is None: | |
dit_path = hf_hub_download( | |
repo_id="ai-forever/kandinsky4", filename=f"kandinsky4_distil_{resolution}.pt", local_dir=cache_dir | |
) | |
if vae_path is None: | |
vae_path = snapshot_download( | |
repo_id="THUDM/CogVideoX-5b", allow_patterns='vae/*', local_dir=cache_dir | |
) | |
vae_path = os.path.join(cache_dir, f"vae/") | |
if scheduler_path is None: | |
scheduler_path = snapshot_download( | |
repo_id="THUDM/CogVideoX-5b", allow_patterns='scheduler/*', local_dir=cache_dir | |
) | |
scheduler_path = os.path.join(cache_dir, f"scheduler/") | |
if text_encoder_path is None: | |
text_encoder_path = snapshot_download( | |
repo_id="THUDM/CogVideoX-5b", allow_patterns='text_encoder/*', local_dir=cache_dir | |
) | |
text_encoder_path = os.path.join(cache_dir, f"text_encoder/") | |
if tokenizer_path is None: | |
tokenizer_path = snapshot_download( | |
repo_id="THUDM/CogVideoX-5b", allow_patterns='tokenizer/*', local_dir=cache_dir | |
) | |
tokenizer_path = os.path.join(cache_dir, f"tokenizer/") | |
if conf_path is None: | |
conf = get_default_conf(vae_path, text_encoder_path, tokenizer_path, scheduler_path, dit_path) | |
else: | |
conf = OmegaConf.load(conf_path) | |
dit = get_dit(conf.dit) | |
dit = dit.to(dtype=torch.bfloat16, device=device_map["dit"]) | |
noise_scheduler = CogVideoXDDIMScheduler.from_pretrained(conf.dit.scheduler) | |
if world_size > 1: | |
dit = parallelize(dit, device_mesh["tensor_parallel"]) | |
text_embedder = get_text_embedder(conf) | |
text_embedder = text_embedder.freeze() | |
if local_rank == 0: | |
text_embedder = text_embedder.to(device=device_map["text_embedder"], dtype=torch.bfloat16) | |
vae = AutoencoderKLCogVideoX.from_pretrained(conf.vae.checkpoint_path) | |
vae = vae.eval() | |
if local_rank == 0: | |
vae = vae.to(device_map["vae"], dtype=torch.bfloat16) | |
return Kandinsky4T2VPipeline( | |
device_map=device_map, | |
dit=dit, | |
text_embedder=text_embedder, | |
vae=vae, | |
noise_scheduler=noise_scheduler, | |
resolution=resolution, | |
local_dit_rank=local_rank, | |
world_size=world_size, | |
) | |
def get_default_conf( | |
vae_path, | |
text_encoder_path, | |
tokenizer_path, | |
scheduler_path, | |
dit_path, | |
) -> DictConfig: | |
dit_params = { | |
'in_visual_dim': 16, | |
'in_text_dim': 4096, | |
'out_visual_dim': 16, | |
'time_dim': 512, | |
'patch_size': [1, 2, 2], | |
'model_dim': 3072, | |
'ff_dim': 12288, | |
'num_blocks': 21, | |
'axes_dims': [16, 24, 24] | |
} | |
conf = { | |
'vae': | |
{ | |
'checkpoint_path': vae_path | |
}, | |
'text_embedder': | |
{ | |
'emb_size': 4096, | |
'tokens_lenght': 224, | |
'params': | |
{ | |
'checkpoint_path': text_encoder_path, | |
'tokenizer_path': tokenizer_path | |
} | |
}, | |
'dit': | |
{ | |
'scheduler': scheduler_path, | |
'checkpoint_path': dit_path, | |
'params': dit_params | |
}, | |
'resolution': 512, | |
} | |
return DictConfig(conf) |