Align3R / third_party /sam2 /training /scripts /sav_frame_extraction_submitit.py
cyun9286's picture
Add application file
f53b39e
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import argparse
import os
from pathlib import Path
import cv2
import numpy as np
import submitit
import tqdm
def get_args_parser():
parser = argparse.ArgumentParser(
description="[SA-V Preprocessing] Extracting JPEG frames",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# ------------
# DATA
# ------------
data_parser = parser.add_argument_group(
title="SA-V dataset data root",
description="What data to load and how to process it.",
)
data_parser.add_argument(
"--sav-vid-dir",
type=str,
required=True,
help=("Where to find the SAV videos"),
)
data_parser.add_argument(
"--sav-frame-sample-rate",
type=int,
default=4,
help="Rate at which to sub-sample frames",
)
# ------------
# LAUNCH
# ------------
launch_parser = parser.add_argument_group(
title="Cluster launch settings",
description="Number of jobs and retry settings.",
)
launch_parser.add_argument(
"--n-jobs",
type=int,
required=True,
help="Shard the run over this many jobs.",
)
launch_parser.add_argument(
"--timeout", type=int, required=True, help="SLURM timeout parameter in minutes."
)
launch_parser.add_argument(
"--partition", type=str, required=True, help="Partition to launch on."
)
launch_parser.add_argument(
"--account", type=str, required=True, help="Partition to launch on."
)
launch_parser.add_argument("--qos", type=str, required=True, help="QOS.")
# ------------
# OUTPUT
# ------------
output_parser = parser.add_argument_group(
title="Setting for results output", description="Where and how to save results."
)
output_parser.add_argument(
"--output-dir",
type=str,
required=True,
help=("Where to dump the extracted jpeg frames"),
)
output_parser.add_argument(
"--slurm-output-root-dir",
type=str,
required=True,
help=("Where to save slurm outputs"),
)
return parser
def decode_video(video_path: str):
assert os.path.exists(video_path)
video = cv2.VideoCapture(video_path)
video_frames = []
while video.isOpened():
ret, frame = video.read()
if ret:
video_frames.append(frame)
else:
break
return video_frames
def extract_frames(video_path, sample_rate):
frames = decode_video(video_path)
return frames[::sample_rate]
def submitit_launch(video_paths, sample_rate, save_root):
for path in tqdm.tqdm(video_paths):
frames = extract_frames(path, sample_rate)
output_folder = os.path.join(save_root, Path(path).stem)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
for fid, frame in enumerate(frames):
frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg")
cv2.imwrite(frame_path, frame)
print(f"Saved output to {save_root}")
if __name__ == "__main__":
parser = get_args_parser()
args = parser.parse_args()
sav_vid_dir = args.sav_vid_dir
save_root = args.output_dir
sample_rate = args.sav_frame_sample_rate
# List all SA-V videos
mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")])
mp4_files = np.array(mp4_files)
chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)]
print(f"Processing videos in: {sav_vid_dir}")
print(f"Processing {len(mp4_files)} files")
print(f"Beginning processing in {args.n_jobs} processes")
# Submitit params
jobs_dir = os.path.join(args.slurm_output_root_dir, "%j")
cpus_per_task = 4
executor = submitit.AutoExecutor(folder=jobs_dir)
executor.update_parameters(
timeout_min=args.timeout,
gpus_per_node=0,
tasks_per_node=1,
slurm_array_parallelism=args.n_jobs,
cpus_per_task=cpus_per_task,
slurm_partition=args.partition,
slurm_account=args.account,
slurm_qos=args.qos,
)
executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"])
# Launch
jobs = []
with executor.batch():
for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)):
job = executor.submit(
submitit_launch,
video_paths=mp4_chunk,
sample_rate=sample_rate,
save_root=save_root,
)
jobs.append(job)
for j in jobs:
print(f"Slurm JobID: {j.job_id}")
print(f"Saving outputs to {save_root}")
print(f"Slurm outputs at {args.slurm_output_root_dir}")