File size: 4,631 Bytes
c58ca4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from dataclasses import dataclass
from fractions import Fraction
from pathlib import Path
from typing import Optional
import av
import numpy as np
import torch
from av import AudioFrame
@dataclass
class VideoInfo:
duration_sec: float
fps: Fraction
clip_frames: torch.Tensor
sync_frames: torch.Tensor
all_frames: Optional[list[np.ndarray]]
@property
def height(self):
return self.all_frames[0].shape[0]
@property
def width(self):
return self.all_frames[0].shape[1]
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
output_frames = [[] for _ in list_of_fps]
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
all_frames = []
# container = av.open(video_path)
with av.open(video_path) as container:
stream = container.streams.video[0]
fps = stream.guessed_rate
stream.thread_type = 'AUTO'
for packet in container.demux(stream):
for frame in packet.decode():
frame_time = frame.time
if frame_time < start_sec:
continue
if frame_time > end_sec:
break
frame_np = None
if need_all_frames:
frame_np = frame.to_ndarray(format='rgb24')
all_frames.append(frame_np)
for i, _ in enumerate(list_of_fps):
this_time = frame_time
while this_time >= next_frame_time_for_each_fps[i]:
if frame_np is None:
frame_np = frame.to_ndarray(format='rgb24')
output_frames[i].append(frame_np)
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
output_frames = [np.stack(frames) for frames in output_frames]
return output_frames, all_frames, fps
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
sampling_rate: int):
container = av.open(output_path, 'w')
output_video_stream = container.add_stream('h264', video_info.fps)
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
output_video_stream.width = video_info.width
output_video_stream.height = video_info.height
output_video_stream.pix_fmt = 'yuv420p'
output_audio_stream = container.add_stream('aac', sampling_rate)
# encode video
for image in video_info.all_frames:
image = av.VideoFrame.from_ndarray(image)
packet = output_video_stream.encode(image)
container.mux(packet)
for packet in output_video_stream.encode():
container.mux(packet)
# convert float tensor audio to numpy array
audio_np = audio.numpy().astype(np.float32)
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
audio_frame.sample_rate = sampling_rate
for packet in output_audio_stream.encode(audio_frame):
container.mux(packet)
for packet in output_audio_stream.encode():
container.mux(packet)
container.close()
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
"""
NOTE: I don't think we can get the exact video duration right without re-encoding
so we are not using this but keeping it here for reference
"""
video = av.open(video_path)
output = av.open(output_path, 'w')
input_video_stream = video.streams.video[0]
output_video_stream = output.add_stream(template=input_video_stream)
output_audio_stream = output.add_stream('aac', sampling_rate)
duration_sec = audio.shape[-1] / sampling_rate
for packet in video.demux(input_video_stream):
# We need to skip the "flushing" packets that `demux` generates.
if packet.dts is None:
continue
# We need to assign the packet to the new stream.
packet.stream = output_video_stream
output.mux(packet)
# convert float tensor audio to numpy array
audio_np = audio.numpy().astype(np.float32)
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
audio_frame.sample_rate = sampling_rate
for packet in output_audio_stream.encode(audio_frame):
output.mux(packet)
for packet in output_audio_stream.encode():
output.mux(packet)
video.close()
output.close()
output.close()
|