rollingdepth / app.py
toshas's picture
adjust examples
ce6c94f
raw
history blame
9.21 kB
# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# More information about the method can be found at https://marigoldmonodepth.github.io
# --------------------------------------------------------------------------
import functools
import os
import sys
import tempfile
import av
import numpy as np
import spaces
import gradio as gr
import torch as torch
import einops
from huggingface_hub import login
from colorize import colorize_depth_multi_thread
from video_io import get_video_fps, write_video_from_numpy
VERBOSE = False
MAX_FRAMES = 100
def process(pipe, device, path_input):
print(f"Processing {path_input}")
path_output_dir = tempfile.mkdtemp()
os.makedirs(path_output_dir, exist_ok=True)
name_base = os.path.splitext(os.path.basename(path_input))[0]
path_out_in = os.path.join(path_output_dir, f"{name_base}_depth_input.mp4")
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4")
output_fps = int(get_video_fps(path_input))
container = av.open(path_input)
stream = container.streams.video[0]
fps = float(stream.average_rate)
duration_sec = float(stream.duration * stream.time_base) if stream.duration else 0
total_frames = int(duration_sec * fps)
if total_frames > MAX_FRAMES:
gr.Warning(
f"Only the first {MAX_FRAMES} frames (~{MAX_FRAMES / fps:.1f} sec.) will be processed for demonstration; "
f"use the code from GitHub for full processing"
)
generator = torch.Generator(device=device)
generator.manual_seed(2024)
pipe_out: RollingDepthOutput = pipe(
# input setting
input_video_path=path_input,
start_frame=0,
frame_count=min(MAX_FRAMES, total_frames), # 0 = all
processing_res=768,
# infer setting
dilations=[1, 25],
cap_dilation=True,
snippet_lengths=[3],
init_infer_steps=[1],
strides=[1],
coalign_kwargs=None,
refine_step=0, # 0 = off
max_vae_bs=8, # batch size for encoder/decoder
# other settings
generator=generator,
verbose=VERBOSE,
# output settings
restore_res=False,
unload_snippet=False,
)
depth_pred = pipe_out.depth_pred # [N 1 H W]
# Colorize results
cmap = "Spectral_r"
colored_np = colorize_depth_multi_thread(
depth=depth_pred.numpy(),
valid_mask=None,
chunk_size=4,
num_threads=4,
color_map=cmap,
verbose=VERBOSE,
) # [n h w 3], in [0, 255]
write_video_from_numpy(
frames=colored_np,
output_path=path_out_vis,
fps=output_fps,
crf=23,
preset="medium",
verbose=VERBOSE,
)
# Save rgb
rgb = (pipe_out.input_rgb.numpy() * 255).astype(np.uint8) # [N 3 H W]
rgb = einops.rearrange(rgb, "n c h w -> n h w c")
write_video_from_numpy(
frames=rgb,
output_path=path_out_in,
fps=output_fps,
crf=23,
preset="medium",
verbose=VERBOSE,
)
return path_out_in, path_out_vis
def run_demo_server(pipe, device):
process_pipe = spaces.GPU(functools.partial(process, pipe, device), duration=120)
os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
with gr.Blocks(
analytics_enabled=False,
title="RollingDepth",
css="""
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
""",
) as demo:
gr.HTML(
"""
<h1>🛹 RollingDepth: Video Depth without Video Models</h1>
<div style="text-align: center; margin-top: 20px;">
<a title="Website" href="https://rollingdepth.github.io" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;">
<img src="https://www.obukhov.ai/img/badges/badge-website.svg" alt="Website Badge">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2411.xxxxx" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;">
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg" alt="arXiv Badge">
</a>
<a title="GitHub" href="https://github.com/prs-eth/rollingdepth" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;">
<img src="https://img.shields.io/github/stars/prs-eth/rollingdepth?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="GitHub Stars Badge">
</a>
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</div>
<p style="margin-top: 20px; text-align: justify;">
RollingDepth is the state-of-the-art depth estimator for videos in the wild. Upload your video into the
<b>left</b> pane, or click any of the <b>examples</b> below. The result preview will be computed and
appear in the <b>right</b> panes. For full functionality, use the code on GitHub.
<b>TIP:</b> When running out of GPU time, fork the demo.
</p>
"""
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_video = gr.Video(label="Input Video")
with gr.Column(scale=2):
with gr.Row(equal_height=True):
output_video_1 = gr.Video(
label="Preprocessed video",
interactive=False,
autoplay=True,
loop=True,
show_share_button=True,
scale=5,
)
output_video_2 = gr.Video(
label="Generated Depth Video",
interactive=False,
autoplay=True,
loop=True,
show_share_button=True,
scale=5,
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
with gr.Row(equal_height=False):
generate_btn = gr.Button("Generate")
with gr.Column(scale=2):
pass
gr.Examples(
examples=[
["files/gokart.mp4"],
["files/horse.mp4"],
["files/walking.mp4"],
],
inputs=[input_video],
outputs=[output_video_1, output_video_2],
fn=process_pipe,
cache_examples=True,
)
generate_btn.click(
fn=process_pipe,
inputs=[input_video],
outputs=[output_video_1, output_video_2],
)
demo.queue(
api_open=False,
).launch(
server_name="0.0.0.0",
server_port=7860,
)
def main():
os.system("pip freeze")
os.system("pip uninstall -y diffusers")
os.system("pip install rollingdepth_src/diffusers")
os.system("pip freeze")
if "HF_TOKEN_LOGIN" in os.environ:
login(token=os.environ["HF_TOKEN_LOGIN"])
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
sys.path.append(os.path.join(os.path.dirname(__file__), "rollingdepth_src"))
from rollingdepth import RollingDepthOutput, RollingDepthPipeline
pipe: RollingDepthPipeline = RollingDepthPipeline.from_pretrained(
"prs-eth/rollingdepth-v1-0",
torch_dtype=torch.float16,
)
pipe.set_progress_bar_config(disable=True)
try:
import xformers
pipe.enable_xformers_memory_efficient_attention()
except:
pass # run without xformers
pipe = pipe.to(device)
run_demo_server(pipe, device)
if __name__ == "__main__":
main()