import os import json import torch import random import gradio as gr from glob import glob from omegaconf import OmegaConf from datetime import datetime from safetensors import safe_open from diffusers import AutoencoderKL from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPTextModel, CLIPTokenizer from animatelcm.scheduler.lcm_scheduler import LCMScheduler from animatelcm.models.unet import UNet3DConditionModel from animatelcm.pipelines.pipeline_animation import AnimationPipeline from animatelcm.utils.util import save_videos_grid from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora from animatelcm.utils.lcm_utils import convert_lcm_lora import copy sample_idx = 0 scheduler_dict = { "LCM": LCMScheduler, } css = """ .toolbutton { margin-buttom: 0em 0em 0em 0em; max-width: 2.5em; min-width: 2.5em !important; height: 2.5em; } """ SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') class AnimateController: def __init__(self): # config dirs self.basedir = os.getcwd() self.stable_diffusion_dir = os.path.join( self.basedir, "models", "StableDiffusion") self.motion_module_dir = os.path.join( self.basedir, "models", "Motion_Module") self.personalized_model_dir = os.path.join( self.basedir, "models", "DreamBooth_LoRA") self.savedir = os.path.join( self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) self.savedir_sample = os.path.join(self.savedir, "sample") self.lcm_lora_path = "models/LCM_LoRA/sd15_t2v_beta_lora.safetensors" os.makedirs(self.savedir, exist_ok=True) self.stable_diffusion_list = [] self.motion_module_list = [] self.personalized_model_list = [] self.refresh_stable_diffusion() self.refresh_motion_module() self.refresh_personalized_model() # config models self.tokenizer = None self.text_encoder = None self.vae = None self.unet = None self.pipeline = None self.lora_model_state_dict = {} self.inference_config = OmegaConf.load("configs/inference.yaml") def refresh_stable_diffusion(self): self.stable_diffusion_list = glob( os.path.join(self.stable_diffusion_dir, "*/")) def refresh_motion_module(self): motion_module_list = glob(os.path.join( self.motion_module_dir, "*.ckpt")) self.motion_module_list = [ os.path.basename(p) for p in motion_module_list] def refresh_personalized_model(self): personalized_model_list = glob(os.path.join( self.personalized_model_dir, "*.safetensors")) self.personalized_model_list = [ os.path.basename(p) for p in personalized_model_list] def update_stable_diffusion(self, stable_diffusion_dropdown): stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown) self.tokenizer = CLIPTokenizer.from_pretrained( stable_diffusion_dropdown, subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained( stable_diffusion_dropdown, subfolder="text_encoder").cuda() self.vae = AutoencoderKL.from_pretrained( stable_diffusion_dropdown, subfolder="vae").cuda() self.unet = UNet3DConditionModel.from_pretrained_2d( stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() return gr.Dropdown.update() def update_motion_module(self, motion_module_dropdown): if self.unet is None: gr.Info(f"Please select a pretrained model path.") return gr.Dropdown.update(value=None) else: motion_module_dropdown = os.path.join( self.motion_module_dir, motion_module_dropdown) motion_module_state_dict = torch.load( motion_module_dropdown, map_location="cpu") missing, unexpected = self.unet.load_state_dict( motion_module_state_dict, strict=False) del motion_module_state_dict assert len(unexpected) == 0 return gr.Dropdown.update() def update_base_model(self, base_model_dropdown): if self.unet is None: gr.Info(f"Please select a pretrained model path.") return gr.Dropdown.update(value=None) else: base_model_dropdown = os.path.join( self.personalized_model_dir, base_model_dropdown) base_model_state_dict = {} with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: for key in f.keys(): base_model_state_dict[key] = f.get_tensor(key) converted_vae_checkpoint = convert_ldm_vae_checkpoint( base_model_state_dict, self.vae.config) self.vae.load_state_dict(converted_vae_checkpoint) converted_unet_checkpoint = convert_ldm_unet_checkpoint( base_model_state_dict, self.unet.config) self.unet.load_state_dict(converted_unet_checkpoint, strict=False) del converted_unet_checkpoint del converted_vae_checkpoint del base_model_state_dict # self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) return gr.Dropdown.update() def update_lora_model(self, lora_model_dropdown): lora_model_dropdown = os.path.join( self.personalized_model_dir, lora_model_dropdown) self.lora_model_state_dict = {} if lora_model_dropdown == "none": pass else: with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: for key in f.keys(): self.lora_model_state_dict[key] = f.get_tensor(key) return gr.Dropdown.update() @torch.no_grad() def animate( self, secret_token, lora_alpha_slider, spatial_lora_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, width_slider, length_slider, height_slider, cfg_scale_slider, seed_textbox ): if secret_token != SECRET_TOKEN: raise gr.Error( f'Invalid secret token. Please fork the original space if you want to use it for yourself.') if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() pipeline = AnimationPipeline( vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, scheduler=scheduler_dict[sampler_dropdown]( **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) ).to("cuda") original_state_dict = {k: v.cpu().clone() for k, v in pipeline.unet.state_dict().items() if "motion_modules." not in k} pipeline.unet = convert_lcm_lora(pipeline.unet, self.lcm_lora_path, spatial_lora_slider) pipeline.to("cuda") if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) else: torch.seed() seed = torch.initial_seed() with torch.autocast("cuda"): sample = pipeline( prompt_textbox, negative_prompt=negative_prompt_textbox, num_inference_steps=sample_step_slider, guidance_scale=cfg_scale_slider, width=width_slider, height=height_slider, video_length=length_slider, ).videos pipeline.unet.load_state_dict(original_state_dict,strict=False) del original_state_dict save_sample_path = os.path.join( self.savedir_sample, f"{sample_idx}.mp4") save_videos_grid(sample, save_sample_path) # sample_config = { # "prompt": prompt_textbox, # "n_prompt": negative_prompt_textbox, # "sampler": sampler_dropdown, # "num_inference_steps": sample_step_slider, # "guidance_scale": cfg_scale_slider, # "width": width_slider, # "height": height_slider, # "video_length": length_slider, # "seed": seed # } # json_str = json.dumps(sample_config, indent=4) # with open(os.path.join(self.savedir, "logs.json"), "a") as f: # f.write(json_str) # f.write("\n\n") # return gr.Video.update(value=save_sample_path) # Read the content of the video file and encode it to base64 with open(save_sample_path, "rb") as video_file: video_base64 = base64.b64encode(video_file.read()).decode('utf-8') # Prepend the appropriate data URI header with MIME type video_data_uri = 'data:video/mp4;base64,' + video_base64 # clean-up (otherwise there is a risk of "ghosting", eg. someone seeing the previous generated video", # of one of the steps go wrong) os.remove(video_path) return video_data_uri controller = AnimateController() controller.update_stable_diffusion("stable-diffusion-v1-5") controller.update_motion_module("sd15_t2v_beta_motion.ckpt") controller.update_base_model("realistic2.safetensors") def ui(): with gr.Blocks(css=css) as demo: gr.Markdown( """ # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769) Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)
[arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm) """ ''' Important Notes: 1. The generation speed is around few seconds. There is delay in the space. 2. Increase the sampling step and cfg and set proper negative prompt if you want more fancy videos. ''' ) with gr.Column(): with gr.Row(): secret_token = gr.Text(label='Secret Token', max_lines=1) base_model_dropdown = gr.Dropdown( label="Select base Dreambooth model (required)", choices=controller.personalized_model_list, interactive=True, value="realistic2.safetensors" ) base_model_dropdown.change(fn=controller.update_base_model, inputs=[ base_model_dropdown], outputs=[base_model_dropdown]) lora_model_dropdown = gr.Dropdown( label="Select LoRA model (optional)", choices=["none",], value="none", interactive=True, ) lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[ lora_model_dropdown], outputs=[lora_model_dropdown]) lora_alpha_slider = gr.Slider( label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) spatial_lora_slider = gr.Slider( label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True) personalized_refresh_button = gr.Button( value="\U0001F503", elem_classes="toolbutton") def update_personalized_model(): controller.refresh_personalized_model() return [ gr.Dropdown.update( choices=controller.personalized_model_list), gr.Dropdown.update( choices=["none"] + controller.personalized_model_list) ] personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[ base_model_dropdown, lora_model_dropdown]) with gr.Column(variant="panel"): gr.Markdown( """ ### 2. Configs for AnimateLCM. """ ) prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="a boy holding a rabbit") negative_prompt_textbox = gr.Textbox( label="Negative prompt", lines=2, value="bad quality") with gr.Row().style(equal_height=False): with gr.Column(): with gr.Row(): sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list( scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) sample_step_slider = gr.Slider( label="Sampling steps", value=6, minimum=1, maximum=25, step=1) width_slider = gr.Slider( label="Width", value=512, minimum=256, maximum=1024, step=64) height_slider = gr.Slider( label="Height", value=512, minimum=256, maximum=1024, step=64) length_slider = gr.Slider( label="Animation length", value=16, minimum=12, maximum=20, step=1) cfg_scale_slider = gr.Slider( label="CFG Scale", value=1.5, minimum=1, maximum=2) with gr.Row(): seed_textbox = gr.Textbox(label="Seed", value=-1) seed_button = gr.Button( value="\U0001F3B2", elem_classes="toolbutton") seed_button.click(fn=lambda: gr.Textbox.update( value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) generate_button = gr.Button( value="Generate", variant='primary') result_video_base_64 = gr.Text() generate_button.click( fn=controller.animate, inputs=[ secret_token, lora_alpha_slider, spatial_lora_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, width_slider, length_slider, height_slider, cfg_scale_slider, seed_textbox, ], outputs=[result_video_base64] ) return demo if __name__ == "__main__": demo = ui() # gr.close_all() # restart demo.queue(api_open=True) demo.launch()