Control-A-Video / model /video_diffusion /pipelines /pipeline_stable_diffusion_controlnet3d.py
imjunaidafzal's picture
Duplicate from wf-genius/Control-A-Video
0e83ec5
raw
history blame
22.2 kB
# Copyright 2023 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .pipeline_st_stable_diffusion import SpatioTemporalStableDiffusionPipeline
from typing import Callable, List, Optional, Union
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from transformers import DPTForDepthEstimation
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler
import torch
from einops import rearrange, repeat
import decord
import cv2
import random
import numpy as np
from ..models.unet_3d_condition import UNetPseudo3DConditionModel
from ..models.controlnet3d import ControlNet3DModel
class Controlnet3DStableDiffusionPipeline(SpatioTemporalStableDiffusionPipeline):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNetPseudo3DConditionModel,
controlnet: ControlNet3DModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
annotator_model=None,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.annotator_model = annotator_model
self.controlnet = controlnet
self.unet = unet
self.vae = vae
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.scheduler = scheduler
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
)
@staticmethod
def get_frames_preprocess(data_path, num_frames=24, sampling_rate=1, begin_indice=0, return_np=False):
vr = decord.VideoReader(data_path,)
n_images = len(vr)
fps_vid = round(vr.get_avg_fps())
frame_indices = [begin_indice + i*sampling_rate for i in range(num_frames)] # 随机取n帧
while n_images <= frame_indices[-1]:
# 超过视频长度,采样率减小直至不超过。
sampling_rate -= 1
if sampling_rate == 0:
# NOTE 边界检查
return None, None
frame_indices = [i*sampling_rate for i in range(num_frames)]
frames = vr.get_batch(frame_indices).asnumpy()
if return_np:
return frames, fps_vid
frames = torch.from_numpy(frames).div(255) * 2 - 1
frames = rearrange(frames, "f h w c -> c f h w").unsqueeze(0)
return frames, fps_vid
@torch.no_grad()
def get_canny_edge_map(self, frames, ):
# (b f) c h w"
# from tensor to numpy
inputs = frames.cpu().numpy()
inputs = rearrange(inputs, 'f c h w -> f h w c')
# inputs from [-1, 1] to [0, 255]
inputs = (inputs + 1) * 127.5
inputs = inputs.astype(np.uint8)
lower_threshold = 100
higher_threshold = 200
edge_images = np.stack([cv2.Canny(inp, lower_threshold, higher_threshold) for inp in inputs])
# from numpy to tensors
edge_images = torch.from_numpy(edge_images).unsqueeze(1) # f, 1, h, w
edge_images = edge_images.div(255)*2 - 1
# print(torch.max(out_images), torch.min(out_images), out_images.dtype)
return edge_images.to(dtype= self.controlnet.dtype, device=self.controlnet.device)
@torch.no_grad()
def get_depth_map(self, frames, height, width, return_standard_norm=False ):
"""
frames should be like: (f c h w), you may turn b f c h w -> (b f) c h w first
"""
h,w = height, width
inputs = torch.nn.functional.interpolate(
frames,
size=(384, 384),
mode="bicubic",
antialias=True,
)
# 转类型和设备
inputs = inputs.to(dtype= self.annotator_model.dtype, device=self.annotator_model.device)
outputs = self.annotator_model(inputs)
predicted_depths = outputs.predicted_depth
# interpolate to original size
predictions = torch.nn.functional.interpolate(
predicted_depths.unsqueeze(1),
size=(h, w),
mode="bicubic",
)
# normalize output
if return_standard_norm:
depth_min = torch.amin(predictions, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(predictions, dim=[1, 2, 3], keepdim=True)
predictions = 2.0 * (predictions - depth_min) / (depth_max - depth_min) - 1.0
else:
predictions -= torch.min(predictions)
predictions /= torch.max(predictions)
return predictions
@torch.no_grad()
def get_hed_map(self, frames,):
if isinstance(frames, torch.Tensor):
# 输入的就是 b c h w的tensor 范围是-1~1,需要转换为0~1
frames = (frames + 1) / 2
#rgb转bgr
bgr_frames = frames.clone()
bgr_frames[:, 0, :, :] = frames[:, 2, :, :]
bgr_frames[:, 2, :, :] = frames[:, 0, :, :]
edge = self.annotator_model(bgr_frames) # 范围也是0~1
return edge
else:
assert frames.ndim == 3
frames = frames[:, :, ::-1].copy()
with torch.no_grad():
image_hed = torch.from_numpy(frames).to(next(self.annotator_model.parameters()).device, dtype=next(self.annotator_model.parameters()).dtype )
image_hed = image_hed / 255.0
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
edge = self.annotator_model(image_hed)[0]
edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
return edge[0]
@torch.no_grad()
def get_pose_map(self, frames,):
if isinstance(frames, torch.Tensor):
# 输入的就是 b c h w的tensor 范围是-1~1,需要转换为0~1
frames = (frames + 1) / 2
np_frames = frames.cpu().numpy() * 255
np_frames = np.array(np_frames, dtype=np.uint8)
np_frames = rearrange(np_frames, 'f c h w-> f h w c')
poses = np.stack([self.annotator_model(inp) for inp in np_frames])
else:
poses = self.annotator_model(frames)
return poses
def get_timesteps(self, num_inference_steps, strength,):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
controlnet_hint = None,
fps_labels = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
clip_length: int = 8, # NOTE clip_length和images的帧数一致。
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs = None,
video_scale: float = 0.0,
controlnet_conditioning_scale: float = 1.0,
fix_first_frame=True,
first_frame_output = None , # 也可以允许挑好图后传入。
first_frame_output_latent = None,
first_frame_control_hint = None, # 维持第一帧
add_first_frame_by_concat = False,
controlhint_in_uncond = False,
init_same_noise_per_frame=False,
init_noise_by_residual_thres=0.0,
images=None,
in_domain=False, # 是否调用视频模型生成图片
residual_control_steps=1,
first_frame_ddim_strength=1.0,
return_last_latent = False,
):
'''
add origin video frames to get depth maps
'''
if fix_first_frame and first_frame_output is None and first_frame_output_latent is None:
first_frame_output = self.__call__(
prompt=prompt,
controlnet_hint=controlnet_hint[:,:,0,:,:] if not in_domain else controlnet_hint[:,:,0:1,:,:],
# b c f h w
num_inference_steps=20,
width=width,
height=height,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
fix_first_frame=False,
controlhint_in_uncond=controlhint_in_uncond,
).images[0]
if first_frame_output is not None:
if isinstance(first_frame_output, list):
first_frame_output = first_frame_output[0]
first_frame_output = torch.from_numpy(np.array(first_frame_output)).div(255) * 2 - 1
first_frame_output = rearrange(first_frame_output, "h w c -> c h w").unsqueeze(0) # FIXME 目前不允许多个batch 先设置为1
first_frame_output = first_frame_output.to(dtype= self.vae.dtype, device=self.vae.device)
first_frame_output_latent = self.vae.encode(first_frame_output).latent_dist.sample()
first_frame_output_latent = first_frame_output_latent * 0.18215
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 5.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
if controlnet_hint is not None:
if len(controlnet_hint.shape) == 5:
clip_length = controlnet_hint.shape[2]
else:
clip_length = 0
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
clip_length,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
latents_dtype = latents.dtype
if len(latents.shape) == 5 and init_same_noise_per_frame:
latents[:,:,1:,:,:] = latents[:,:,0:1,:,:]
if len(latents.shape) == 5 and init_noise_by_residual_thres > 0.0 and images is not None:
images = images.to(device=device, dtype=latents_dtype) # b c f h w
image_residual = torch.abs(images[:,:,1:,:,:] - images[:,:,:-1,:,:])
images = rearrange(images, "b c f h w -> (b f) c h w")
# norm residual
image_residual = image_residual / torch.max(image_residual)
image_residual = rearrange(image_residual, "b c f h w -> (b f) c h w")
image_residual = torch.nn.functional.interpolate(
image_residual,
size=(latents.shape[-2], latents.shape[-1]),
mode='bilinear')
image_residual = torch.mean(image_residual, dim=1)
image_residual_mask = (image_residual > init_noise_by_residual_thres).float()
image_residual_mask = repeat(image_residual_mask, '(b f) h w -> b f h w', b=batch_size)
image_residual_mask = repeat(image_residual_mask, 'b f h w -> b c f h w', c=latents.shape[1])
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
if fix_first_frame:
if add_first_frame_by_concat:
if len(first_frame_output_latent.shape) == 4:
latents = torch.cat([first_frame_output_latent.unsqueeze(2), latents], dim=2)
else:
latents = torch.cat([first_frame_output_latent, latents], dim=2)
if first_frame_control_hint is not None:
controlnet_hint = torch.cat([first_frame_control_hint, controlnet_hint], dim=2)
else:
controlnet_hint = torch.cat([controlnet_hint[:,:,0:1 ,:,:], controlnet_hint], dim=2)
if controlhint_in_uncond:
controlnet_hint = torch.cat([controlnet_hint] * 2) if do_classifier_free_guidance else controlnet_hint
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
if i<residual_control_steps and len(latents.shape) == 5 and init_noise_by_residual_thres > 0.0 and images is not None :
if first_frame_ddim_strength < 1.0 and i == 0 :
# NOTE DDIM to get the first noise
first_frame_output_latent_DDIM = first_frame_output_latent.clone()
full_noise_timestep, _ = self.get_timesteps(num_inference_steps, strength=first_frame_ddim_strength)
latent_timestep = full_noise_timestep[:1].repeat(batch_size * num_images_per_prompt)
first_frame_output_latent_DDIM = self.scheduler.add_noise(first_frame_output_latent_DDIM, latents[:,:,0,:,:], latent_timestep)
latents[:,:,0,:,:]=first_frame_output_latent_DDIM
begin_frame = 1
for n_frame in range(begin_frame, latents.shape[2]):
latents[:,:, n_frame, :, :] = \
(latents[:,:, n_frame, :, :] - latents[:,:, n_frame-1, :, :]) \
* image_residual_mask[:,:, n_frame-1, :, :] + \
latents[:,:, n_frame-1, :, :]
if fix_first_frame:
latents[:,:,0 ,:,:] = first_frame_output_latent
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if controlnet_hint is not None:
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=controlnet_hint,
return_dict=False,
)
down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample.to(dtype=latents_dtype)
else:
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
).sample.to(dtype=latents_dtype)
if video_scale > 0 and controlnet_hint is not None:
bsz = latents.shape[0]
f = latents.shape[2]
# 逐帧预测
latent_model_input_single_frame = rearrange(latent_model_input, 'b c f h w -> (b f) c h w')
text_embeddings_single_frame = torch.cat([text_embeddings] * f, dim=0)
control_maps_single_frame = rearrange(controlnet_hint, 'b c f h w -> (b f) c h w')
latent_model_input_single_frame = latent_model_input_single_frame.chunk(2, dim=0)[0]
text_embeddings_single_frame = text_embeddings_single_frame.chunk(2, dim=0)[0]
if controlhint_in_uncond:
control_maps_single_frame = control_maps_single_frame.chunk(2, dim=0)[0]
down_block_res_samples_single_frame, mid_block_res_sample_single_frame = self.controlnet(
latent_model_input_single_frame,
t,
encoder_hidden_states=text_embeddings_single_frame,
controlnet_cond=control_maps_single_frame,
return_dict=False,
)
down_block_res_samples_single_frame = [
down_block_res_sample_single_frame * controlnet_conditioning_scale
for down_block_res_sample_single_frame in down_block_res_samples_single_frame
]
mid_block_res_sample_single_frame *= controlnet_conditioning_scale
noise_pred_single_frame_uncond = self.unet(
latent_model_input_single_frame,
t,
encoder_hidden_states = text_embeddings_single_frame,
down_block_additional_residuals=down_block_res_samples_single_frame,
mid_block_additional_residual=mid_block_res_sample_single_frame,
).sample
noise_pred_single_frame_uncond = rearrange(noise_pred_single_frame_uncond, '(b f) c h w -> b c f h w', f=f)
# perform guidance
if do_classifier_free_guidance:
if video_scale > 0 and controlnet_hint is not None:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_single_frame_uncond + video_scale * (
noise_pred_uncond - noise_pred_single_frame_uncond
) + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
else:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
if add_first_frame_by_concat:
image = image[:,1:,:,:,:]
# 9. Run safety checker
has_nsfw_concept = None
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
if return_last_latent:
last_latent = latents[:,:,-1,:,:]
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), last_latent
else:
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)