from typing import Union, List import PIL from PIL import Image import numpy as np from tqdm.auto import tqdm import torch import torchvision from torchvision.transforms import ToPILImage from einops import repeat from diffusers import AutoencoderKLCogVideoX from diffusers import CogVideoXDDIMScheduler from .model.dit import DiffusionTransformer3D from .model.text_embedders import T5TextEmbedder @torch.no_grad() def predict_x_0(noise_scheduler, model_output, timesteps, sample, device): init_alpha_device = noise_scheduler.alphas_cumprod.device alphas = noise_scheduler.alphas_cumprod.to(device) alpha_prod_t = alphas[timesteps][:, None, None, None] beta_prod_t = 1 - alpha_prod_t pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output noise_scheduler.alphas_cumprod.to(init_alpha_device) return pred_original_sample @torch.no_grad() def get_velocity( model, x, t, text_embed, visual_cu_seqlens, text_cu_seqlens, num_goups=(1, 1, 1), scale_factor=(1., 1., 1.) ): pred_velocity = model(x, text_embed, t, visual_cu_seqlens, text_cu_seqlens, num_goups, scale_factor) return pred_velocity @torch.no_grad() def diffusion_generate_renoise( model, noise_scheduler, shape, device, num_steps, text_embed, visual_cu_seqlens, text_cu_seqlens, num_goups=(1, 1, 1), scale_factor=(1., 1., 1.), progress=False, seed=6554 ): generator = torch.Generator() if seed is not None: generator.manual_seed(seed) img = torch.randn(*shape, generator=generator).to(torch.bfloat16).to(device) noise_scheduler.set_timesteps(num_steps, device=device) timesteps = noise_scheduler.timesteps if progress: timesteps = tqdm(timesteps) for time in timesteps: model_time = time.unsqueeze(0).repeat(visual_cu_seqlens.shape[0] - 1) noise = torch.randn(img.shape, generator=generator).to(torch.bfloat16).to(device) img = noise_scheduler.add_noise(img, noise, time) pred_velocity = get_velocity( model, img.to(torch.bfloat16), model_time, text_embed.to(torch.bfloat16), visual_cu_seqlens, text_cu_seqlens, num_goups, scale_factor ) img = predict_x_0(noise_scheduler=noise_scheduler, model_output=pred_velocity.to(device), timesteps=model_time.to(device), sample=img.to(device), device=device) return img class Kandinsky4T2VPipeline: def __init__( self, device_map: Union[str, torch.device, dict], # {"dit": cuda:0, "vae": cuda:1, "text_embedder": cuda:1 } dit: DiffusionTransformer3D, text_embedder: T5TextEmbedder, vae: AutoencoderKLCogVideoX, noise_scheduler: CogVideoXDDIMScheduler, # TODO base class resolution: int = 512, local_dit_rank=0, world_size=1, ): if resolution not in [512]: raise ValueError("Resolution can be only 512") self.dit = dit self.noise_scheduler = noise_scheduler self.text_embedder = text_embedder self.vae = vae self.resolution = resolution self.device_map = device_map self.local_dit_rank = local_dit_rank self.world_size = world_size self.RESOLUTIONS = { 512: [(512, 512), (352, 736), (736, 352), (384, 672), (672, 384), (480, 544), (544, 480)], } def __call__( self, text: str, save_path: str = "./test.mp4", bs: int = 1, time_length: int = 12, # time in seconds 0 if you want generate image width: int = 512, height: int = 512, seed: int = None, return_frames: bool = False ): num_steps = 4 # SEED if seed is None: if self.local_dit_rank == 0: seed = torch.randint(2 ** 63 - 1, (1,)).to(self.local_dit_rank) else: seed = torch.empty((1,), dtype=torch.int64).to(self.local_dit_rank) if self.world_size > 1: torch.distributed.broadcast(seed, 0) seed = seed.item() assert bs == 1 if self.resolution != 512: raise NotImplementedError(f"Only 512 resolution is available for now") if (height, width) not in self.RESOLUTIONS[self.resolution]: raise ValueError(f"Wrong height, width pair. Available (height, width) are: {self.RESOLUTIONS[self.resolution]}") if num_steps != 4: raise NotImplementedError(f"In the distilled version number of steps have to be strictly equal to 4") # PREPARATION num_frames = 1 if time_length == 0 else time_length * 8 // 4 + 1 num_groups = (1, 1, 1) if self.resolution == 512 else (1, 2, 2) scale_factor = (1., 1., 1.) if self.resolution == 512 else (1., 2., 2.) # TEXT EMBEDDER if self.local_dit_rank == 0: with torch.no_grad(): text_embed = self.text_embedder(text).squeeze(0).to(self.local_dit_rank, dtype=torch.bfloat16) else: text_embed = torch.empty(224, 4096, dtype=torch.bfloat16).to(self.local_dit_rank) if self.world_size > 1: torch.distributed.broadcast(text_embed, 0) torch.cuda.empty_cache() visual_cu_seqlens = num_frames * torch.arange(bs + 1, dtype=torch.int32, device=self.device_map["dit"]) text_cu_seqlens = text_embed.shape[0] * torch.arange(bs + 1, dtype=torch.int32, device=self.device_map["dit"]) bs_text_embed = text_embed.repeat(bs, 1).to(self.device_map["dit"]) shape = (bs * num_frames, height // 8, width // 8, 16) # DIT with torch.no_grad(): with torch.autocast(device_type='cuda', dtype=torch.bfloat16): images = diffusion_generate_renoise( self.dit, self.noise_scheduler, shape, self.device_map["dit"], num_steps, bs_text_embed, visual_cu_seqlens, text_cu_seqlens, num_groups, scale_factor, progress=True, seed=seed, ) torch.cuda.empty_cache() # VAE if self.local_dit_rank == 0: self.vae.num_latent_frames_batch_size = 1 if time_length == 0 else 2 with torch.no_grad(): images = 1 / self.vae.config.scaling_factor * images.to(device=self.device_map["vae"], dtype=torch.bfloat16) images = images.permute(0, 3, 1, 2) if time_length == 0 else images.permute(3, 0, 1, 2) images = self.vae.decode(images.unsqueeze(2 if time_length == 0 else 0)).sample.float() images = torch.clip((images + 1.) / 2., 0., 1.) torch.cuda.empty_cache() if self.local_dit_rank == 0: # RESULTS if time_length == 0: return_images = [] for i, image in enumerate(images.squeeze(2).cpu()): return_images.append(ToPILImage()(image)) return return_images else: if return_frames: return_images = [] for i, image in enumerate(images.squeeze(0).float().permute(1, 0, 2, 3).cpu()): return_images.append(ToPILImage()(image)) return return_images else: torchvision.io.write_video(save_path, 255. * images.squeeze(0).float().permute(1, 2, 3, 0).cpu().numpy(), fps=8, options = {"crf": "5"})