Spaces:
Sleeping
Sleeping
import gradio as gr | |
import logging | |
import os | |
import random | |
import tempfile | |
import time | |
import spaces | |
from easydict import EasyDict | |
import numpy as np | |
import torch | |
from dav.pipelines import DAVPipeline | |
from dav.models import UNetSpatioTemporalRopeConditionModel | |
from diffusers import AutoencoderKLTemporalDecoder, FlowMatchEulerDiscreteScheduler | |
from dav.utils import img_utils | |
def seed_all(seed: int = 0): | |
""" | |
Set random seeds for reproducibility. | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
# Load models once to avoid reloading on every inference | |
def load_models(model_base, device): | |
vae = AutoencoderKLTemporalDecoder.from_pretrained(model_base, subfolder="vae") | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
model_base, subfolder="scheduler" | |
) | |
unet = UNetSpatioTemporalRopeConditionModel.from_pretrained( | |
model_base, subfolder="unet" | |
) | |
unet_interp = UNetSpatioTemporalRopeConditionModel.from_pretrained( | |
model_base, subfolder="unet_interp" | |
) | |
pipe = DAVPipeline( | |
vae=vae, | |
unet=unet, | |
unet_interp=unet_interp, | |
scheduler=scheduler, | |
) | |
pipe = pipe.to(device) | |
return pipe | |
# Load models at startup | |
MODEL_BASE = "hhyangcs/depth-any-video" | |
DEVICE_TYPE = "cuda" | |
DEVICE = torch.device(DEVICE_TYPE) | |
pipe = load_models(MODEL_BASE, DEVICE) | |
def depth_any_video( | |
file, | |
denoise_steps=3, | |
num_frames=32, | |
decode_chunk_size=16, | |
num_interp_frames=16, | |
num_overlap_frames=6, | |
max_resolution=1024, | |
): | |
""" | |
Perform depth estimation on the uploaded video/image. | |
""" | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
# Save the uploaded file | |
input_path = os.path.join(tmp_dir, file.name) | |
with open(input_path, "wb") as f: | |
f.write(file.read()) | |
# Set up output directory | |
output_dir = os.path.join(tmp_dir, "output") | |
os.makedirs(output_dir, exist_ok=True) | |
# Prepare configuration | |
cfg = EasyDict( | |
{ | |
"model_base": MODEL_BASE, | |
"data_path": input_path, | |
"output_dir": output_dir, | |
"denoise_steps": denoise_steps, | |
"num_frames": num_frames, | |
"decode_chunk_size": decode_chunk_size, | |
"num_interp_frames": num_interp_frames, | |
"num_overlap_frames": num_overlap_frames, | |
"max_resolution": max_resolution, | |
"seed": 666, | |
} | |
) | |
seed_all(cfg.seed) | |
file_name = os.path.splitext(os.path.basename(cfg.data_path))[0] | |
is_video = cfg.data_path.lower().endswith((".mp4", ".avi", ".mov", ".mkv")) | |
if is_video: | |
num_interp_frames = cfg.num_interp_frames | |
num_overlap_frames = cfg.num_overlap_frames | |
num_frames = cfg.num_frames | |
assert num_frames % 2 == 0, "num_frames should be even." | |
assert ( | |
2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2 | |
), "Invalid frame overlap." | |
max_frames = (num_interp_frames + 2 - num_overlap_frames) * ( | |
num_frames // 2 | |
) | |
image, fps = img_utils.read_video(cfg.data_path, max_frames=max_frames) | |
else: | |
image = img_utils.read_image(cfg.data_path) | |
image = img_utils.imresize_max(image, cfg.max_resolution) | |
image = img_utils.imcrop_multi(image) | |
image_tensor = np.ascontiguousarray( | |
[_img.transpose(2, 0, 1) / 255.0 for _img in image] | |
) | |
image_tensor = torch.from_numpy(image_tensor).to(DEVICE) | |
with torch.no_grad(), torch.autocast( | |
device_type=DEVICE_TYPE, dtype=torch.float16 | |
): | |
pipe_out = pipe( | |
image_tensor, | |
num_frames=cfg.num_frames, | |
num_overlap_frames=cfg.num_overlap_frames, | |
num_interp_frames=cfg.num_interp_frames, | |
decode_chunk_size=cfg.decode_chunk_size, | |
num_inference_steps=cfg.denoise_steps, | |
) | |
disparity = pipe_out.disparity | |
disparity_colored = pipe_out.disparity_colored | |
image = pipe_out.image | |
# (N, H, 2 * W, 3) | |
merged = np.concatenate( | |
[ | |
image, | |
disparity_colored, | |
], | |
axis=2, | |
) | |
if is_video: | |
output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.mp4") | |
img_utils.write_video( | |
output_path, | |
merged, | |
fps, | |
) | |
return output_path | |
else: | |
output_path = os.path.join(cfg.output_dir, f"{file_name}_depth.png") | |
img_utils.write_image( | |
output_path, | |
merged[0], | |
) | |
return output_path | |
# Define Gradio interface | |
title = "Depth Any Video with Scalable Synthetic Data" | |
description = """ | |
Upload a video or image to perform depth estimation using the Depth Any Video model. | |
Adjust the parameters as needed to control the inference process. | |
""" | |
iface = gr.Interface( | |
fn=depth_any_video, | |
inputs=[ | |
gr.File(label="Upload Video/Image"), | |
gr.Slider(1, 10, step=1, value=3, label="Denoise Steps"), | |
gr.Slider(16, 64, step=1, value=32, label="Number of Frames"), | |
gr.Slider(8, 32, step=1, value=16, label="Decode Chunk Size"), | |
gr.Slider(8, 32, step=1, value=16, label="Number of Interpolation Frames"), | |
gr.Slider(2, 10, step=1, value=6, label="Number of Overlap Frames"), | |
gr.Slider(512, 2048, step=32, value=1024, label="Maximum Resolution"), | |
], | |
outputs=gr.Video(label="Depth Enhanced Video/Image"), | |
title=title, | |
description=description, | |
examples=[["demos/arch_2.jpg"], ["demos/wooly_mammoth.mp4"]], | |
allow_flagging="never", | |
analytics_enabled=False, | |
) | |
if __name__ == "__main__": | |
iface.launch(share=True) | |