File size: 4,049 Bytes
a45988a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright 2024 Bingxin Ke, ETH Zurich. All rights reserved.
# Last modified: 2024-11-28
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/RollingDepth#-citation
# More information about the method can be found at https://rollingdepth.github.io
# ---------------------------------------------------------------------------------
import logging
from os import PathLike
from typing import Optional

import av
import numpy as np
from tqdm import tqdm


def get_video_fps(video_path: PathLike) -> float:
    # Open the video file
    container = av.open(video_path)

    # Get the video stream
    video_stream = container.streams.video[0]

    # Calculate FPS from the stream's time base and average frame rate
    fps = float(video_stream.average_rate)  # type: ignore

    # Close the container
    container.close()

    return fps


def write_video_from_numpy(
    frames: np.ndarray,  # shape [n h w 3]
    output_path: PathLike,
    fps: int = 30,
    codec: Optional[str] = None,  # Let PyAV choose default codec
    crf: int = 23,
    preset: str = "medium",
    verbose: bool = False,
) -> None:
    if len(frames.shape) != 4 or frames.shape[-1] != 3:
        raise ValueError(f"Expected shape [n, height, width, 3], got {frames.shape}")
    if frames.dtype != np.uint8:
        raise ValueError(f"Expected dtype uint8, got {frames.dtype}")

    n_frames, height, width, _ = frames.shape

    # Try to determine codec from output format if not specified
    if codec is None:
        codecs_to_try = ["libx264", "h264", "mpeg4", "mjpeg"]
    else:
        codecs_to_try = [codec]

    # Try available codecs
    for try_codec in codecs_to_try:
        try:
            container = av.open(output_path, mode="w")
            stream = container.add_stream(try_codec, rate=fps)
            if verbose:
                logging.info(f"Using codec: {try_codec}")
            break
        except av.codec.codec.UnknownCodecError:  # type: ignore
            if try_codec == codecs_to_try[-1]:  # Last codec in list
                raise ValueError(
                    f"No working codec found. Tried: {codecs_to_try}. "
                    "Please install ffmpeg with necessary codecs."
                )
            continue

    stream.width = width  # type: ignore
    stream.height = height  # type: ignore
    stream.pix_fmt = "yuv420p"  # type: ignore

    # Only set these options for x264-compatible codecs
    if try_codec in ["libx264", "h264"]:  # type: ignore
        stream.options = {"crf": str(crf), "preset": preset}  # type: ignore

    # Create a single VideoFrame object and reuse it
    video_frame = av.VideoFrame(width, height, "rgb24")

    frames_iterable = range(n_frames)
    if verbose:
        frames_iterable = tqdm(frames_iterable, desc="Writing video", total=n_frames)

    try:
        for frame_idx in frames_iterable:
            # Get view of current frame
            current_frame = frames[frame_idx]

            # Update frame data in-place
            video_frame.to_ndarray()[:] = current_frame

            packet = stream.encode(video_frame)  # type: ignore
            container.mux(packet)  # type: ignore

        # Flush the stream
        packet = stream.encode(None)  # type: ignore
        container.mux(packet)  # type: ignore
    finally:
        container.close()  # type: ignore