hpoghos's picture
Update t2v_enhanced/model/video_ldm.py
870500e verified
from pathlib import Path
from typing import Any, Optional, Union, Callable
import pytorch_lightning as pl
import torch
from diffusers import DDPMScheduler, DiffusionPipeline, AutoencoderKL, DDIMScheduler
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
from transformers import CLIPTextModel, CLIPTokenizer
from t2v_enhanced.utils.video_utils import ResultProcessor, save_videos_grid, video_naming
from t2v_enhanced.model import pl_module_params_controlnet
from t2v_enhanced.model.diffusers_conditional.models.controlnet.controlnet import ControlNetModel
from t2v_enhanced.model.diffusers_conditional.models.controlnet.unet_3d_condition import UNet3DConditionModel
from t2v_enhanced.model.diffusers_conditional.models.controlnet.pipeline_text_to_video_w_controlnet_synth import TextToVideoSDPipeline
from t2v_enhanced.model.diffusers_conditional.models.controlnet.processor import set_use_memory_efficient_attention_xformers
from t2v_enhanced.model.diffusers_conditional.models.controlnet.mask_generator import MaskGenerator
import warnings
# from warnings import warn
from t2v_enhanced.utils.iimage import IImage
from t2v_enhanced.utils.object_loader import instantiate_object
from t2v_enhanced.utils.object_loader import get_class
class VideoLDM(pl.LightningModule):
def __init__(self,
inference_params: pl_module_params_controlnet.InferenceParams,
opt_params: pl_module_params_controlnet.OptimizerParams = None,
unet_params: pl_module_params_controlnet.UNetParams = None,
):
super().__init__()
self.inference_generator = torch.Generator(device=self.device)
self.opt_params = opt_params
self.unet_params = unet_params
print(f"Base pipeline from: {unet_params.pipeline_repo}")
print(f"Pipeline class {unet_params.pipeline_class}")
# load entire pipeline (unet, vq, text encoder,..)
state_dict_control_model = None
state_dict_fusion = None
state_dict_base_model = None
if len(opt_params.load_trained_controlnet_from_ckpt) > 0:
state_dict_ckpt = torch.load(opt_params.load_trained_controlnet_from_ckpt, map_location=torch.device("cpu"))
state_dict_ckpt = state_dict_ckpt["state_dict"]
state_dict_control_model = dict(filter(lambda x: x[0].startswith("unet"), state_dict_ckpt.items()))
state_dict_control_model = {k.split("unet.")[1]: v for (k, v) in state_dict_control_model.items()}
state_dict_fusion = dict(filter(lambda x: "cross_attention_merger" in x[0], state_dict_ckpt.items()))
state_dict_fusion = {k.split("base_model.")[1]: v for (k, v) in state_dict_fusion.items()}
del state_dict_ckpt
state_dict_proj = None
state_dict_ckpt = None
if hasattr(unet_params, "use_resampler") and unet_params.use_resampler:
num_queries = unet_params.num_frames if unet_params.num_frames > 1 else None
if unet_params.use_image_tokens_ctrl:
num_queries = unet_params.num_control_input_frames
assert unet_params.frame_expansion == "none"
image_encoder = self.unet_params.image_encoder
embedding_dim = image_encoder.embedding_dim
resampler = instantiate_object(self.unet_params.resampler_cls, video_length=num_queries, embedding_dim=embedding_dim, input_tokens=image_encoder.num_tokens, num_layers=self.unet_params.resampler_merging_layers, aggregation=self.unet_params.aggregation)
state_dict_proj = None
self.resampler = resampler
self.image_encoder = image_encoder
noise_scheduler = DDPMScheduler.from_pretrained(self.unet_params.pipeline_repo, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(self.unet_params.pipeline_repo, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(self.unet_params.pipeline_repo, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(self.unet_params.pipeline_repo, subfolder="vae")
base_model = UNet3DConditionModel.from_pretrained(self.unet_params.pipeline_repo, subfolder="unet", low_cpu_mem_usage=False, device_map=None, merging_mode=self.unet_params.merging_mode_base, use_image_embedding=unet_params.use_resampler and unet_params.use_image_tokens_main, use_fps_conditioning=self.opt_params.use_fps_conditioning, unet_params=unet_params)
if state_dict_base_model is not None:
miss, unex = base_model.load_state_dict(state_dict_base_model, strict=False)
assert len(unex) == 0
if len(miss) > 0:
warnings.warn(f"Missing keys when loading base_mode:{miss}")
del state_dict_base_model
if state_dict_fusion is not None:
miss, unex = base_model.load_state_dict(state_dict_fusion, strict=False)
assert len(unex) == 0
del state_dict_fusion
print("PIPE LOADING DONE")
self.noise_scheduler = noise_scheduler
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.vae = vae
self.unet = ControlNetModel.from_unet(
unet=base_model,
conditioning_embedding_out_channels=unet_params.conditioning_embedding_out_channels,
downsample_controlnet_cond=unet_params.downsample_controlnet_cond,
num_frames=unet_params.num_frames if (unet_params.frame_expansion != "none" or self.unet_params.use_controlnet_mask) else unet_params.num_control_input_frames,
num_frame_conditioning=unet_params.num_control_input_frames,
frame_expansion=unet_params.frame_expansion,
pre_transformer_in_cond=unet_params.pre_transformer_in_cond,
num_tranformers=unet_params.num_tranformers,
vae=AutoencoderKL.from_pretrained(self.unet_params.pipeline_repo, subfolder="vae"),
zero_conv_mode=unet_params.zero_conv_mode,
merging_mode=unet_params.merging_mode,
condition_encoder=unet_params.condition_encoder,
use_controlnet_mask=unet_params.use_controlnet_mask,
use_image_embedding=unet_params.use_resampler and unet_params.use_image_tokens_ctrl,
unet_params=unet_params,
use_image_encoder_normalization=unet_params.use_image_encoder_normalization,
)
if state_dict_control_model is not None:
miss, unex = self.unet.load_state_dict(
state_dict_control_model, strict=False)
if len(miss) > 0:
print("WARNING: Loading checkpoint for controlnet misses states")
print(miss)
if unet_params.frame_expansion == "none":
attention_params = self.unet_params.attention_mask_params
assert not attention_params.temporal_self_attention_only_on_conditioning and not attention_params.spatial_attend_on_condition_frames and not attention_params.temp_attend_on_neighborhood_of_condition_frames
self.mask_generator = MaskGenerator(
self.unet_params.attention_mask_params, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames)
self.mask_generator_base = MaskGenerator(
self.unet_params.attention_mask_params_base, num_frame_conditioning=self.unet_params.num_control_input_frames, num_frames=self.unet_params.num_frames)
if state_dict_proj is not None and unet_params.use_image_tokens_main:
if unet_params.use_image_tokens_main:
missing, unexpected = base_model.load_state_dict(
state_dict_proj, strict=False)
elif unet_params.use_image_tokens_ctrl:
missing, unexpected = unet.load_state_dict(
state_dict_proj, strict=False)
assert len(unexpected) == 0, f"Unexpected entries {unexpected}"
print(f"Missing keys state proj = {missing}")
del state_dict_proj
base_model.requires_grad_(False)
self.base_model = base_model
self.unet.requires_grad_(False)
self.text_encoder.requires_grad_(False)
self.vae.requires_grad_(False)
layers_config = opt_params.layers_config
layers_config.set_requires_grad(self)
print("CUSTOM XFORMERS ATTENTION USED.")
if is_xformers_available():
set_use_memory_efficient_attention_xformers(self.unet, num_frame_conditioning=self.unet_params.num_control_input_frames,
num_frames=self.unet_params.num_frames,
attention_mask_params=self.unet_params.attention_mask_params
)
set_use_memory_efficient_attention_xformers(self.base_model, num_frame_conditioning=self.unet_params.num_control_input_frames,
num_frames=self.unet_params.num_frames,
attention_mask_params=self.unet_params.attention_mask_params_base)
if len(inference_params.scheduler_cls) > 0:
inf_scheduler_class = get_class(inference_params.scheduler_cls)
else:
inf_scheduler_class = DDIMScheduler
inf_scheduler = inf_scheduler_class.from_pretrained(
self.unet_params.pipeline_repo, subfolder="scheduler")
inference_pipeline = TextToVideoSDPipeline(vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
unet=self.base_model,
controlnet=self.unet,
scheduler=inf_scheduler
)
inference_pipeline.set_noise_generator(self.opt_params.noise_generator)
inference_pipeline.enable_vae_slicing()
inference_pipeline.set_progress_bar_config(disable=True)
self.inference_params = inference_params
self.inference_pipeline = inference_pipeline
self.result_processor = ResultProcessor(fps=self.inference_params.frame_rate, n_frames=self.inference_params.video_length)
def on_start(self):
datamodule = self.trainer._data_connector._datahook_selector.datamodule
pipe_id_model = self.unet_params.pipeline_repo
for dataset_key in ["video_dataset", "image_dataset", "predict_dataset"]:
dataset = getattr(datamodule, dataset_key, None)
if dataset is not None and hasattr(dataset, "model_id"):
pipe_id_data = dataset.model_id
assert pipe_id_model == pipe_id_data, f"Model and Dataloader need the same pipeline path. Found '{pipe_id_model}' and '{dataset_key}.model_id={pipe_id_data}'. Consider setting '--data.{dataset_key}.model_id={pipe_id_data}'"
self.result_processor.set_logger(self.logger)
def on_predict_start(self) -> None:
self.on_start()
# pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
# pipe.set_progress_bar_config(disable=True)
# self.first_stage = pipe.to(self.device)
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
cfg = self.trainer.predict_cfg
result_file_stem = cfg["result_file_stem"]
storage_fol = Path(cfg['predict_dir'])
prompts = [cfg["prompt"]]
inference_params: pl_module_params_controlnet.InferenceParams = self.inference_params
conditioning_type = inference_params.conditioning_type
# n_autoregressive_generations = inference_params.n_autoregressive_generations
n_autoregressive_generations = cfg["n_autoregressive_generations"]
mode = inference_params.mode
start_from_real_input = inference_params.start_from_real_input
assert isinstance(prompts, list)
prompts = n_autoregressive_generations * prompts
self.inference_generator.manual_seed(self.inference_params.seed)
assert self.unet_params.num_control_input_frames == self.inference_params.video_length//2, f"currently we assume to have an equal size for and second half of the frame interval, e.g. 16 frames, and we condition on 8. Current setup: {self.unet_params.num_frame_conditioning} and {self.inference_params.video_length}"
chunks_conditional = []
batch_size = 1
shape = (batch_size, self.inference_pipeline.unet.config.in_channels, self.inference_params.video_length,
self.inference_pipeline.unet.config.sample_size, self.inference_pipeline.unet.config.sample_size)
for idx, prompt in enumerate(prompts):
if idx > 0:
content = sample*2-1
content_latent = self.vae.encode(content).latent_dist.sample() * self.vae.config.scaling_factor
content_latent = rearrange(content_latent, "F C W H -> 1 C F W H")
content_latent = content_latent[:, :, self.unet_params.num_control_input_frames:].detach().clone()
if hasattr(self.inference_pipeline, "noise_generator"):
latents = self.inference_pipeline.noise_generator.sample_noise(shape=shape, device=self.device, dtype=self.dtype, generator=self.inference_generator, content=content_latent if idx > 0 else None)
else:
latents = None
if idx == 0:
sample = cfg["video"].to(self.device)
else:
if inference_params.conditioning_type == "fixed":
context = chunks_conditional[0][:self.unet_params.num_frame_conditioning]
context = [context]
context = [2*sample-1 for sample in context]
input_frames_conditioning = torch.cat(context).detach().clone()
input_frames_conditioning = rearrange(input_frames_conditioning, "F C W H -> 1 F C W H")
elif inference_params.conditioning_type == "last_chunk":
input_frames_conditioning = condition_input[:, -self.unet_params.num_frame_conditioning:].detach().clone()
elif inference_params.conditioning_type == "past":
context = [sample[:self.unet_params.num_control_input_frames] for sample in chunks_conditional]
context = [2*sample-1 for sample in context]
input_frames_conditioning = torch.cat(context).detach().clone()
input_frames_conditioning = rearrange(input_frames_conditioning, "F C W H -> 1 F C W H")
else:
raise NotImplementedError()
input_frames = condition_input[:, self.unet_params.num_control_input_frames:].detach().clone()
sample = self(prompt, input_frames=input_frames, input_frames_conditioning=input_frames_conditioning, latents=latents)
if hasattr(self.inference_pipeline, "reset_noise_generator_state"):
self.inference_pipeline.reset_noise_generator_state()
condition_input = rearrange(sample, "F C W H -> 1 F C W H")
condition_input = (2*condition_input)-1 # range: [-1,1]
# store first 16 frames, then always last 8 of a chunk
chunks_conditional.append(sample)
result_formats = self.inference_params.result_formats
# result_formats = [gif", "mp4"]
concat_video = self.inference_params.concat_video
def IImage_normalized(x): return IImage(x, vmin=0, vmax=1)
for result_format in result_formats:
save_format = result_format.replace("eval_", "")
merged_video = None
for chunk_idx, (prompt, video) in enumerate(zip(prompts, chunks_conditional)):
if chunk_idx == 0:
current_video = IImage_normalized(video)
else:
current_video = IImage_normalized(video[self.unet_params.num_control_input_frames:])
if merged_video is None:
merged_video = current_video
else:
merged_video &= current_video
if concat_video:
filename = video_naming(prompts[0], save_format, batch_idx, 0)
result_file_video = (storage_fol / filename).absolute().as_posix()
result_file_video = (Path(result_file_video).parent / (result_file_stem+Path(result_file_video).suffix)).as_posix()
self.result_processor.save_to_file(video=merged_video.torch(vmin=0, vmax=1), prompt=prompts[0], video_filename=result_file_video, prompt_on_vid=False)
def forward(self, prompt, input_frames=None, input_frames_conditioning=None, latents=None):
call_params = self.inference_params.to_dict()
print(f"INFERENCE PARAMS = {call_params}")
call_params["prompt"] = prompt
call_params["image"] = input_frames
call_params["num_frames"] = self.inference_params.video_length
call_params["return_dict"] = False
call_params["output_type"] = "pt_t2v"
call_params["mask_generator"] = self.mask_generator
call_params["precision"] = "16" if self.trainer.precision.startswith("16") else "32"
call_params["no_text_condition_control"] = self.opt_params.no_text_condition_control
call_params["weight_control_sample"] = self.unet_params.weight_control_sample
call_params["use_controlnet_mask"] = self.unet_params.use_controlnet_mask
call_params["skip_controlnet_branch"] = self.opt_params.skip_controlnet_branch
call_params["img_cond_resampler"] = self.resampler if self.unet_params.use_resampler else None
call_params["img_cond_encoder"] = self.image_encoder if self.unet_params.use_resampler else None
call_params["input_frames_conditioning"] = input_frames_conditioning
call_params["cfg_text_image"] = self.unet_params.cfg_text_image
call_params["use_of"] = self.unet_params.use_of
if latents is not None:
call_params["latents"] = latents
sample = self.inference_pipeline(generator=self.inference_generator, **call_params)
return sample