|
|
|
import os |
|
import json |
|
import torch |
|
import random |
|
|
|
import gradio as gr |
|
from glob import glob |
|
from omegaconf import OmegaConf |
|
from datetime import datetime |
|
from safetensors import safe_open |
|
|
|
from diffusers import AutoencoderKL |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
from animatelcm.scheduler.lcm_scheduler import LCMScheduler |
|
from animatelcm.models.unet import UNet3DConditionModel |
|
from animatelcm.pipelines.pipeline_animation import AnimationPipeline |
|
from animatelcm.utils.util import save_videos_grid |
|
from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint |
|
from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora |
|
from animatelcm.utils.lcm_utils import convert_lcm_lora |
|
import copy |
|
|
|
sample_idx = 0 |
|
scheduler_dict = { |
|
"LCM": LCMScheduler, |
|
} |
|
|
|
css = """ |
|
.toolbutton { |
|
margin-buttom: 0em 0em 0em 0em; |
|
max-width: 2.5em; |
|
min-width: 2.5em !important; |
|
height: 2.5em; |
|
} |
|
""" |
|
|
|
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') |
|
|
|
class AnimateController: |
|
def __init__(self): |
|
|
|
|
|
self.basedir = os.getcwd() |
|
self.stable_diffusion_dir = os.path.join( |
|
self.basedir, "models", "StableDiffusion") |
|
self.motion_module_dir = os.path.join( |
|
self.basedir, "models", "Motion_Module") |
|
self.personalized_model_dir = os.path.join( |
|
self.basedir, "models", "DreamBooth_LoRA") |
|
self.savedir = os.path.join( |
|
self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) |
|
self.savedir_sample = os.path.join(self.savedir, "sample") |
|
self.lcm_lora_path = "models/LCM_LoRA/sd15_t2v_beta_lora.safetensors" |
|
os.makedirs(self.savedir, exist_ok=True) |
|
|
|
self.stable_diffusion_list = [] |
|
self.motion_module_list = [] |
|
self.personalized_model_list = [] |
|
|
|
self.refresh_stable_diffusion() |
|
self.refresh_motion_module() |
|
self.refresh_personalized_model() |
|
|
|
|
|
self.tokenizer = None |
|
self.text_encoder = None |
|
self.vae = None |
|
self.unet = None |
|
self.pipeline = None |
|
self.lora_model_state_dict = {} |
|
|
|
self.inference_config = OmegaConf.load("configs/inference.yaml") |
|
|
|
def refresh_stable_diffusion(self): |
|
self.stable_diffusion_list = glob( |
|
os.path.join(self.stable_diffusion_dir, "*/")) |
|
|
|
def refresh_motion_module(self): |
|
motion_module_list = glob(os.path.join( |
|
self.motion_module_dir, "*.ckpt")) |
|
self.motion_module_list = [ |
|
os.path.basename(p) for p in motion_module_list] |
|
|
|
def refresh_personalized_model(self): |
|
personalized_model_list = glob(os.path.join( |
|
self.personalized_model_dir, "*.safetensors")) |
|
self.personalized_model_list = [ |
|
os.path.basename(p) for p in personalized_model_list] |
|
|
|
def update_stable_diffusion(self, stable_diffusion_dropdown): |
|
stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown) |
|
self.tokenizer = CLIPTokenizer.from_pretrained( |
|
stable_diffusion_dropdown, subfolder="tokenizer") |
|
self.text_encoder = CLIPTextModel.from_pretrained( |
|
stable_diffusion_dropdown, subfolder="text_encoder").cuda() |
|
self.vae = AutoencoderKL.from_pretrained( |
|
stable_diffusion_dropdown, subfolder="vae").cuda() |
|
self.unet = UNet3DConditionModel.from_pretrained_2d( |
|
stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() |
|
return gr.Dropdown.update() |
|
|
|
def update_motion_module(self, motion_module_dropdown): |
|
if self.unet is None: |
|
gr.Info(f"Please select a pretrained model path.") |
|
return gr.Dropdown.update(value=None) |
|
else: |
|
motion_module_dropdown = os.path.join( |
|
self.motion_module_dir, motion_module_dropdown) |
|
motion_module_state_dict = torch.load( |
|
motion_module_dropdown, map_location="cpu") |
|
missing, unexpected = self.unet.load_state_dict( |
|
motion_module_state_dict, strict=False) |
|
del motion_module_state_dict |
|
assert len(unexpected) == 0 |
|
return gr.Dropdown.update() |
|
|
|
def update_base_model(self, base_model_dropdown): |
|
if self.unet is None: |
|
gr.Info(f"Please select a pretrained model path.") |
|
return gr.Dropdown.update(value=None) |
|
else: |
|
base_model_dropdown = os.path.join( |
|
self.personalized_model_dir, base_model_dropdown) |
|
base_model_state_dict = {} |
|
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
base_model_state_dict[key] = f.get_tensor(key) |
|
|
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint( |
|
base_model_state_dict, self.vae.config) |
|
self.vae.load_state_dict(converted_vae_checkpoint) |
|
|
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint( |
|
base_model_state_dict, self.unet.config) |
|
self.unet.load_state_dict(converted_unet_checkpoint, strict=False) |
|
del converted_unet_checkpoint |
|
del converted_vae_checkpoint |
|
del base_model_state_dict |
|
|
|
|
|
return gr.Dropdown.update() |
|
|
|
def update_lora_model(self, lora_model_dropdown): |
|
lora_model_dropdown = os.path.join( |
|
self.personalized_model_dir, lora_model_dropdown) |
|
self.lora_model_state_dict = {} |
|
if lora_model_dropdown == "none": |
|
pass |
|
else: |
|
with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
self.lora_model_state_dict[key] = f.get_tensor(key) |
|
return gr.Dropdown.update() |
|
@torch.no_grad() |
|
def animate( |
|
self, |
|
secret_token, |
|
lora_alpha_slider, |
|
spatial_lora_slider, |
|
prompt_textbox, |
|
negative_prompt_textbox, |
|
sampler_dropdown, |
|
sample_step_slider, |
|
width_slider, |
|
length_slider, |
|
height_slider, |
|
cfg_scale_slider, |
|
seed_textbox |
|
): |
|
if secret_token != SECRET_TOKEN: |
|
raise gr.Error( |
|
f'Invalid secret token. Please fork the original space if you want to use it for yourself.') |
|
|
|
|
|
if is_xformers_available(): |
|
self.unet.enable_xformers_memory_efficient_attention() |
|
|
|
pipeline = AnimationPipeline( |
|
vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, |
|
scheduler=scheduler_dict[sampler_dropdown]( |
|
**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) |
|
).to("cuda") |
|
|
|
original_state_dict = {k: v.cpu().clone() for k, v in pipeline.unet.state_dict().items() if "motion_modules." not in k} |
|
pipeline.unet = convert_lcm_lora(pipeline.unet, self.lcm_lora_path, spatial_lora_slider) |
|
|
|
pipeline.to("cuda") |
|
|
|
if seed_textbox != -1 and seed_textbox != "": |
|
torch.manual_seed(int(seed_textbox)) |
|
else: |
|
torch.seed() |
|
seed = torch.initial_seed() |
|
|
|
with torch.autocast("cuda"): |
|
sample = pipeline( |
|
prompt_textbox, |
|
negative_prompt=negative_prompt_textbox, |
|
num_inference_steps=sample_step_slider, |
|
guidance_scale=cfg_scale_slider, |
|
width=width_slider, |
|
height=height_slider, |
|
video_length=length_slider, |
|
).videos |
|
|
|
pipeline.unet.load_state_dict(original_state_dict,strict=False) |
|
del original_state_dict |
|
|
|
save_sample_path = os.path.join( |
|
self.savedir_sample, f"{sample_idx}.mp4") |
|
save_videos_grid(sample, save_sample_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(save_sample_path, "rb") as video_file: |
|
video_base64 = base64.b64encode(video_file.read()).decode('utf-8') |
|
|
|
|
|
video_data_uri = 'data:video/mp4;base64,' + video_base64 |
|
|
|
|
|
|
|
os.remove(video_path) |
|
|
|
return video_data_uri |
|
|
|
|
|
controller = AnimateController() |
|
|
|
controller.update_stable_diffusion("stable-diffusion-v1-5") |
|
controller.update_motion_module("sd15_t2v_beta_motion.ckpt") |
|
controller.update_base_model("realistic2.safetensors") |
|
|
|
|
|
def ui(): |
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769) |
|
Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)<br> |
|
[arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm) |
|
""" |
|
|
|
''' |
|
Important Notes: |
|
1. The generation speed is around few seconds. There is delay in the space. |
|
2. Increase the sampling step and cfg and set proper negative prompt if you want more fancy videos. |
|
''' |
|
) |
|
with gr.Column(): |
|
with gr.Row(): |
|
secret_token = gr.Text(label='Secret Token', max_lines=1) |
|
|
|
base_model_dropdown = gr.Dropdown( |
|
label="Select base Dreambooth model (required)", |
|
choices=controller.personalized_model_list, |
|
interactive=True, |
|
value="realistic2.safetensors" |
|
) |
|
base_model_dropdown.change(fn=controller.update_base_model, inputs=[ |
|
base_model_dropdown], outputs=[base_model_dropdown]) |
|
|
|
lora_model_dropdown = gr.Dropdown( |
|
label="Select LoRA model (optional)", |
|
choices=["none",], |
|
value="none", |
|
interactive=True, |
|
) |
|
lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[ |
|
lora_model_dropdown], outputs=[lora_model_dropdown]) |
|
|
|
lora_alpha_slider = gr.Slider( |
|
label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) |
|
spatial_lora_slider = gr.Slider( |
|
label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True) |
|
|
|
personalized_refresh_button = gr.Button( |
|
value="\U0001F503", elem_classes="toolbutton") |
|
|
|
def update_personalized_model(): |
|
controller.refresh_personalized_model() |
|
return [ |
|
gr.Dropdown.update( |
|
choices=controller.personalized_model_list), |
|
gr.Dropdown.update( |
|
choices=["none"] + controller.personalized_model_list) |
|
] |
|
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[ |
|
base_model_dropdown, lora_model_dropdown]) |
|
|
|
with gr.Column(variant="panel"): |
|
gr.Markdown( |
|
""" |
|
### 2. Configs for AnimateLCM. |
|
""" |
|
) |
|
|
|
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="a boy holding a rabbit") |
|
negative_prompt_textbox = gr.Textbox( |
|
label="Negative prompt", lines=2, value="bad quality") |
|
|
|
with gr.Row().style(equal_height=False): |
|
with gr.Column(): |
|
with gr.Row(): |
|
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list( |
|
scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) |
|
sample_step_slider = gr.Slider( |
|
label="Sampling steps", value=6, minimum=1, maximum=25, step=1) |
|
|
|
width_slider = gr.Slider( |
|
label="Width", value=512, minimum=256, maximum=1024, step=64) |
|
height_slider = gr.Slider( |
|
label="Height", value=512, minimum=256, maximum=1024, step=64) |
|
length_slider = gr.Slider( |
|
label="Animation length", value=16, minimum=12, maximum=20, step=1) |
|
cfg_scale_slider = gr.Slider( |
|
label="CFG Scale", value=1.5, minimum=1, maximum=2) |
|
|
|
with gr.Row(): |
|
seed_textbox = gr.Textbox(label="Seed", value=-1) |
|
seed_button = gr.Button( |
|
value="\U0001F3B2", elem_classes="toolbutton") |
|
seed_button.click(fn=lambda: gr.Textbox.update( |
|
value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) |
|
|
|
generate_button = gr.Button( |
|
value="Generate", variant='primary') |
|
|
|
result_video_base_64 = gr.Text() |
|
|
|
generate_button.click( |
|
fn=controller.animate, |
|
inputs=[ |
|
secret_token, |
|
lora_alpha_slider, |
|
spatial_lora_slider, |
|
prompt_textbox, |
|
negative_prompt_textbox, |
|
sampler_dropdown, |
|
sample_step_slider, |
|
width_slider, |
|
length_slider, |
|
height_slider, |
|
cfg_scale_slider, |
|
seed_textbox, |
|
], |
|
outputs=[result_video_base64] |
|
) |
|
|
|
return demo |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo = ui() |
|
|
|
|
|
demo.queue(api_open=True) |
|
demo.launch() |
|
|