Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,485 Bytes
2d9a728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
"""
Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
"""
import os
import random
import io
import av
import cv2
import decord
import imageio
from decord import VideoReader
import torch
import numpy as np
import math
decord.bridge.set_bridge("torch")
import logging
logger = logging.getLogger(__name__)
def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
"""
Converts a present time with the given time base and start_pts offset to seconds.
Returns:
time_in_seconds (float): The corresponding time in seconds.
https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
"""
if pts == math.inf:
return math.inf
return int(pts - start_pts) * time_base
def get_pyav_video_duration(video_reader):
video_stream = video_reader.streams.video[0]
video_duration = pts_to_secs(
video_stream.duration,
video_stream.time_base,
video_stream.start_time
)
return float(video_duration)
def get_frame_indices_by_fps():
pass
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
if sample in ["rand", "middle"]: # uniform sampling
acc_samples = min(num_frames, vlen)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == 'rand':
try:
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
except:
frame_indices = np.random.permutation(vlen)[:acc_samples]
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == 'middle':
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[:len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(vlen) / input_fps
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int)
frame_indices = [e for e in frame_indices if e < vlen]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
else:
raise ValueError
return frame_indices
def read_frames_av(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1):
reader = av.open(video_path)
frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
vlen = len(frames)
duration = get_pyav_video_duration(reader)
fps = vlen / float(duration)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps, max_num_frames=max_num_frames
)
frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
return frames, frame_indices, duration
def read_frames_gif(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, trimmed30=False,
):
if 's3://' in video_path:
video_bytes = client.get(video_path)
gif = imageio.get_reader(io.BytesIO(video_bytes))
else:
gif = imageio.get_reader(video_path)
vlen = len(gif)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
max_num_frames=max_num_frames
)
frames = []
for index, frame in enumerate(gif):
# for index in frame_idxs:
if index in frame_indices:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
frame = torch.from_numpy(frame).byte()
# # (H x W x C) to (C x H x W)
frame = frame.permute(2, 0, 1)
frames.append(frame)
frames = torch.stack(frames) # .float() / 255
return frames, frame_indices, None
def read_frames_decord(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, trimmed30=False
):
num_threads = 1 if video_path.endswith('.webm') else 0 # make ssv2 happy
if "s3://" in video_path:
video_bytes = client.get(video_path)
# print(f"\033[1;31;40m {video_path} ok: {video_bytes is None} \033[0m")
if video_bytes is None:
logger.warning(f"Failed to load {video_path}")
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=num_threads)
else:
video_reader = VideoReader(video_path, num_threads=num_threads)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
# only use top 30 seconds
if trimmed30 and duration > 30:
duration = 30
vlen = int(30 * float(fps))
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps, max_num_frames=max_num_frames
)
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
return frames, frame_indices, duration
def read_frames_img(
video_path, num_frames, sample='rand', fix_start=None,
max_num_frames=-1, client=None, trimmed30=False
):
img_list=[]
if "s3://" in video_path:
for path in client.list(video_path):
if path.startswith('img'):
img_list.append(path)
else:
for path in os.listdir(video_path):
if path.startswith('img'):
img_list.append(path)
vlen = len(img_list)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
max_num_frames=max_num_frames
)
imgs = []
for idx in frame_indices:
frame_fname = os.path.join(video_path, img_list[idx])
if "s3://" in video_path:
img_bytes = client.get(frame_fname)
else:
with open(frame_fname, 'rb') as f:
img_bytes = f.read()
img_np = np.frombuffer(img_bytes, np.uint8)
img = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
imgs.append(img)
frames = torch.tensor(np.array(imgs), dtype=torch.uint8).permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
return frames, frame_indices, None
VIDEO_READER_FUNCS = {
'av': read_frames_av,
'decord': read_frames_decord,
'gif': read_frames_gif,
'img': read_frames_img,
}
|