Spaces:
Runtime error
Runtime error
File size: 7,828 Bytes
f239efc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
"""
Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
"""
import random
import io
import os
import av
import cv2
import decord
import imageio
from decord import VideoReader
# from dataloader import KVReader
import torch
import numpy as np
import math
# import tensorflow as tf
decord.bridge.set_bridge("torch")
import logging
logger = logging.getLogger(__name__)
def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
"""
Converts a present time with the given time base and start_pts offset to seconds.
Returns:
time_in_seconds (float): The corresponding time in seconds.
https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
"""
if pts == math.inf:
return math.inf
return int(pts - start_pts) * time_base
def get_pyav_video_duration(video_reader):
video_stream = video_reader.streams.video[0]
video_duration = pts_to_secs(
video_stream.duration,
video_stream.time_base,
video_stream.start_time
)
return float(video_duration)
def get_frame_indices_by_fps():
pass
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
if sample in ["rand", "middle"]: # uniform sampling
acc_samples = min(num_frames, vlen)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == 'rand':
try:
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
except:
frame_indices = np.random.permutation(vlen)[:acc_samples]
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == 'middle':
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[:len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(vlen) / input_fps
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int)
frame_indices = [e for e in frame_indices if e < vlen]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
else:
raise ValueError
return frame_indices
def read_frames_av(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, clip=None,
):
reader = av.open(video_path)
frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
vlen = len(frames)
duration = get_pyav_video_duration(reader)
fps = vlen / float(duration)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps, max_num_frames=max_num_frames
)
frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
return frames, frame_indices, fps
def read_frames_gif(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, clip=None,
):
if video_path.startswith('s3') or video_path.startswith('p2'):
video_bytes = client.get(video_path)
gif = imageio.get_reader(io.BytesIO(video_bytes))
else:
gif = imageio.get_reader(video_path)
vlen = len(gif)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
max_num_frames=max_num_frames
)
frames = []
for index, frame in enumerate(gif):
# for index in frame_idxs:
if index in frame_indices:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
frame = torch.from_numpy(frame).byte()
# # (H x W x C) to (C x H x W)
frame = frame.permute(2, 0, 1)
frames.append(frame)
frames = torch.stack(frames) # .float() / 255
return frames, frame_indices, 25. # for tgif
def read_frames_hdfs(ind_file, vid, num_frames, sample='rand',fix_start=None,
max_num_frames=-1, client=None, clip=None):
_context_features = {'title': tf.io.FixedLenFeature([], dtype=tf.string)}
_sequence_features = {'data': tf.io.FixedLenSequenceFeature([], dtype=tf.string)}
num_parallel_reader = 1
filename, extension = os.path.splitext(ind_file)
reader = KVReader(filename, num_parallel_reader)
key = vid
values = reader.read_many([key])
item = values[0]
contexts, sequences = tf.io.parse_single_sequence_example(
serialized=item,
context_features=_context_features,
sequence_features=_sequence_features)
# text = contexts['title'].numpy().decode("utf-8")
rawframes = sequences['data']
vlen = len(rawframes)
sample="rand"
frame_indices = get_frame_indices(num_frames, vlen, sample=sample,
fix_start=fix_start,
max_num_frames=max_num_frames)
def read_image(raw_data):
return tf.image.decode_jpeg(raw_data, channels=3, dct_method='INTEGER_ACCURATE').numpy()
frames = []
for index, frame in enumerate(rawframes):
if index in frame_indices:
frame = read_image(frame)
frame = torch.as_tensor(frame)
frames.append(frame)
frames = torch.stack(frames)
# print("in hdfs========>",frames[0])
frames = frames.permute(0, 3, 1, 2)
return frames, frame_indices, 25 # don't know the fps for index
def read_frames_decord(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, clip=None
):
if video_path.startswith('s3') or video_path.startswith('p2'):
video_bytes = client.get(video_path)
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
else:
video_reader = VideoReader(video_path, num_threads=1)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
if clip:
start, end = clip
duration = end - start
vlen = int(duration * fps)
start_index = int(start * fps)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps, max_num_frames=max_num_frames
)
if clip:
frame_indices = [f + start_index for f in frame_indices]
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
return frames, frame_indices, float(fps)
VIDEO_READER_FUNCS = {
'av': read_frames_av,
'decord': read_frames_decord,
'gif': read_frames_gif,
'hdfs': read_frames_hdfs,
}
|