ai-forever's picture
add files
9d3c2b7
import os
from typing import Optional, Union
import torch
from omegaconf import OmegaConf
from .model.dit import get_dit, parallelize
from .model.text_embedders import get_text_embedder
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
from omegaconf.dictconfig import DictConfig
from huggingface_hub import hf_hub_download, snapshot_download
from .t2v_pipeline import Kandinsky4T2VPipeline
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
def get_T2V_pipeline(
device_map: Union[str, torch.device, dict],
resolution: int = 512,
cache_dir: str = './weights/',
dit_path: str = None,
text_encoder_path: str = None,
tokenizer_path: str = None,
vae_path: str = None,
scheduler_path: str = None,
conf_path: str = None,
) -> Kandinsky4T2VPipeline:
assert resolution in [512]
if not isinstance(device_map, dict):
device_map = {
'dit': device_map,
'vae': device_map,
'text_embedder': device_map
}
try:
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"])
except:
local_rank, world_size = 0, 1
if world_size > 1:
device_mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("tensor_parallel",))
device_map["dit"] = torch.device(f'cuda:{local_rank}')
os.makedirs(cache_dir, exist_ok=True)
if dit_path is None:
dit_path = hf_hub_download(
repo_id="ai-forever/kandinsky4", filename=f"kandinsky4_distil_{resolution}.pt", local_dir=cache_dir
)
if vae_path is None:
vae_path = snapshot_download(
repo_id="THUDM/CogVideoX-5b", allow_patterns='vae/*', local_dir=cache_dir
)
vae_path = os.path.join(cache_dir, f"vae/")
if scheduler_path is None:
scheduler_path = snapshot_download(
repo_id="THUDM/CogVideoX-5b", allow_patterns='scheduler/*', local_dir=cache_dir
)
scheduler_path = os.path.join(cache_dir, f"scheduler/")
if text_encoder_path is None:
text_encoder_path = snapshot_download(
repo_id="THUDM/CogVideoX-5b", allow_patterns='text_encoder/*', local_dir=cache_dir
)
text_encoder_path = os.path.join(cache_dir, f"text_encoder/")
if tokenizer_path is None:
tokenizer_path = snapshot_download(
repo_id="THUDM/CogVideoX-5b", allow_patterns='tokenizer/*', local_dir=cache_dir
)
tokenizer_path = os.path.join(cache_dir, f"tokenizer/")
if conf_path is None:
conf = get_default_conf(vae_path, text_encoder_path, tokenizer_path, scheduler_path, dit_path)
else:
conf = OmegaConf.load(conf_path)
dit = get_dit(conf.dit)
dit = dit.to(dtype=torch.bfloat16, device=device_map["dit"])
noise_scheduler = CogVideoXDDIMScheduler.from_pretrained(conf.dit.scheduler)
if world_size > 1:
dit = parallelize(dit, device_mesh["tensor_parallel"])
text_embedder = get_text_embedder(conf)
text_embedder = text_embedder.freeze()
if local_rank == 0:
text_embedder = text_embedder.to(device=device_map["text_embedder"], dtype=torch.bfloat16)
vae = AutoencoderKLCogVideoX.from_pretrained(conf.vae.checkpoint_path)
vae = vae.eval()
if local_rank == 0:
vae = vae.to(device_map["vae"], dtype=torch.bfloat16)
return Kandinsky4T2VPipeline(
device_map=device_map,
dit=dit,
text_embedder=text_embedder,
vae=vae,
noise_scheduler=noise_scheduler,
resolution=resolution,
local_dit_rank=local_rank,
world_size=world_size,
)
def get_default_conf(
vae_path,
text_encoder_path,
tokenizer_path,
scheduler_path,
dit_path,
) -> DictConfig:
dit_params = {
'in_visual_dim': 16,
'in_text_dim': 4096,
'out_visual_dim': 16,
'time_dim': 512,
'patch_size': [1, 2, 2],
'model_dim': 3072,
'ff_dim': 12288,
'num_blocks': 21,
'axes_dims': [16, 24, 24]
}
conf = {
'vae':
{
'checkpoint_path': vae_path
},
'text_embedder':
{
'emb_size': 4096,
'tokens_lenght': 224,
'params':
{
'checkpoint_path': text_encoder_path,
'tokenizer_path': tokenizer_path
}
},
'dit':
{
'scheduler': scheduler_path,
'checkpoint_path': dit_path,
'params': dit_params
},
'resolution': 512,
}
return DictConfig(conf)