Spaces:
Runtime error
Runtime error
import os | |
import shutil | |
import cv2 | |
import einops | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.optim as optim | |
import random | |
import imageio | |
from torchvision import transforms | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from PIL import Image | |
import time | |
import scipy.interpolate | |
from tqdm import tqdm | |
from pytorch_lightning import seed_everything | |
from annotator.util import resize_image, HWC3 | |
from annotator.canny import CannyDetector | |
from annotator.midas import MidasDetector | |
from cldm.model import create_model, load_state_dict | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from stablevideo.atlas_data import AtlasData | |
from stablevideo.atlas_utils import get_grid_indices, get_atlas_bounding_box | |
from stablevideo.aggnet import AGGNet | |
class StableVideo: | |
def __init__(self, base_cfg, canny_model_cfg, depth_model_cfg, save_memory=False): | |
self.base_cfg = base_cfg | |
self.canny_model_cfg = canny_model_cfg | |
self.depth_model_cfg = depth_model_cfg | |
self.img2img_model = None | |
self.canny_model = None | |
self.depth_model = None | |
self.b_atlas = None | |
self.f_atlas = None | |
self.data = None | |
self.crops = None | |
self.save_memory = save_memory | |
def load_canny_model( | |
self, | |
base_cfg='ckpt/cldm_v15.yaml', | |
canny_model_cfg='ckpt/control_sd15_canny.pth', | |
): | |
self.apply_canny = CannyDetector() | |
canny_model = create_model(base_cfg).cpu() | |
canny_model.load_state_dict(load_state_dict(canny_model_cfg, location='cpu'), strict=False) | |
self.canny_ddim_sampler = DDIMSampler(canny_model) | |
self.canny_model = canny_model | |
def load_depth_model( | |
self, | |
base_cfg='ckpt/cldm_v15.yaml', | |
depth_model_cfg='ckpt/control_sd15_depth.pth', | |
): | |
self.apply_midas = MidasDetector() | |
depth_model = create_model(base_cfg).cpu() | |
depth_model.load_state_dict(load_state_dict(depth_model_cfg, location='cpu'), strict=False) | |
self.depth_ddim_sampler = DDIMSampler(depth_model) | |
self.depth_model = depth_model | |
def load_video(self, video_name): | |
self.data = AtlasData(video_name) | |
save_name = f"data/{video_name}/{video_name}.mp4" | |
if not os.path.exists(save_name): | |
imageio.mimwrite(save_name, self.data.original_video.cpu().permute(0, 2, 3, 1)) | |
print("original video saved.") | |
toIMG = transforms.ToPILImage() | |
self.f_atlas_origin = toIMG(self.data.cropped_foreground_atlas[0]) | |
self.b_atlas_origin = toIMG(self.data.background_grid_atlas[0]) | |
return save_name, self.f_atlas_origin, self.b_atlas_origin | |
def depth_edit(self, input_image=None, | |
prompt="", | |
a_prompt="best quality, extremely detailed", | |
n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", | |
image_resolution=512, | |
detect_resolution=384, | |
ddim_steps=20, | |
scale=9, | |
seed=-1, | |
eta=0, | |
num_samples=1): | |
size = input_image.size | |
model = self.depth_model | |
ddim_sampler = self.depth_ddim_sampler | |
apply_midas = self.apply_midas | |
input_image = np.array(input_image) | |
input_image = HWC3(input_image) | |
detected_map, _ = apply_midas(resize_image(input_image, detect_resolution)) | |
detected_map = HWC3(detected_map) | |
img = resize_image(input_image, image_resolution) | |
H, W, C = img.shape | |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) | |
control = torch.from_numpy(detected_map.copy()).float() / 255.0 | |
control = torch.stack([control for _ in range(1)], dim=0) | |
control = einops.rearrange(control, 'b h w c -> b c h w').clone() | |
if seed == -1: | |
seed = random.randint(0, 65535) | |
seed_everything(seed) | |
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} | |
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} | |
shape = (4, H // 8, W // 8) | |
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, | |
shape, cond, verbose=False, eta=eta, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=un_cond) | |
x_samples = model.decode_first_stage(samples) | |
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | |
results = [x_samples[i] for i in range(num_samples)] | |
self.b_atlas = Image.fromarray(results[0]).resize(size) | |
return self.b_atlas | |
def edit_background(self, *args, **kwargs): | |
self.depth_model = self.depth_model | |
input_image = self.b_atlas_origin | |
self.depth_edit(input_image, *args, **kwargs) | |
if self.save_memory: | |
self.depth_model = self.depth_model.cpu() | |
return self.b_atlas | |
def advanced_edit_foreground(self, | |
keyframes="0", | |
res=2000, | |
prompt="", | |
a_prompt="best quality, extremely detailed", | |
n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", | |
image_resolution=512, | |
low_threshold=100, | |
high_threshold=200, | |
ddim_steps=20, | |
s=0.9, | |
scale=9, | |
seed=-1, | |
eta=0, | |
if_net=False, | |
num_samples=1): | |
self.canny_model = self.canny_model | |
keyframes = [int(x) for x in keyframes.split(",")] | |
if self.data is None: | |
raise ValueError("Please load video first") | |
self.crops = self.data.get_global_crops_multi(keyframes, res) | |
n_keyframes = len(keyframes) | |
indices = get_grid_indices(0, 0, res, res) | |
f_atlas = torch.zeros(size=(n_keyframes, res, res, 3,)).to("cuda") | |
img_list = [transforms.ToPILImage()(i[0]) for i in self.crops['original_foreground_crops']] | |
result_list = [] | |
# initial setting | |
if seed == -1: | |
seed = random.randint(0, 65535) | |
seed_everything(seed) | |
self.canny_ddim_sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=eta, verbose=False) | |
c_crossattn = [self.canny_model.get_learned_conditioning([prompt + ', ' + a_prompt])] | |
uc_crossattn = [self.canny_model.get_learned_conditioning([n_prompt])] | |
for i in range(n_keyframes): | |
# get current keyframe | |
current_img = img_list[i] | |
img = resize_image(HWC3(np.array(current_img)), image_resolution) | |
H, W, C = img.shape | |
shape = (4, H // 8, W // 8) | |
# get canny control | |
detected_map = self.apply_canny(img, low_threshold, high_threshold) | |
detected_map = HWC3(detected_map) | |
control = torch.from_numpy(detected_map.copy()).float() / 255.0 | |
control = einops.rearrange(control.unsqueeze(0), 'b h w c -> b c h w').clone() | |
cond = {"c_concat": [control], "c_crossattn": c_crossattn} | |
un_cond = {"c_concat": [control], "c_crossattn": uc_crossattn} | |
# if not the key frame, calculate the mapping from last atlas | |
if i == 0: | |
latent = torch.randn((1, 4, H // 8, W // 8)) | |
samples, _ = self.canny_ddim_sampler.sample(ddim_steps, num_samples, | |
shape, cond, verbose=False, eta=eta, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=un_cond, | |
x_T=latent) | |
else: | |
last_atlas = f_atlas[i-1:i].permute(0, 3, 2, 1) | |
mapped_img = F.grid_sample(last_atlas, self.crops['foreground_uvs'][i].reshape(1, -1, 1, 2), mode="bilinear", align_corners=self.data.config["align_corners"]).clamp(min=0.0, max=1.0).reshape((3, current_img.size[1], current_img.size[0])) | |
mapped_img = transforms.ToPILImage()(mapped_img) | |
mapped_img = mapped_img.resize((W, H)) | |
mapped_img = np.array(mapped_img).astype(np.float32) / 255.0 | |
mapped_img = mapped_img[None].transpose(0, 3, 1, 2) | |
mapped_img = torch.from_numpy(mapped_img) | |
mapped_img = 2. * mapped_img - 1. | |
latent = self.canny_model.get_first_stage_encoding(self.canny_model.encode_first_stage(mapped_img)) | |
t_enc = int(ddim_steps * s) | |
latent = self.canny_ddim_sampler.stochastic_encode(latent, torch.tensor([t_enc]).to("cuda")) | |
samples = self.canny_ddim_sampler.decode(x_latent=latent, | |
cond=cond, | |
t_start=t_enc, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=un_cond) | |
x_samples = self.canny_model.decode_first_stage(samples) | |
result = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | |
result = Image.fromarray(result[0]) | |
result = result.resize(current_img.size) | |
result = transforms.ToTensor()(result) | |
# times alpha | |
alpha = self.crops['foreground_alpha'][i][0].cpu() | |
result = alpha * result | |
# buffer for training | |
result_copy = result.clone() | |
result_copy.requires_grad = True | |
result_list.append(result_copy) | |
# map to atlas | |
uv = (self.crops['foreground_uvs'][i].reshape(-1, 2) * 0.5 + 0.5) * res | |
for c in range(3): | |
interpolated = scipy.interpolate.griddata( | |
points=uv.cpu().numpy(), | |
values=result[c].reshape(-1, 1).cpu().numpy(), | |
xi=indices.reshape(-1, 2).cpu().numpy(), | |
method="linear", | |
).reshape(res, res) | |
interpolated = torch.from_numpy(interpolated).float() | |
interpolated[interpolated.isnan()] = 0.0 | |
f_atlas[i, :, :, c] = interpolated | |
f_atlas = f_atlas.permute(0, 3, 2, 1) | |
# aggregate via simple median as begining | |
agg_atlas, _ = torch.median(f_atlas, dim=0) | |
if if_net == True: | |
##################################### | |
# aggregate net # | |
##################################### | |
lr, n_epoch = 1e-3, 500 | |
agg_net = AGGNet() | |
loss_fn = nn.L1Loss() | |
optimizer = optim.SGD(agg_net.parameters(), lr=lr, momentum=0.9) | |
for _ in range(n_epoch): | |
loss = 0. | |
for i in range(n_keyframes): | |
e_img = result_list[i] | |
temp_agg_atlas = agg_net(agg_atlas) | |
rec_img = F.grid_sample(temp_agg_atlas[None], | |
self.crops['foreground_uvs'][i].reshape(1, -1, 1, 2), | |
mode="bilinear", | |
align_corners=self.data.config["align_corners"]) | |
rec_img = rec_img.clamp(min=0.0, max=1.0).reshape(e_img.shape) | |
loss += loss_fn(rec_img, e_img) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
agg_atlas = agg_net(agg_atlas) | |
##################################### | |
agg_atlas, _ = get_atlas_bounding_box(self.data.mask_boundaries, agg_atlas, self.data.foreground_all_uvs) | |
self.f_atlas = transforms.ToPILImage()(agg_atlas) | |
if self.save_memory: | |
self.canny_model = self.canny_model.cpu() | |
return self.f_atlas | |
def render(self, f_atlas, b_atlas): | |
# foreground | |
if f_atlas == None: | |
f_atlas = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0) | |
else: | |
f_atlas, mask = f_atlas["image"], f_atlas["mask"] | |
f_atlas_origin = transforms.ToTensor()(self.f_atlas_origin).unsqueeze(0) | |
f_atlas = transforms.ToTensor()(f_atlas).unsqueeze(0) | |
mask = transforms.ToTensor()(mask).unsqueeze(0) | |
if f_atlas.shape != mask.shape: | |
print("Warning: truncating mask to atlas shape {}".format(f_atlas.shape)) | |
mask = mask[:f_atlas.shape[0], :f_atlas.shape[1], :f_atlas.shape[2], :f_atlas.shape[3]] | |
f_atlas = f_atlas * (1 - mask) + f_atlas_origin * mask | |
f_atlas = torch.nn.functional.pad( | |
f_atlas, | |
pad=( | |
self.data.foreground_atlas_bbox[1], | |
self.data.foreground_grid_atlas.shape[-1] - (self.data.foreground_atlas_bbox[1] + self.data.foreground_atlas_bbox[3]), | |
self.data.foreground_atlas_bbox[0], | |
self.data.foreground_grid_atlas.shape[-2] - (self.data.foreground_atlas_bbox[0] + self.data.foreground_atlas_bbox[2]), | |
), | |
mode="replicate", | |
) | |
foreground_edit = F.grid_sample( | |
f_atlas, self.data.scaled_foreground_uvs, mode="bilinear", align_corners=self.data.config["align_corners"] | |
).clamp(min=0.0, max=1.0) | |
foreground_edit = foreground_edit.squeeze().t() # shape (batch, 3) | |
foreground_edit = ( | |
foreground_edit.reshape(self.data.config["maximum_number_of_frames"], self.data.config["resy"], self.data.config["resx"], 3) | |
.permute(0, 3, 1, 2) | |
.clamp(min=0.0, max=1.0) | |
) | |
# background | |
if b_atlas == None: | |
b_atlas = self.b_atlas_origin | |
b_atlas = transforms.ToTensor()(b_atlas).unsqueeze(0) | |
background_edit = F.grid_sample( | |
b_atlas, self.data.scaled_background_uvs, mode="bilinear", align_corners=self.data.config["align_corners"] | |
).clamp(min=0.0, max=1.0) | |
background_edit = background_edit.squeeze().t() # shape (batch, 3) | |
background_edit = ( | |
background_edit.reshape(self.data.config["maximum_number_of_frames"], self.data.config["resy"], self.data.config["resx"], 3) | |
.permute(0, 3, 1, 2) | |
.clamp(min=0.0, max=1.0) | |
) | |
output_video = ( | |
self.data.all_alpha * foreground_edit | |
+ (1 - self.data.all_alpha) * background_edit | |
) | |
id = time.time() | |
os.mkdir(f"log/{id}") | |
save_name = f"log/{id}/video.mp4" | |
imageio.mimwrite(save_name, (255 * output_video.detach().cpu()).to(torch.uint8).permute(0, 2, 3, 1)) | |
return save_name | |
if __name__ == '__main__': | |
stablevideo = StableVideo(base_cfg="ckpt/cldm_v15.yaml", | |
canny_model_cfg="ckpt/control_sd15_canny.pth", | |
depth_model_cfg="ckpt/control_sd15_depth.pth", | |
save_memory=True) | |
stablevideo.load_canny_model() | |
stablevideo.load_depth_model() | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
gr.Markdown("## StableVideo") | |
with gr.Row(): | |
with gr.Column(): | |
original_video = gr.Video(label="Original Video", interactive=False) | |
with gr.Row(): | |
foreground_atlas = gr.Image(label="Foreground Atlas", type="pil") | |
background_atlas = gr.Image(label="Background Atlas", type="pil") | |
gr.Markdown("### Step 1. select one example video and click **Load Video** buttom and wait for 10 sec.") | |
avail_video = [f.name for f in os.scandir("data") if f.is_dir()] | |
video_name = gr.Radio(choices=avail_video, | |
label="Select Example Videos", | |
value="car-turn") | |
load_video_button = gr.Button("Load Video") | |
gr.Markdown("### Step 2. write text prompt and advanced options for background and foreground.") | |
with gr.Row(): | |
f_prompt = gr.Textbox(label="Foreground Prompt", value="a picture of an orange suv") | |
b_prompt = gr.Textbox(label="Background Prompt", value="winter scene, snowy scene, beautiful snow") | |
with gr.Row(): | |
with gr.Accordion("Advanced Foreground Options", open=False): | |
adv_keyframes = gr.Textbox(label="keyframe", value="20, 40, 60") | |
adv_atlas_resolution = gr.Slider(label="Atlas Resolution", minimum=1000, maximum=3000, value=2000, step=100) | |
adv_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256) | |
adv_low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1) | |
adv_high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1) | |
adv_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
adv_s = gr.Slider(label="Noise Scale", minimum=0.0, maximum=1.0, value=0.8, step=0.01) | |
adv_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=15.0, value=9.0, step=0.1) | |
adv_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
adv_eta = gr.Number(label="eta (DDIM)", value=0.0) | |
adv_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed, no background') | |
adv_n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') | |
adv_if_net = gr.gradio.Checkbox(label="if use agg net", value=False) | |
with gr.Accordion("Background Options", open=False): | |
b_image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256) | |
b_detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=512, step=1) | |
b_ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
b_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
b_seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
b_eta = gr.Number(label="eta (DDIM)", value=0.0) | |
b_a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') | |
b_n_prompt = gr.Textbox(label="Negative Prompt", | |
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') | |
gr.Markdown("### Step 3. edit each one and render.") | |
with gr.Row(): | |
f_advance_run_button = gr.Button("Advanced Edit Foreground (slower, better)") | |
b_run_button = gr.Button("Edit Background") | |
run_button = gr.Button("Render") | |
with gr.Column(): | |
output_video = gr.Video(label="Output Video", interactive=False) | |
# output_foreground_atlas = gr.Image(label="Output Foreground Atlas", type="pil", interactive=False) | |
output_foreground_atlas = gr.ImageMask(label="Editable Output Foreground Atlas", type="pil", tool="sketch", interactive=True) | |
output_background_atlas = gr.Image(label="Output Background Atlas", type="pil", interactive=False) | |
# edit param | |
f_adv_edit_param = [adv_keyframes, | |
adv_atlas_resolution, | |
f_prompt, | |
adv_a_prompt, | |
adv_n_prompt, | |
adv_image_resolution, | |
adv_low_threshold, | |
adv_high_threshold, | |
adv_ddim_steps, | |
adv_s, | |
adv_scale, | |
adv_seed, | |
adv_eta, | |
adv_if_net] | |
b_edit_param = [b_prompt, | |
b_a_prompt, | |
b_n_prompt, | |
b_image_resolution, | |
b_detect_resolution, | |
b_ddim_steps, | |
b_scale, | |
b_seed, | |
b_eta] | |
# action | |
load_video_button.click(fn=stablevideo.load_video, inputs=video_name, outputs=[original_video, foreground_atlas, background_atlas]) | |
f_advance_run_button.click(fn=stablevideo.advanced_edit_foreground, inputs=f_adv_edit_param, outputs=[output_foreground_atlas]) | |
b_run_button.click(fn=stablevideo.edit_background, inputs=b_edit_param, outputs=[output_background_atlas]) | |
run_button.click(fn=stablevideo.render, inputs=[output_foreground_atlas, output_background_atlas], outputs=[output_video]) | |
block.launch() | |