import os import json import torch import random import base64 import gradio as gr from glob import glob from omegaconf import OmegaConf from datetime import datetime from safetensors import safe_open from diffusers import AutoencoderKL from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPTextModel, CLIPTokenizer from animatelcm.scheduler.lcm_scheduler import LCMScheduler from animatelcm.models.unet import UNet3DConditionModel from animatelcm.pipelines.pipeline_animation import AnimationPipeline from animatelcm.utils.util import save_videos_grid from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora from animatelcm.utils.lcm_utils import convert_lcm_lora import copy sample_idx = 0 scheduler_dict = { "LCM": LCMScheduler, } SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') class AnimateController: def __init__(self): # config dirs self.basedir = os.getcwd() self.stable_diffusion_dir = os.path.join( self.basedir, "models", "StableDiffusion") self.motion_module_dir = os.path.join( self.basedir, "models", "Motion_Module") self.personalized_model_dir = os.path.join( self.basedir, "models", "DreamBooth_LoRA") self.savedir = os.path.join( self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) self.savedir_sample = os.path.join(self.savedir, "sample") self.lcm_lora_path = "models/LCM_LoRA/sd15_t2v_beta_lora.safetensors" os.makedirs(self.savedir, exist_ok=True) self.stable_diffusion_list = [] self.motion_module_list = [] self.personalized_model_list = [] self.refresh_stable_diffusion() self.refresh_motion_module() self.refresh_personalized_model() # config models self.tokenizer = None self.text_encoder = None self.vae = None self.unet = None self.pipeline = None self.lora_model_state_dict = {} self.inference_config = OmegaConf.load("configs/inference.yaml") def refresh_stable_diffusion(self): self.stable_diffusion_list = glob( os.path.join(self.stable_diffusion_dir, "*/")) def refresh_motion_module(self): motion_module_list = glob(os.path.join( self.motion_module_dir, "*.ckpt")) self.motion_module_list = [ os.path.basename(p) for p in motion_module_list] def refresh_personalized_model(self): personalized_model_list = glob(os.path.join( self.personalized_model_dir, "*.safetensors")) self.personalized_model_list = [ os.path.basename(p) for p in personalized_model_list] def update_stable_diffusion(self, stable_diffusion_dropdown): stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown) self.tokenizer = CLIPTokenizer.from_pretrained( stable_diffusion_dropdown, subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained( stable_diffusion_dropdown, subfolder="text_encoder").cuda() self.vae = AutoencoderKL.from_pretrained( stable_diffusion_dropdown, subfolder="vae").cuda() self.unet = UNet3DConditionModel.from_pretrained_2d( stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() return gr.Dropdown.update() def update_motion_module(self, motion_module_dropdown): if self.unet is None: gr.Info(f"Please select a pretrained model path.") return gr.Dropdown.update(value=None) else: motion_module_dropdown = os.path.join( self.motion_module_dir, motion_module_dropdown) motion_module_state_dict = torch.load( motion_module_dropdown, map_location="cpu") missing, unexpected = self.unet.load_state_dict( motion_module_state_dict, strict=False) del motion_module_state_dict assert len(unexpected) == 0 return gr.Dropdown.update() def update_base_model(self, base_model_dropdown): if self.unet is None: gr.Info(f"Please select a pretrained model path.") return gr.Dropdown.update(value=None) else: base_model_dropdown = os.path.join( self.personalized_model_dir, base_model_dropdown) base_model_state_dict = {} with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: for key in f.keys(): base_model_state_dict[key] = f.get_tensor(key) converted_vae_checkpoint = convert_ldm_vae_checkpoint( base_model_state_dict, self.vae.config) self.vae.load_state_dict(converted_vae_checkpoint) converted_unet_checkpoint = convert_ldm_unet_checkpoint( base_model_state_dict, self.unet.config) self.unet.load_state_dict(converted_unet_checkpoint, strict=False) del converted_unet_checkpoint del converted_vae_checkpoint del base_model_state_dict # self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) return gr.Dropdown.update() def update_lora_model(self, lora_model_dropdown): lora_model_dropdown = os.path.join( self.personalized_model_dir, lora_model_dropdown) self.lora_model_state_dict = {} if lora_model_dropdown == "none": pass else: with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: for key in f.keys(): self.lora_model_state_dict[key] = f.get_tensor(key) return gr.Dropdown.update() @torch.no_grad() def animate( self, secret_token, lora_alpha_slider, spatial_lora_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, width_slider, length_slider, height_slider, cfg_scale_slider, seed_textbox ): if secret_token != SECRET_TOKEN: raise gr.Error( f'Invalid secret token. Please fork the original space if you want to use it for yourself.') if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() pipeline = AnimationPipeline( vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, scheduler=scheduler_dict[sampler_dropdown]( **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) ).to("cuda") original_state_dict = {k: v.cpu().clone() for k, v in pipeline.unet.state_dict().items() if "motion_modules." not in k} pipeline.unet = convert_lcm_lora(pipeline.unet, self.lcm_lora_path, spatial_lora_slider) pipeline.to("cuda") if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) else: torch.seed() seed = torch.initial_seed() with torch.autocast("cuda"): sample = pipeline( prompt_textbox, negative_prompt=negative_prompt_textbox, num_inference_steps=sample_step_slider, guidance_scale=cfg_scale_slider, width=width_slider, height=height_slider, video_length=length_slider, ).videos pipeline.unet.load_state_dict(original_state_dict,strict=False) del original_state_dict save_sample_path = os.path.join( self.savedir_sample, f"{sample_idx}.mp4") save_videos_grid(sample, save_sample_path) # sample_config = { # "prompt": prompt_textbox, # "n_prompt": negative_prompt_textbox, # "sampler": sampler_dropdown, # "num_inference_steps": sample_step_slider, # "guidance_scale": cfg_scale_slider, # "width": width_slider, # "height": height_slider, # "video_length": length_slider, # "seed": seed # } # json_str = json.dumps(sample_config, indent=4) # with open(os.path.join(self.savedir, "logs.json"), "a") as f: # f.write(json_str) # f.write("\n\n") # return gr.Video.update(value=save_sample_path) # Read the content of the video file and encode it to base64 with open(save_sample_path, "rb") as video_file: video_base64 = base64.b64encode(video_file.read()).decode('utf-8') # Prepend the appropriate data URI header with MIME type video_data_uri = 'data:video/mp4;base64,' + video_base64 # clean-up (otherwise there is a risk of "ghosting", eg. someone seeing the previous generated video", # of one of the steps go wrong) os.remove(save_sample_path) return video_data_uri controller = AnimateController() controller.update_stable_diffusion("stable-diffusion-v1-5") controller.update_motion_module("sd15_t2v_beta_motion.ckpt") controller.update_base_model("realistic2.safetensors") def ui(): with gr.Blocks() as demo: gr.HTML("""

This space is a REST API to programmatically generate MP4 videos.

Interested in using it? Look no further than the original space!

""") with gr.Column(): with gr.Row(): secret_token = gr.Text(label='Secret Token', max_lines=1) # TODO: find a way to use this to filter the dropdown #base_model = gr.Text(label="Base model") base_model_dropdown = gr.Dropdown( label="Select base Dreambooth model (required)", choices=controller.personalized_model_list, interactive=True, value="cartoon3d.safetensors" # value="realistic2.safetensors" ) base_model_dropdown.change(fn=controller.update_base_model, inputs=[ base_model_dropdown], outputs=[base_model_dropdown]) lora_model_dropdown = gr.Dropdown( label="Select LoRA model (optional)", choices=["none",], value="none", interactive=True, ) lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[ lora_model_dropdown], outputs=[lora_model_dropdown]) lora_alpha_slider = gr.Slider( label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) spatial_lora_slider = gr.Slider( label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True) personalized_refresh_button = gr.Button( value="\U0001F503", elem_classes="toolbutton") def update_personalized_model(): controller.refresh_personalized_model() return [ gr.Dropdown.update( choices=controller.personalized_model_list), gr.Dropdown.update( choices=["none"] + controller.personalized_model_list) ] personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[ base_model_dropdown, lora_model_dropdown]) with gr.Column(variant="panel"): prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="a boy holding a rabbit") negative_prompt_textbox = gr.Textbox( label="Negative prompt", lines=2, value="bad quality") with gr.Row().style(equal_height=False): with gr.Column(): with gr.Row(): sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list( scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) sample_step_slider = gr.Slider( label="Sampling steps", value=6, minimum=1, maximum=25, step=1) width_slider = gr.Slider( label="Width", value=512, minimum=256, maximum=1024, step=64) height_slider = gr.Slider( label="Height", value=512, minimum=256, maximum=1024, step=64) length_slider = gr.Slider( label="Animation length", value=16, minimum=12, maximum=20, step=1) cfg_scale_slider = gr.Slider( label="CFG Scale", value=1.5, minimum=1, maximum=2) with gr.Row(): seed_textbox = gr.Textbox(label="Seed", value=-1) seed_button = gr.Button( value="\U0001F3B2", elem_classes="toolbutton") seed_button.click(fn=lambda: gr.Textbox.update( value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) generate_button = gr.Button( value="Generate", variant='primary') result_video_base64 = gr.Text() generate_button.click( fn=controller.animate, inputs=[ secret_token, lora_alpha_slider, spatial_lora_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, width_slider, length_slider, height_slider, cfg_scale_slider, seed_textbox, ], outputs=[result_video_base64] ) return demo if __name__ == "__main__": demo = ui() # gr.close_all() # restart demo.queue(max_size=32, api_open=True).launch()