Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 Bingxin Ke, ETH Zurich. All rights reserved. | |
# Last modified: 2024-11-28 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# --------------------------------------------------------------------------------- | |
# If you find this code useful, we kindly ask you to cite our paper in your work. | |
# Please find bibtex at: https://github.com/prs-eth/RollingDepth#-citation | |
# More information about the method can be found at https://rollingdepth.github.io | |
# --------------------------------------------------------------------------------- | |
import logging | |
from os import PathLike | |
from typing import Optional | |
import av | |
import numpy as np | |
from tqdm import tqdm | |
def get_video_fps(video_path: PathLike) -> float: | |
# Open the video file | |
container = av.open(video_path) | |
# Get the video stream | |
video_stream = container.streams.video[0] | |
# Calculate FPS from the stream's time base and average frame rate | |
fps = float(video_stream.average_rate) # type: ignore | |
# Close the container | |
container.close() | |
return fps | |
def write_video_from_numpy( | |
frames: np.ndarray, # shape [n h w 3] | |
output_path: PathLike, | |
fps: int = 30, | |
codec: Optional[str] = None, # Let PyAV choose default codec | |
crf: int = 23, | |
preset: str = "medium", | |
verbose: bool = False, | |
) -> None: | |
if len(frames.shape) != 4 or frames.shape[-1] != 3: | |
raise ValueError(f"Expected shape [n, height, width, 3], got {frames.shape}") | |
if frames.dtype != np.uint8: | |
raise ValueError(f"Expected dtype uint8, got {frames.dtype}") | |
n_frames, height, width, _ = frames.shape | |
# Try to determine codec from output format if not specified | |
if codec is None: | |
codecs_to_try = ["libx264", "h264", "mpeg4", "mjpeg"] | |
else: | |
codecs_to_try = [codec] | |
# Try available codecs | |
for try_codec in codecs_to_try: | |
try: | |
container = av.open(output_path, mode="w") | |
stream = container.add_stream(try_codec, rate=fps) | |
if verbose: | |
logging.info(f"Using codec: {try_codec}") | |
break | |
except av.codec.codec.UnknownCodecError: # type: ignore | |
if try_codec == codecs_to_try[-1]: # Last codec in list | |
raise ValueError( | |
f"No working codec found. Tried: {codecs_to_try}. " | |
"Please install ffmpeg with necessary codecs." | |
) | |
continue | |
stream.width = width # type: ignore | |
stream.height = height # type: ignore | |
stream.pix_fmt = "yuv420p" # type: ignore | |
# Only set these options for x264-compatible codecs | |
if try_codec in ["libx264", "h264"]: # type: ignore | |
stream.options = {"crf": str(crf), "preset": preset} # type: ignore | |
# Create a single VideoFrame object and reuse it | |
video_frame = av.VideoFrame(width, height, "rgb24") | |
frames_iterable = range(n_frames) | |
if verbose: | |
frames_iterable = tqdm(frames_iterable, desc="Writing video", total=n_frames) | |
try: | |
for frame_idx in frames_iterable: | |
# Get view of current frame | |
current_frame = frames[frame_idx] | |
# Update frame data in-place | |
video_frame.to_ndarray()[:] = current_frame | |
packet = stream.encode(video_frame) # type: ignore | |
container.mux(packet) # type: ignore | |
# Flush the stream | |
packet = stream.encode(None) # type: ignore | |
container.mux(packet) # type: ignore | |
finally: | |
container.close() # type: ignore | |