Spaces:
Runtime error
Runtime error
from enum import Enum | |
import gc | |
import numpy as np | |
import torch | |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel | |
from diffusers.schedulers import EulerAncestralDiscreteScheduler, DDIMScheduler | |
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor | |
import utils | |
import gradio_utils | |
import os | |
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" | |
from einops import rearrange | |
class ModelType(Enum): | |
ControlNetPose = 5, | |
class Model: | |
def __init__(self, device, dtype, **kwargs): | |
self.device = device | |
self.dtype = dtype | |
self.generator = torch.Generator(device=device) | |
self.pipe_dict = { | |
ModelType.ControlNetPose: StableDiffusionControlNetPipeline, | |
} | |
self.pipe = None | |
self.model_type = None | |
self.states = {} | |
self.model_name = "" | |
def set_model(self, model_type: ModelType, model_id: str, **kwargs): | |
if hasattr(self, "pipe") and self.pipe is not None: | |
del self.pipe | |
torch.cuda.empty_cache() | |
gc.collect() | |
print('kwargs', kwargs) | |
print('device', self.device) | |
safety_checker = kwargs.pop('safety_checker', None) | |
controlnet = kwargs.pop('controlnet', None) | |
self.pipe = self.pipe_dict[model_type].from_pretrained( | |
model_id, safety_checker=safety_checker, controlnet=controlnet, torch_dtype=torch.float16).to(self.device)#, torch_dtype=torch.float16).to(self.device) | |
self.pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) | |
self.pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) | |
self.model_type = model_type | |
self.model_name = model_id | |
def inference_chunk(self, frame_ids, **kwargs): | |
if not hasattr(self, "pipe") or self.pipe is None: | |
return | |
prompt = np.array(kwargs.pop('prompt')) | |
negative_prompt = np.array(kwargs.pop('negative_prompt', '')) | |
latents = None | |
if 'latents' in kwargs: | |
latents = kwargs.pop('latents')[frame_ids] | |
if 'image' in kwargs: | |
kwargs['image'] = kwargs['image'][frame_ids] | |
if 'video_length' in kwargs: | |
kwargs['video_length'] = len(frame_ids) | |
return self.pipe(prompt=prompt[frame_ids].tolist(), | |
negative_prompt=negative_prompt[frame_ids].tolist(), | |
latents=latents, | |
generator=self.generator, | |
**kwargs) | |
def inference(self, **kwargs): | |
if not hasattr(self, "pipe") or self.pipe is None: | |
return | |
seed = kwargs.pop('seed', 0) | |
if seed < 0: | |
seed = self.generator.seed() | |
kwargs.pop('generator', '') | |
if 'image' in kwargs: | |
f = kwargs['image'].shape[0] | |
else: | |
f = kwargs['video_length'] | |
assert 'prompt' in kwargs | |
prompt = [kwargs.pop('prompt')] * f | |
negative_prompt = [kwargs.pop('negative_prompt', '')] * f | |
frames_counter = 0 | |
# Processing frame_by_frame | |
result = [] | |
for i in range(f): | |
frame_ids = [0] + [i] | |
self.generator.manual_seed(seed) | |
print(f'Processing frame {i + 1} / {f}') | |
result.append(self.inference_chunk(frame_ids=frame_ids, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
**kwargs).images[1:]) | |
frames_counter += 1 | |
if on_huggingspace and frames_counter >= 80: | |
break | |
result = np.concatenate(result) | |
return result | |
def process_controlnet_pose(self, | |
video_path, | |
prompt, | |
num_inference_steps=20, | |
controlnet_conditioning_scale=1.0, | |
guidance_scale=9.0, | |
seed=42, | |
eta=0.0, | |
resolution=512, | |
use_cf_attn=True, | |
save_path=None): | |
print("Module Pose") | |
video_path = gradio_utils.motion_to_video_path(video_path) | |
if self.model_type != ModelType.ControlNetPose: | |
controlnet = ControlNetModel.from_pretrained( | |
"fusing/stable-diffusion-v1-5-controlnet-openpose", torch_dtype=torch.float16) | |
self.set_model(ModelType.ControlNetPose, | |
model_id="runwayml/stable-diffusion-v1-5", controlnet=controlnet) | |
self.pipe.scheduler = DDIMScheduler.from_config( | |
self.pipe.scheduler.config) | |
video_path = gradio_utils.motion_to_video_path( | |
video_path) if 'Motion' in video_path else video_path | |
added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth' | |
negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic' | |
video, fps = utils.prepare_video( | |
video_path, resolution, self.device, self.dtype, False, output_fps=4) | |
control = utils.pre_process_pose( | |
video, apply_pose_detect=False).to(self.device).to(self.dtype) | |
f, _, h, w = video.shape | |
self.generator.manual_seed(seed) | |
latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, | |
device=self.device, generator=self.generator) | |
latents = latents.repeat(f, 1, 1, 1) | |
result = self.inference(image=control, | |
prompt=prompt + ', ' + added_prompt, | |
height=h, | |
width=w, | |
negative_prompt=negative_prompts, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
eta=eta, | |
latents=latents, | |
seed=seed, | |
output_type='numpy', | |
) | |
return utils.create_gif(result, fps, path=save_path) | |