from enum import Enum import gc import numpy as np import torch from diffusers import StableDiffusionControlNetPipeline, ControlNetModel from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel from diffusers.schedulers import EulerAncestralDiscreteScheduler, DDIMScheduler from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor import utils import gradio_utils import os on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR" from einops import rearrange class ModelType(Enum): ControlNetPose = 5, class Model: def __init__(self, device, dtype, **kwargs): self.device = device self.dtype = dtype self.generator = torch.Generator(device=device) self.pipe_dict = { ModelType.ControlNetPose: StableDiffusionControlNetPipeline, } self.pipe = None self.model_type = None self.states = {} self.model_name = "" def set_model(self, model_type: ModelType, model_id: str, **kwargs): if hasattr(self, "pipe") and self.pipe is not None: del self.pipe torch.cuda.empty_cache() gc.collect() print('kwargs', kwargs) print('device', self.device) safety_checker = kwargs.pop('safety_checker', None) controlnet = kwargs.pop('controlnet', None) self.pipe = self.pipe_dict[model_type].from_pretrained( model_id, safety_checker=safety_checker, controlnet=controlnet, torch_dtype=torch.float16).to(self.device)#, torch_dtype=torch.float16).to(self.device) self.pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) self.pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2)) self.model_type = model_type self.model_name = model_id def inference_chunk(self, frame_ids, **kwargs): if not hasattr(self, "pipe") or self.pipe is None: return prompt = np.array(kwargs.pop('prompt')) negative_prompt = np.array(kwargs.pop('negative_prompt', '')) latents = None if 'latents' in kwargs: latents = kwargs.pop('latents')[frame_ids] if 'image' in kwargs: kwargs['image'] = kwargs['image'][frame_ids] if 'video_length' in kwargs: kwargs['video_length'] = len(frame_ids) return self.pipe(prompt=prompt[frame_ids].tolist(), negative_prompt=negative_prompt[frame_ids].tolist(), latents=latents, generator=self.generator, **kwargs) def inference(self, **kwargs): if not hasattr(self, "pipe") or self.pipe is None: return seed = kwargs.pop('seed', 0) if seed < 0: seed = self.generator.seed() kwargs.pop('generator', '') if 'image' in kwargs: f = kwargs['image'].shape[0] else: f = kwargs['video_length'] assert 'prompt' in kwargs prompt = [kwargs.pop('prompt')] * f negative_prompt = [kwargs.pop('negative_prompt', '')] * f frames_counter = 0 # Processing frame_by_frame result = [] for i in range(f): frame_ids = [0] + [i] self.generator.manual_seed(seed) print(f'Processing frame {i + 1} / {f}') result.append(self.inference_chunk(frame_ids=frame_ids, prompt=prompt, negative_prompt=negative_prompt, **kwargs).images[1:]) frames_counter += 1 if on_huggingspace and frames_counter >= 80: break result = np.concatenate(result) return result def process_controlnet_pose(self, video_path, prompt, num_inference_steps=20, controlnet_conditioning_scale=1.0, guidance_scale=9.0, seed=42, eta=0.0, resolution=512, use_cf_attn=True, save_path=None): print("Module Pose") video_path = gradio_utils.motion_to_video_path(video_path) if self.model_type != ModelType.ControlNetPose: controlnet = ControlNetModel.from_pretrained( "fusing/stable-diffusion-v1-5-controlnet-openpose", torch_dtype=torch.float16) self.set_model(ModelType.ControlNetPose, model_id="runwayml/stable-diffusion-v1-5", controlnet=controlnet) self.pipe.scheduler = DDIMScheduler.from_config( self.pipe.scheduler.config) video_path = gradio_utils.motion_to_video_path( video_path) if 'Motion' in video_path else video_path added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth' negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic' video, fps = utils.prepare_video( video_path, resolution, self.device, self.dtype, False, output_fps=4) control = utils.pre_process_pose( video, apply_pose_detect=False).to(self.device).to(self.dtype) f, _, h, w = video.shape self.generator.manual_seed(seed) latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, device=self.device, generator=self.generator) latents = latents.repeat(f, 1, 1, 1) result = self.inference(image=control, prompt=prompt + ', ' + added_prompt, height=h, width=w, negative_prompt=negative_prompts, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, controlnet_conditioning_scale=controlnet_conditioning_scale, eta=eta, latents=latents, seed=seed, output_type='numpy', ) return utils.create_gif(result, fps, path=save_path)