rollingdepth / video_io.py
toshas's picture
initial commit
a45988a
# Copyright 2024 Bingxin Ke, ETH Zurich. All rights reserved.
# Last modified: 2024-11-28
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/RollingDepth#-citation
# More information about the method can be found at https://rollingdepth.github.io
# ---------------------------------------------------------------------------------
import logging
from os import PathLike
from typing import Optional
import av
import numpy as np
from tqdm import tqdm
def get_video_fps(video_path: PathLike) -> float:
# Open the video file
container = av.open(video_path)
# Get the video stream
video_stream = container.streams.video[0]
# Calculate FPS from the stream's time base and average frame rate
fps = float(video_stream.average_rate) # type: ignore
# Close the container
container.close()
return fps
def write_video_from_numpy(
frames: np.ndarray, # shape [n h w 3]
output_path: PathLike,
fps: int = 30,
codec: Optional[str] = None, # Let PyAV choose default codec
crf: int = 23,
preset: str = "medium",
verbose: bool = False,
) -> None:
if len(frames.shape) != 4 or frames.shape[-1] != 3:
raise ValueError(f"Expected shape [n, height, width, 3], got {frames.shape}")
if frames.dtype != np.uint8:
raise ValueError(f"Expected dtype uint8, got {frames.dtype}")
n_frames, height, width, _ = frames.shape
# Try to determine codec from output format if not specified
if codec is None:
codecs_to_try = ["libx264", "h264", "mpeg4", "mjpeg"]
else:
codecs_to_try = [codec]
# Try available codecs
for try_codec in codecs_to_try:
try:
container = av.open(output_path, mode="w")
stream = container.add_stream(try_codec, rate=fps)
if verbose:
logging.info(f"Using codec: {try_codec}")
break
except av.codec.codec.UnknownCodecError: # type: ignore
if try_codec == codecs_to_try[-1]: # Last codec in list
raise ValueError(
f"No working codec found. Tried: {codecs_to_try}. "
"Please install ffmpeg with necessary codecs."
)
continue
stream.width = width # type: ignore
stream.height = height # type: ignore
stream.pix_fmt = "yuv420p" # type: ignore
# Only set these options for x264-compatible codecs
if try_codec in ["libx264", "h264"]: # type: ignore
stream.options = {"crf": str(crf), "preset": preset} # type: ignore
# Create a single VideoFrame object and reuse it
video_frame = av.VideoFrame(width, height, "rgb24")
frames_iterable = range(n_frames)
if verbose:
frames_iterable = tqdm(frames_iterable, desc="Writing video", total=n_frames)
try:
for frame_idx in frames_iterable:
# Get view of current frame
current_frame = frames[frame_idx]
# Update frame data in-place
video_frame.to_ndarray()[:] = current_frame
packet = stream.encode(video_frame) # type: ignore
container.mux(packet) # type: ignore
# Flush the stream
packet = stream.encode(None) # type: ignore
container.mux(packet) # type: ignore
finally:
container.close() # type: ignore