Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py | |
""" | |
import os | |
import random | |
import io | |
import av | |
import cv2 | |
import decord | |
import imageio | |
from decord import VideoReader | |
import torch | |
import numpy as np | |
import math | |
decord.bridge.set_bridge("torch") | |
import logging | |
logger = logging.getLogger(__name__) | |
def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float: | |
""" | |
Converts a present time with the given time base and start_pts offset to seconds. | |
Returns: | |
time_in_seconds (float): The corresponding time in seconds. | |
https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64 | |
""" | |
if pts == math.inf: | |
return math.inf | |
return int(pts - start_pts) * time_base | |
def get_pyav_video_duration(video_reader): | |
video_stream = video_reader.streams.video[0] | |
video_duration = pts_to_secs( | |
video_stream.duration, | |
video_stream.time_base, | |
video_stream.start_time | |
) | |
return float(video_duration) | |
def get_frame_indices_by_fps(): | |
pass | |
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): | |
if sample in ["rand", "middle"]: # uniform sampling | |
acc_samples = min(num_frames, vlen) | |
# split the video into `acc_samples` intervals, and sample from each interval. | |
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) | |
ranges = [] | |
for idx, interv in enumerate(intervals[:-1]): | |
ranges.append((interv, intervals[idx + 1] - 1)) | |
if sample == 'rand': | |
try: | |
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] | |
except: | |
frame_indices = np.random.permutation(vlen)[:acc_samples] | |
frame_indices.sort() | |
frame_indices = list(frame_indices) | |
elif fix_start is not None: | |
frame_indices = [x[0] + fix_start for x in ranges] | |
elif sample == 'middle': | |
frame_indices = [(x[0] + x[1]) // 2 for x in ranges] | |
else: | |
raise NotImplementedError | |
if len(frame_indices) < num_frames: # padded with last frame | |
padded_frame_indices = [frame_indices[-1]] * num_frames | |
padded_frame_indices[:len(frame_indices)] = frame_indices | |
frame_indices = padded_frame_indices | |
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps | |
output_fps = float(sample[3:]) | |
duration = float(vlen) / input_fps | |
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents | |
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) | |
frame_indices = np.around(frame_seconds * input_fps).astype(int) | |
frame_indices = [e for e in frame_indices if e < vlen] | |
if max_num_frames > 0 and len(frame_indices) > max_num_frames: | |
frame_indices = frame_indices[:max_num_frames] | |
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) | |
else: | |
raise ValueError | |
return frame_indices | |
def read_frames_av(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1): | |
reader = av.open(video_path) | |
frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)] | |
vlen = len(frames) | |
duration = get_pyav_video_duration(reader) | |
fps = vlen / float(duration) | |
frame_indices = get_frame_indices( | |
num_frames, vlen, sample=sample, fix_start=fix_start, | |
input_fps=fps, max_num_frames=max_num_frames | |
) | |
frames = torch.stack([frames[idx] for idx in frame_indices]) # (T, H, W, C), torch.uint8 | |
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 | |
return frames, frame_indices, duration | |
def read_frames_gif( | |
video_path, num_frames, sample='rand', fix_start=None, | |
max_num_frames=-1, client=None, trimmed30=False, | |
): | |
if 's3://' in video_path: | |
video_bytes = client.get(video_path) | |
gif = imageio.get_reader(io.BytesIO(video_bytes)) | |
else: | |
gif = imageio.get_reader(video_path) | |
vlen = len(gif) | |
frame_indices = get_frame_indices( | |
num_frames, vlen, sample=sample, fix_start=fix_start, | |
max_num_frames=max_num_frames | |
) | |
frames = [] | |
for index, frame in enumerate(gif): | |
# for index in frame_idxs: | |
if index in frame_indices: | |
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) | |
frame = torch.from_numpy(frame).byte() | |
# # (H x W x C) to (C x H x W) | |
frame = frame.permute(2, 0, 1) | |
frames.append(frame) | |
frames = torch.stack(frames) # .float() / 255 | |
return frames, frame_indices, None | |
def read_frames_decord( | |
video_path, num_frames, sample='rand', fix_start=None, | |
max_num_frames=-1, client=None, trimmed30=False | |
): | |
num_threads = 1 if video_path.endswith('.webm') else 0 # make ssv2 happy | |
if "s3://" in video_path: | |
video_bytes = client.get(video_path) | |
# print(f"\033[1;31;40m {video_path} ok: {video_bytes is None} \033[0m") | |
if video_bytes is None: | |
logger.warning(f"Failed to load {video_path}") | |
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=num_threads) | |
else: | |
video_reader = VideoReader(video_path, num_threads=num_threads) | |
vlen = len(video_reader) | |
fps = video_reader.get_avg_fps() | |
duration = vlen / float(fps) | |
# only use top 30 seconds | |
if trimmed30 and duration > 30: | |
duration = 30 | |
vlen = int(30 * float(fps)) | |
frame_indices = get_frame_indices( | |
num_frames, vlen, sample=sample, fix_start=fix_start, | |
input_fps=fps, max_num_frames=max_num_frames | |
) | |
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 | |
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 | |
return frames, frame_indices, duration | |
def read_frames_img( | |
video_path, num_frames, sample='rand', fix_start=None, | |
max_num_frames=-1, client=None, trimmed30=False | |
): | |
img_list=[] | |
if "s3://" in video_path: | |
for path in client.list(video_path): | |
if path.startswith('img'): | |
img_list.append(path) | |
else: | |
for path in os.listdir(video_path): | |
if path.startswith('img'): | |
img_list.append(path) | |
vlen = len(img_list) | |
frame_indices = get_frame_indices( | |
num_frames, vlen, sample=sample, fix_start=fix_start, | |
max_num_frames=max_num_frames | |
) | |
imgs = [] | |
for idx in frame_indices: | |
frame_fname = os.path.join(video_path, img_list[idx]) | |
if "s3://" in video_path: | |
img_bytes = client.get(frame_fname) | |
else: | |
with open(frame_fname, 'rb') as f: | |
img_bytes = f.read() | |
img_np = np.frombuffer(img_bytes, np.uint8) | |
img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) | |
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) | |
imgs.append(img) | |
frames = torch.tensor(np.array(imgs), dtype=torch.uint8).permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 | |
return frames, frame_indices, None | |
VIDEO_READER_FUNCS = { | |
'av': read_frames_av, | |
'decord': read_frames_decord, | |
'gif': read_frames_gif, | |
'img': read_frames_img, | |
} | |