Spaces:
Runtime error
Runtime error
# Copyright 2023 Bytedance Ltd. and/or its affiliates | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from .pipeline_st_stable_diffusion import SpatioTemporalStableDiffusionPipeline | |
from typing import Callable, List, Optional, Union | |
from diffusers.schedulers import ( | |
DDIMScheduler, | |
DPMSolverMultistepScheduler, | |
EulerAncestralDiscreteScheduler, | |
EulerDiscreteScheduler, | |
LMSDiscreteScheduler, | |
PNDMScheduler, | |
) | |
from transformers import DPTForDepthEstimation | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | |
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler | |
import torch | |
from einops import rearrange, repeat | |
import decord | |
import cv2 | |
import random | |
import numpy as np | |
from ..models.unet_3d_condition import UNetPseudo3DConditionModel | |
from ..models.controlnet3d import ControlNet3DModel | |
class Controlnet3DStableDiffusionPipeline(SpatioTemporalStableDiffusionPipeline): | |
def __init__( | |
self, | |
vae: AutoencoderKL, | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
unet: UNetPseudo3DConditionModel, | |
controlnet: ControlNet3DModel, | |
scheduler: Union[ | |
DDIMScheduler, | |
PNDMScheduler, | |
LMSDiscreteScheduler, | |
EulerDiscreteScheduler, | |
EulerAncestralDiscreteScheduler, | |
DPMSolverMultistepScheduler, | |
], | |
annotator_model=None, | |
): | |
super().__init__(vae, text_encoder, tokenizer, unet, scheduler) | |
self.annotator_model = annotator_model | |
self.controlnet = controlnet | |
self.unet = unet | |
self.vae = vae | |
self.tokenizer = tokenizer | |
self.text_encoder = text_encoder | |
self.scheduler = scheduler | |
self.register_modules( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
controlnet=controlnet, | |
scheduler=scheduler, | |
) | |
def get_frames_preprocess(data_path, num_frames=24, sampling_rate=1, begin_indice=0, return_np=False): | |
vr = decord.VideoReader(data_path,) | |
n_images = len(vr) | |
fps_vid = round(vr.get_avg_fps()) | |
frame_indices = [begin_indice + i*sampling_rate for i in range(num_frames)] # 随机取n帧 | |
while n_images <= frame_indices[-1]: | |
# 超过视频长度,采样率减小直至不超过。 | |
sampling_rate -= 1 | |
if sampling_rate == 0: | |
# NOTE 边界检查 | |
return None, None | |
frame_indices = [i*sampling_rate for i in range(num_frames)] | |
frames = vr.get_batch(frame_indices).asnumpy() | |
if return_np: | |
return frames, fps_vid | |
frames = torch.from_numpy(frames).div(255) * 2 - 1 | |
frames = rearrange(frames, "f h w c -> c f h w").unsqueeze(0) | |
return frames, fps_vid | |
def get_canny_edge_map(self, frames, ): | |
# (b f) c h w" | |
# from tensor to numpy | |
inputs = frames.cpu().numpy() | |
inputs = rearrange(inputs, 'f c h w -> f h w c') | |
# inputs from [-1, 1] to [0, 255] | |
inputs = (inputs + 1) * 127.5 | |
inputs = inputs.astype(np.uint8) | |
lower_threshold = 100 | |
higher_threshold = 200 | |
edge_images = np.stack([cv2.Canny(inp, lower_threshold, higher_threshold) for inp in inputs]) | |
# from numpy to tensors | |
edge_images = torch.from_numpy(edge_images).unsqueeze(1) # f, 1, h, w | |
edge_images = edge_images.div(255)*2 - 1 | |
# print(torch.max(out_images), torch.min(out_images), out_images.dtype) | |
return edge_images.to(dtype= self.controlnet.dtype, device=self.controlnet.device) | |
def get_depth_map(self, frames, height, width, return_standard_norm=False ): | |
""" | |
frames should be like: (f c h w), you may turn b f c h w -> (b f) c h w first | |
""" | |
h,w = height, width | |
inputs = torch.nn.functional.interpolate( | |
frames, | |
size=(384, 384), | |
mode="bicubic", | |
antialias=True, | |
) | |
# 转类型和设备 | |
inputs = inputs.to(dtype= self.annotator_model.dtype, device=self.annotator_model.device) | |
outputs = self.annotator_model(inputs) | |
predicted_depths = outputs.predicted_depth | |
# interpolate to original size | |
predictions = torch.nn.functional.interpolate( | |
predicted_depths.unsqueeze(1), | |
size=(h, w), | |
mode="bicubic", | |
) | |
# normalize output | |
if return_standard_norm: | |
depth_min = torch.amin(predictions, dim=[1, 2, 3], keepdim=True) | |
depth_max = torch.amax(predictions, dim=[1, 2, 3], keepdim=True) | |
predictions = 2.0 * (predictions - depth_min) / (depth_max - depth_min) - 1.0 | |
else: | |
predictions -= torch.min(predictions) | |
predictions /= torch.max(predictions) | |
return predictions | |
def get_hed_map(self, frames,): | |
if isinstance(frames, torch.Tensor): | |
# 输入的就是 b c h w的tensor 范围是-1~1,需要转换为0~1 | |
frames = (frames + 1) / 2 | |
#rgb转bgr | |
bgr_frames = frames.clone() | |
bgr_frames[:, 0, :, :] = frames[:, 2, :, :] | |
bgr_frames[:, 2, :, :] = frames[:, 0, :, :] | |
edge = self.annotator_model(bgr_frames) # 范围也是0~1 | |
return edge | |
else: | |
assert frames.ndim == 3 | |
frames = frames[:, :, ::-1].copy() | |
with torch.no_grad(): | |
image_hed = torch.from_numpy(frames).to(next(self.annotator_model.parameters()).device, dtype=next(self.annotator_model.parameters()).dtype ) | |
image_hed = image_hed / 255.0 | |
image_hed = rearrange(image_hed, 'h w c -> 1 c h w') | |
edge = self.annotator_model(image_hed)[0] | |
edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8) | |
return edge[0] | |
def get_pose_map(self, frames,): | |
if isinstance(frames, torch.Tensor): | |
# 输入的就是 b c h w的tensor 范围是-1~1,需要转换为0~1 | |
frames = (frames + 1) / 2 | |
np_frames = frames.cpu().numpy() * 255 | |
np_frames = np.array(np_frames, dtype=np.uint8) | |
np_frames = rearrange(np_frames, 'f c h w-> f h w c') | |
poses = np.stack([self.annotator_model(inp) for inp in np_frames]) | |
else: | |
poses = self.annotator_model(frames) | |
return poses | |
def get_timesteps(self, num_inference_steps, strength,): | |
# get the original timestep using init_timestep | |
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | |
t_start = max(num_inference_steps - init_timestep, 0) | |
timesteps = self.scheduler.timesteps[t_start:] | |
return timesteps, num_inference_steps - t_start | |
def __call__( | |
self, | |
prompt: Union[str, List[str]], | |
controlnet_hint = None, | |
fps_labels = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 50, | |
clip_length: int = 8, # NOTE clip_length和images的帧数一致。 | |
guidance_scale: float = 7.5, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
eta: float = 0.0, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
callback_steps: Optional[int] = 1, | |
cross_attention_kwargs = None, | |
video_scale: float = 0.0, | |
controlnet_conditioning_scale: float = 1.0, | |
fix_first_frame=True, | |
first_frame_output = None , # 也可以允许挑好图后传入。 | |
first_frame_output_latent = None, | |
first_frame_control_hint = None, # 维持第一帧 | |
add_first_frame_by_concat = False, | |
controlhint_in_uncond = False, | |
init_same_noise_per_frame=False, | |
init_noise_by_residual_thres=0.0, | |
images=None, | |
in_domain=False, # 是否调用视频模型生成图片 | |
residual_control_steps=1, | |
first_frame_ddim_strength=1.0, | |
return_last_latent = False, | |
): | |
''' | |
add origin video frames to get depth maps | |
''' | |
if fix_first_frame and first_frame_output is None and first_frame_output_latent is None: | |
first_frame_output = self.__call__( | |
prompt=prompt, | |
controlnet_hint=controlnet_hint[:,:,0,:,:] if not in_domain else controlnet_hint[:,:,0:1,:,:], | |
# b c f h w | |
num_inference_steps=20, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=1, | |
generator=generator, | |
fix_first_frame=False, | |
controlhint_in_uncond=controlhint_in_uncond, | |
).images[0] | |
if first_frame_output is not None: | |
if isinstance(first_frame_output, list): | |
first_frame_output = first_frame_output[0] | |
first_frame_output = torch.from_numpy(np.array(first_frame_output)).div(255) * 2 - 1 | |
first_frame_output = rearrange(first_frame_output, "h w c -> c h w").unsqueeze(0) # FIXME 目前不允许多个batch 先设置为1 | |
first_frame_output = first_frame_output.to(dtype= self.vae.dtype, device=self.vae.device) | |
first_frame_output_latent = self.vae.encode(first_frame_output).latent_dist.sample() | |
first_frame_output_latent = first_frame_output_latent * 0.18215 | |
# 0. Default height and width to unet | |
height = height or self.unet.config.sample_size * self.vae_scale_factor | |
width = width or self.unet.config.sample_size * self.vae_scale_factor | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs(prompt, height, width, callback_steps) | |
# 2. Define call parameters | |
batch_size = 1 if isinstance(prompt, str) else len(prompt) | |
device = self._execution_device | |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
# corresponds to doing no classifier free guidance. | |
do_classifier_free_guidance = guidance_scale > 5.0 | |
# 3. Encode input prompt | |
text_embeddings = self._encode_prompt( | |
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt | |
) | |
# 4. Prepare timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps | |
# 5. Prepare latent variables | |
num_channels_latents = self.unet.in_channels | |
if controlnet_hint is not None: | |
if len(controlnet_hint.shape) == 5: | |
clip_length = controlnet_hint.shape[2] | |
else: | |
clip_length = 0 | |
latents = self.prepare_latents( | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
clip_length, | |
height, | |
width, | |
text_embeddings.dtype, | |
device, | |
generator, | |
latents, | |
) | |
latents_dtype = latents.dtype | |
if len(latents.shape) == 5 and init_same_noise_per_frame: | |
latents[:,:,1:,:,:] = latents[:,:,0:1,:,:] | |
if len(latents.shape) == 5 and init_noise_by_residual_thres > 0.0 and images is not None: | |
images = images.to(device=device, dtype=latents_dtype) # b c f h w | |
image_residual = torch.abs(images[:,:,1:,:,:] - images[:,:,:-1,:,:]) | |
images = rearrange(images, "b c f h w -> (b f) c h w") | |
# norm residual | |
image_residual = image_residual / torch.max(image_residual) | |
image_residual = rearrange(image_residual, "b c f h w -> (b f) c h w") | |
image_residual = torch.nn.functional.interpolate( | |
image_residual, | |
size=(latents.shape[-2], latents.shape[-1]), | |
mode='bilinear') | |
image_residual = torch.mean(image_residual, dim=1) | |
image_residual_mask = (image_residual > init_noise_by_residual_thres).float() | |
image_residual_mask = repeat(image_residual_mask, '(b f) h w -> b f h w', b=batch_size) | |
image_residual_mask = repeat(image_residual_mask, 'b f h w -> b c f h w', c=latents.shape[1]) | |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
# 7. Denoising loop | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
if fix_first_frame: | |
if add_first_frame_by_concat: | |
if len(first_frame_output_latent.shape) == 4: | |
latents = torch.cat([first_frame_output_latent.unsqueeze(2), latents], dim=2) | |
else: | |
latents = torch.cat([first_frame_output_latent, latents], dim=2) | |
if first_frame_control_hint is not None: | |
controlnet_hint = torch.cat([first_frame_control_hint, controlnet_hint], dim=2) | |
else: | |
controlnet_hint = torch.cat([controlnet_hint[:,:,0:1 ,:,:], controlnet_hint], dim=2) | |
if controlhint_in_uncond: | |
controlnet_hint = torch.cat([controlnet_hint] * 2) if do_classifier_free_guidance else controlnet_hint | |
for i, t in enumerate(timesteps): | |
# expand the latents if we are doing classifier free guidance | |
if i<residual_control_steps and len(latents.shape) == 5 and init_noise_by_residual_thres > 0.0 and images is not None : | |
if first_frame_ddim_strength < 1.0 and i == 0 : | |
# NOTE DDIM to get the first noise | |
first_frame_output_latent_DDIM = first_frame_output_latent.clone() | |
full_noise_timestep, _ = self.get_timesteps(num_inference_steps, strength=first_frame_ddim_strength) | |
latent_timestep = full_noise_timestep[:1].repeat(batch_size * num_images_per_prompt) | |
first_frame_output_latent_DDIM = self.scheduler.add_noise(first_frame_output_latent_DDIM, latents[:,:,0,:,:], latent_timestep) | |
latents[:,:,0,:,:]=first_frame_output_latent_DDIM | |
begin_frame = 1 | |
for n_frame in range(begin_frame, latents.shape[2]): | |
latents[:,:, n_frame, :, :] = \ | |
(latents[:,:, n_frame, :, :] - latents[:,:, n_frame-1, :, :]) \ | |
* image_residual_mask[:,:, n_frame-1, :, :] + \ | |
latents[:,:, n_frame-1, :, :] | |
if fix_first_frame: | |
latents[:,:,0 ,:,:] = first_frame_output_latent | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
if controlnet_hint is not None: | |
down_block_res_samples, mid_block_res_sample = self.controlnet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings, | |
controlnet_cond=controlnet_hint, | |
return_dict=False, | |
) | |
down_block_res_samples = [ | |
down_block_res_sample * controlnet_conditioning_scale | |
for down_block_res_sample in down_block_res_samples | |
] | |
mid_block_res_sample *= controlnet_conditioning_scale | |
noise_pred = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings, | |
cross_attention_kwargs=cross_attention_kwargs, | |
down_block_additional_residuals=down_block_res_samples, | |
mid_block_additional_residual=mid_block_res_sample, | |
).sample.to(dtype=latents_dtype) | |
else: | |
# predict the noise residual | |
noise_pred = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=text_embeddings, | |
).sample.to(dtype=latents_dtype) | |
if video_scale > 0 and controlnet_hint is not None: | |
bsz = latents.shape[0] | |
f = latents.shape[2] | |
# 逐帧预测 | |
latent_model_input_single_frame = rearrange(latent_model_input, 'b c f h w -> (b f) c h w') | |
text_embeddings_single_frame = torch.cat([text_embeddings] * f, dim=0) | |
control_maps_single_frame = rearrange(controlnet_hint, 'b c f h w -> (b f) c h w') | |
latent_model_input_single_frame = latent_model_input_single_frame.chunk(2, dim=0)[0] | |
text_embeddings_single_frame = text_embeddings_single_frame.chunk(2, dim=0)[0] | |
if controlhint_in_uncond: | |
control_maps_single_frame = control_maps_single_frame.chunk(2, dim=0)[0] | |
down_block_res_samples_single_frame, mid_block_res_sample_single_frame = self.controlnet( | |
latent_model_input_single_frame, | |
t, | |
encoder_hidden_states=text_embeddings_single_frame, | |
controlnet_cond=control_maps_single_frame, | |
return_dict=False, | |
) | |
down_block_res_samples_single_frame = [ | |
down_block_res_sample_single_frame * controlnet_conditioning_scale | |
for down_block_res_sample_single_frame in down_block_res_samples_single_frame | |
] | |
mid_block_res_sample_single_frame *= controlnet_conditioning_scale | |
noise_pred_single_frame_uncond = self.unet( | |
latent_model_input_single_frame, | |
t, | |
encoder_hidden_states = text_embeddings_single_frame, | |
down_block_additional_residuals=down_block_res_samples_single_frame, | |
mid_block_additional_residual=mid_block_res_sample_single_frame, | |
).sample | |
noise_pred_single_frame_uncond = rearrange(noise_pred_single_frame_uncond, '(b f) c h w -> b c f h w', f=f) | |
# perform guidance | |
if do_classifier_free_guidance: | |
if video_scale > 0 and controlnet_hint is not None: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_single_frame_uncond + video_scale * ( | |
noise_pred_uncond - noise_pred_single_frame_uncond | |
) + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
else: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ( | |
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | |
): | |
progress_bar.update() | |
if callback is not None and i % callback_steps == 0: | |
callback(i, t, latents) | |
# 8. Post-processing | |
image = self.decode_latents(latents) | |
if add_first_frame_by_concat: | |
image = image[:,1:,:,:,:] | |
# 9. Run safety checker | |
has_nsfw_concept = None | |
# 10. Convert to PIL | |
if output_type == "pil": | |
image = self.numpy_to_pil(image) | |
if not return_dict: | |
return (image, has_nsfw_concept) | |
if return_last_latent: | |
last_latent = latents[:,:,-1,:,:] | |
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), last_latent | |
else: | |
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | |