StableVideo / app.py
wchai's picture
convert to CPU
d515943
import os
import shutil
import cv2
import einops
import gradio as gr
import numpy as np
import torch
import torch.optim as optim
import random
import imageio
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import time
import scipy.interpolate
from tqdm import tqdm
from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from annotator.midas import MidasDetector
from cldm.model import create_model, load_state_dict
from ldm.models.diffusion.ddim import DDIMSampler
from stablevideo.atlas_data import AtlasData
from stablevideo.atlas_utils import get_grid_indices, get_atlas_bounding_box
from stablevideo.aggnet import AGGNet
class StableVideo:
def __init__(self, base_cfg, canny_model_cfg, depth_model_cfg, save_memory=False):
self.base_cfg = base_cfg
self.canny_model_cfg = canny_model_cfg
self.depth_model_cfg = depth_model_cfg
self.img2img_model = None
self.canny_model = None
self.depth_model = None
self.b_atlas = None
self.f_atlas = None
self.data = None
self.crops = None
self.save_memory = save_memory
def load_canny_model(
self,
base_cfg='ckpt/cldm_v15.yaml',
canny_model_cfg='ckpt/control_sd15_canny.pth',
):
self.apply_canny = CannyDetector()
canny_model = create_model(base_cfg).cpu()
canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='cpu'), strict=False)
self.canny_ddim_sampler = DDIMSampler(canny_model)
self.canny_model = canny_model
def load_depth_model(
self,
base_cfg='ckpt/cldm_v15.yaml',
depth_model_cfg='ckpt/control_sd15_depth.pth',
):
self.apply_midas = MidasDetector()
depth_model = create_model(base_cfg).cpu()
depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='cpu'), strict=False)
self.depth_ddim_sampler = DDIMSampler(depth_model)
self.depth_model = depth_model
def load_video(self, video_name):
self.data = AtlasData(video_name)
save_name = f"data/{video_name}/{video_name}.mp4"
if not os.path.exists(save_name):
imageio.mimwrite(save_name, self.data.original_video.cpu().permute(0, 2, 3, 1))
print("original video saved.")
toIMG = transforms.ToPILImage()
self.f_atlas_origin = toIMG(self.data.cropped_foreground_atlas[0])
self.b_atlas_origin = toIMG(self.data.background_grid_atlas[0])
return save_name, self.f_atlas_origin, self.b_atlas_origin
@torch.no_grad()
def depth_edit(self, input_image=None,
prompt="",
a_prompt="best quality, extremely detailed",
n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
image_resolution=512,
detect_resolution=384,
ddim_steps=20,
scale=9,
seed=-1,
eta=0,
num_samples=1):
size = input_image.size
model = self.depth_model
ddim_sampler = self.depth_ddim_sampler
apply_midas = self.apply_midas
input_image = np.array(input_image)
input_image = HWC3(input_image)
detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
detected_map = HWC3(detected_map)
img = resize_image(input_image, image_resolution)
H, W, C = img.shape
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
control = torch.from_numpy(detected_map.copy()).float() / 255.0
control = torch.stack([control for _ in range(1)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(num_samples)]
self.b_atlas = Image.fromarray(results[0]).resize(size)
return self.b_atlas
@torch.no_grad()
def edit_background(self, *args, **kwargs):
self.depth_model = self.depth_model
input_image = self.b_atlas_origin
self.depth_edit(input_image, *args, **kwargs)
if self.save_memory:
self.depth_model = self.depth_model.cpu()
return self.b_atlas
@torch.no_grad()
def advanced_edit_foreground(self,
keyframes="0",
res=2000,
prompt="",
a_prompt="best quality, extremely detailed",
n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
image_resolution=512,
low_threshold=100,
high_threshold=200,
ddim_steps=20,
s=0.9,
scale=9,
seed=-1,
eta=0,
if_net=False,
num_samples=1):
self.canny_model = self.canny_model
keyframes = [int(x) for x in keyframes.split(",")]
if self.data is None:
raise ValueError("Please load video first")
self.crops = self.data.get_global_crops_multi(keyframes, res)
n_keyframes = len(keyframes)
indices = get_grid_indices(0, 0, res, res)
f_atlas = torch.zeros(size=(n_keyframes, res, res, 3,)).to("cuda")
img_list = [transforms.ToPILImage()(i[0]) for i in self.crops['original_foreground_crops']]
result_list = []
# initial setting
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
self.canny_ddim_sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=eta, verbose=False)
c_crossattn = [self.canny_model.get_learned_conditioning([prompt + ', ' + a_prompt])]
uc_crossattn = [self.canny_model.get_learned_conditioning([n_prompt])]
for i in range(n_keyframes):
# get current keyframe
current_img = img_list[i]
img = resize_image(HWC3(np.array(current_img)), image_resolution)
H, W, C = img.shape
shape = (4, H // 8, W // 8)
# get canny control
detected_map = self.apply_canny(img, low_threshold, high_threshold)
detected_map = HWC3(detected_map)
control = torch.from_numpy(detected_map.copy()).float() / 255.0
control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone()
cond = {"c_concat": [control], "c_crossattn": c_crossattn}
un_cond = {"c_concat": [control], "c_crossattn": uc_crossattn}
# if not the key frame, calculate the mapping from last atlas
if i == 0:
latent = torch.randn((1, 4, H // 8, W // 8))
samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond,
x_T=latent)
else:
last_atlas = f_atlas[i-1:i].permute(0, 3, 2, 1)
mapped_img = F.grid_sample(last_atlas, self.crops['foreground_uvs'][i].reshape(1, -1, 1, 2), mode="bilinear", align_corners=self.data.config["align_corners"]).clamp(min=0.0, max=1.0).reshape((3, current_img.size[1], current_img.size[0]))
mapped_img = transforms.ToPILImage()(mapped_img)
mapped_img = mapped_img.resize((W, H))
mapped_img = np.array(mapped_img).astype(np.float32) / 255.0
mapped_img = mapped_img[None].transpose(0, 3, 1, 2)
mapped_img = torch.from_numpy(mapped_img)
mapped_img = 2. * mapped_img - 1.
latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img))
t_enc = int(ddim_steps * s)
latent = self.canny_ddim_sampler.stochastic_encode(latent, torch.tensor([t_enc]).to("cuda"))
samples = self.canny_ddim_sampler.decode(x_latent=latent,
cond=cond,
t_start=t_enc,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
x_samples = self.canny_model.decode_first_stage(samples)
result = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
result = Image.fromarray(result[0])
result = result.resize(current_img.size)
result = transforms.ToTensor()(result)
# times alpha
alpha = self.crops['foreground_alpha'][i][0].cpu()
result = alpha * result
# buffer for training
result_copy = result.clone()
result_copy.requires_grad = True
result_list.append(result_copy)
# map to atlas
uv = (self.crops['foreground_uvs'][i].reshape(-1, 2) * 0.5 + 0.5) * res
for c in range(3):
interpolated = scipy.interpolate.griddata(
points=uv.cpu().numpy(),
values=result[c].reshape(-1, 1).cpu().numpy(),
xi=indices.reshape(-1, 2).cpu().numpy(),
method="linear",
).reshape(res, res)
interpolated = torch.from_numpy(interpolated).float()
interpolated[interpolated.isnan()] = 0.0
f_atlas[i, :, :, c] = interpolated
f_atlas = f_atlas.permute(0, 3, 2, 1)
# aggregate via simple median as begining
agg_atlas, _ = torch.median(f_atlas, dim=0)
if if_net == True:
#####################################
# aggregate net #
#####################################
lr, n_epoch = 1e-3, 500
agg_net = AGGNet()
loss_fn = nn.L1Loss()
optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9)
for _ in range(n_epoch):
loss = 0.
for i in range(n_keyframes):
e_img = result_list[i]
temp_agg_atlas = agg_net(agg_atlas)
rec_img = F.grid_sample(temp_agg_atlas[None],
self.crops['foreground_uvs'][i].reshape(1, -1, 1, 2),
mode="bilinear",
align_corners=self.data.config["align_corners"])
rec_img = rec_img.clamp(min=0.0, max=1.0).reshape(e_img.shape)
loss += loss_fn(rec_img, e_img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
agg_atlas = agg_net(agg_atlas)
#####################################
agg_atlas, _ = get_atlas_bounding_box(self.data.mask_boundaries, agg_atlas, self.data.foreground_all_uvs)
self.f_atlas = transforms.ToPILImage()(agg_atlas)
if self.save_memory:
self.canny_model = self.canny_model.cpu()
return self.f_atlas
@torch.no_grad()
def render(self, f_atlas, b_atlas):
# foreground
if f_atlas == None:
f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
else:
f_atlas, mask = f_atlas["image"], f_atlas["mask"]
f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0)
f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0)
mask = transforms.ToTensor()(mask).unsqueeze(0)
if f_atlas.shape != mask.shape:
print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape))
mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]]
f_atlas = f_atlas * (1 - mask) + f_atlas_origin * mask
f_atlas = torch.nn.functional.pad(
f_atlas,
pad=(
self.data.foreground_atlas_bbox[1],
self.data.foreground_grid_atlas.shape[-1] - (self.data.foreground_atlas_bbox[1] + self.data.foreground_atlas_bbox[3]),
self.data.foreground_atlas_bbox[0],
self.data.foreground_grid_atlas.shape[-2] - (self.data.foreground_atlas_bbox[0] + self.data.foreground_atlas_bbox[2]),
),
mode="replicate",
)
foreground_edit = F.grid_sample(
f_atlas, self.data.scaled_foreground_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
).clamp(min=0.0, max=1.0)
foreground_edit = foreground_edit.squeeze().t() # shape (batch, 3)
foreground_edit = (
foreground_edit.reshape(self.data.config["maximum_number_of_frames"], self.data.config["resy"], self.data.config["resx"], 3)
.permute(0, 3, 1, 2)
.clamp(min=0.0, max=1.0)
)
# background
if b_atlas == None:
b_atlas = self.b_atlas_origin
b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0)
background_edit = F.grid_sample(
b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"]
).clamp(min=0.0, max=1.0)
background_edit = background_edit.squeeze().t() # shape (batch, 3)
background_edit = (
background_edit.reshape(self.data.config["maximum_number_of_frames"], self.data.config["resy"], self.data.config["resx"], 3)
.permute(0, 3, 1, 2)
.clamp(min=0.0, max=1.0)
)
output_video = (
self.data.all_alpha * foreground_edit
+ (1 - self.data.all_alpha) * background_edit
)
id = time.time()
os.mkdir(f"log/{id}")
save_name = f"log/{id}/video.mp4"
imageio.mimwrite(save_name, (255 * output_video.detach().cpu()).to(torch.uint8).permute(0, 2, 3, 1))
return save_name
if __name__ == '__main__':
stablevideo = StableVideo(base_cfg="ckpt/cldm_v15.yaml",
canny_model_cfg="ckpt/control_sd15_canny.pth",
depth_model_cfg="ckpt/control_sd15_depth.pth",
save_memory=True)
stablevideo.load_canny_model()
stablevideo.load_depth_model()
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## StableVideo")
with gr.Row():
with gr.Column():
original_video = gr.Video(label="Original Video", interactive=False)
with gr.Row():
foreground_atlas = gr.Image(label="Foreground Atlas", type="pil")
background_atlas = gr.Image(label="Background Atlas", type="pil")
gr.Markdown("### Step 1. select one example video and click **Load Video** buttom and wait for 10 sec.")
avail_video = [f.name for f in os.scandir("data") if f.is_dir()]
video_name = gr.Radio(choices=avail_video,
label="Select Example Videos",
value="car-turn")
load_video_button = gr.Button("Load Video")
gr.Markdown("### Step 2. write text prompt and advanced options for background and foreground.")
with gr.Row():
f_prompt = gr.Textbox(label="Foreground Prompt", value="a picture of an orange suv")
b_prompt = gr.Textbox(label="Background Prompt", value="winter scene, snowy scene, beautiful snow")
with gr.Row():
with gr.Accordion("Advanced Foreground Options", open=False):
adv_keyframes = gr.Textbox(label="keyframe", value="20, 40, 60")
adv_atlas_resolution = gr.Slider(label="Atlas Resolution", minimum=1000, maximum=3000, value=2000, step=100)
adv_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
adv_low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
adv_high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
adv_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
adv_s = gr.Slider(label="Noise Scale", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
adv_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=15.0, value=9.0, step=0.1)
adv_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
adv_eta = gr.Number(label="eta (DDIM)", value=0.0)
adv_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, no background')
adv_n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
adv_if_net = gr.gradio.Checkbox(label="if use agg net", value=False)
with gr.Accordion("Background Options", open=False):
b_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
b_detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=512, step=1)
b_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
b_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
b_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
b_eta = gr.Number(label="eta (DDIM)", value=0.0)
b_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
b_n_prompt = gr.Textbox(label="Negative Prompt",
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
gr.Markdown("### Step 3. edit each one and render.")
with gr.Row():
f_advance_run_button = gr.Button("Advanced Edit Foreground (slower, better)")
b_run_button = gr.Button("Edit Background")
run_button = gr.Button("Render")
with gr.Column():
output_video = gr.Video(label="Output Video", interactive=False)
# output_foreground_atlas = gr.Image(label="Output Foreground Atlas", type="pil", interactive=False)
output_foreground_atlas = gr.ImageMask(label="Editable Output Foreground Atlas", type="pil", tool="sketch", interactive=True)
output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False)
# edit param
f_adv_edit_param = [adv_keyframes,
adv_atlas_resolution,
f_prompt,
adv_a_prompt,
adv_n_prompt,
adv_image_resolution,
adv_low_threshold,
adv_high_threshold,
adv_ddim_steps,
adv_s,
adv_scale,
adv_seed,
adv_eta,
adv_if_net]
b_edit_param = [b_prompt,
b_a_prompt,
b_n_prompt,
b_image_resolution,
b_detect_resolution,
b_ddim_steps,
b_scale,
b_seed,
b_eta]
# action
load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas])
f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas])
b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas])
run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video])
block.launch()