ai-forever's picture
add files
9d3c2b7
from typing import Union, List
import PIL
from PIL import Image
import numpy as np
from tqdm.auto import tqdm
import torch
import torchvision
from torchvision.transforms import ToPILImage
from einops import repeat
from diffusers import AutoencoderKLCogVideoX
from diffusers import CogVideoXDDIMScheduler
from .model.dit import DiffusionTransformer3D
from .model.text_embedders import T5TextEmbedder
@torch.no_grad()
def predict_x_0(noise_scheduler, model_output, timesteps, sample, device):
init_alpha_device = noise_scheduler.alphas_cumprod.device
alphas = noise_scheduler.alphas_cumprod.to(device)
alpha_prod_t = alphas[timesteps][:, None, None, None]
beta_prod_t = 1 - alpha_prod_t
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
noise_scheduler.alphas_cumprod.to(init_alpha_device)
return pred_original_sample
@torch.no_grad()
def get_velocity(
model, x, t, text_embed, visual_cu_seqlens, text_cu_seqlens,
num_goups=(1, 1, 1), scale_factor=(1., 1., 1.)
):
pred_velocity = model(x, text_embed, t, visual_cu_seqlens, text_cu_seqlens, num_goups, scale_factor)
return pred_velocity
@torch.no_grad()
def diffusion_generate_renoise(
model, noise_scheduler, shape, device, num_steps, text_embed, visual_cu_seqlens, text_cu_seqlens,
num_goups=(1, 1, 1), scale_factor=(1., 1., 1.), progress=False, seed=6554
):
generator = torch.Generator()
if seed is not None:
generator.manual_seed(seed)
img = torch.randn(*shape, generator=generator).to(torch.bfloat16).to(device)
noise_scheduler.set_timesteps(num_steps, device=device)
timesteps = noise_scheduler.timesteps
if progress:
timesteps = tqdm(timesteps)
for time in timesteps:
model_time = time.unsqueeze(0).repeat(visual_cu_seqlens.shape[0] - 1)
noise = torch.randn(img.shape, generator=generator).to(torch.bfloat16).to(device)
img = noise_scheduler.add_noise(img, noise, time)
pred_velocity = get_velocity(
model, img.to(torch.bfloat16), model_time,
text_embed.to(torch.bfloat16), visual_cu_seqlens,
text_cu_seqlens, num_goups, scale_factor
)
img = predict_x_0(noise_scheduler=noise_scheduler, model_output=pred_velocity.to(device), timesteps=model_time.to(device), sample=img.to(device), device=device)
return img
class Kandinsky4T2VPipeline:
def __init__(
self,
device_map: Union[str, torch.device, dict], # {"dit": cuda:0, "vae": cuda:1, "text_embedder": cuda:1 }
dit: DiffusionTransformer3D,
text_embedder: T5TextEmbedder,
vae: AutoencoderKLCogVideoX,
noise_scheduler: CogVideoXDDIMScheduler, # TODO base class
resolution: int = 512,
local_dit_rank=0,
world_size=1,
):
if resolution not in [512]:
raise ValueError("Resolution can be only 512")
self.dit = dit
self.noise_scheduler = noise_scheduler
self.text_embedder = text_embedder
self.vae = vae
self.resolution = resolution
self.device_map = device_map
self.local_dit_rank = local_dit_rank
self.world_size = world_size
self.RESOLUTIONS = {
512: [(512, 512), (352, 736), (736, 352), (384, 672), (672, 384), (480, 544), (544, 480)],
}
def __call__(
self,
text: str,
save_path: str = "./test.mp4",
bs: int = 1,
time_length: int = 12, # time in seconds 0 if you want generate image
width: int = 512,
height: int = 512,
seed: int = None,
return_frames: bool = False
):
num_steps = 4
# SEED
if seed is None:
if self.local_dit_rank == 0:
seed = torch.randint(2 ** 63 - 1, (1,)).to(self.local_dit_rank)
else:
seed = torch.empty((1,), dtype=torch.int64).to(self.local_dit_rank)
if self.world_size > 1:
torch.distributed.broadcast(seed, 0)
seed = seed.item()
assert bs == 1
if self.resolution != 512:
raise NotImplementedError(f"Only 512 resolution is available for now")
if (height, width) not in self.RESOLUTIONS[self.resolution]:
raise ValueError(f"Wrong height, width pair. Available (height, width) are: {self.RESOLUTIONS[self.resolution]}")
if num_steps != 4:
raise NotImplementedError(f"In the distilled version number of steps have to be strictly equal to 4")
# PREPARATION
num_frames = 1 if time_length == 0 else time_length * 8 // 4 + 1
num_groups = (1, 1, 1) if self.resolution == 512 else (1, 2, 2)
scale_factor = (1., 1., 1.) if self.resolution == 512 else (1., 2., 2.)
# TEXT EMBEDDER
if self.local_dit_rank == 0:
with torch.no_grad():
text_embed = self.text_embedder(text).squeeze(0).to(self.local_dit_rank, dtype=torch.bfloat16)
else:
text_embed = torch.empty(224, 4096, dtype=torch.bfloat16).to(self.local_dit_rank)
if self.world_size > 1:
torch.distributed.broadcast(text_embed, 0)
torch.cuda.empty_cache()
visual_cu_seqlens = num_frames * torch.arange(bs + 1, dtype=torch.int32, device=self.device_map["dit"])
text_cu_seqlens = text_embed.shape[0] * torch.arange(bs + 1, dtype=torch.int32, device=self.device_map["dit"])
bs_text_embed = text_embed.repeat(bs, 1).to(self.device_map["dit"])
shape = (bs * num_frames, height // 8, width // 8, 16)
# DIT
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
images = diffusion_generate_renoise(
self.dit, self.noise_scheduler, shape, self.device_map["dit"],
num_steps, bs_text_embed, visual_cu_seqlens, text_cu_seqlens,
num_groups, scale_factor, progress=True, seed=seed,
)
torch.cuda.empty_cache()
# VAE
if self.local_dit_rank == 0:
self.vae.num_latent_frames_batch_size = 1 if time_length == 0 else 2
with torch.no_grad():
images = 1 / self.vae.config.scaling_factor * images.to(device=self.device_map["vae"], dtype=torch.bfloat16)
images = images.permute(0, 3, 1, 2) if time_length == 0 else images.permute(3, 0, 1, 2)
images = self.vae.decode(images.unsqueeze(2 if time_length == 0 else 0)).sample.float()
images = torch.clip((images + 1.) / 2., 0., 1.)
torch.cuda.empty_cache()
if self.local_dit_rank == 0:
# RESULTS
if time_length == 0:
return_images = []
for i, image in enumerate(images.squeeze(2).cpu()):
return_images.append(ToPILImage()(image))
return return_images
else:
if return_frames:
return_images = []
for i, image in enumerate(images.squeeze(0).float().permute(1, 0, 2, 3).cpu()):
return_images.append(ToPILImage()(image))
return return_images
else:
torchvision.io.write_video(save_path, 255. * images.squeeze(0).float().permute(1, 2, 3, 0).cpu().numpy(), fps=8, options = {"crf": "5"})