|
import torch |
|
import numpy as np |
|
import supervision as sv |
|
import cv2 |
|
import os |
|
from glob import glob |
|
from tqdm import tqdm |
|
|
|
|
|
def plot_predictions( |
|
image: str, |
|
labels: list[str], |
|
scores: list[float], |
|
boxes: list[float], |
|
) -> np.ndarray: |
|
|
|
image_source = cv2.imread(image) |
|
image_source = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB) |
|
|
|
boxes = sv.Detections(xyxy=boxes) |
|
|
|
labels = [ |
|
f"{phrase} {logit:.2f}" |
|
for phrase, logit |
|
in zip(labels, scores) |
|
] |
|
|
|
bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX) |
|
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) |
|
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR) |
|
annotated_frame = bbox_annotator.annotate(scene=annotated_frame, detections=boxes) |
|
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=boxes, labels=labels) |
|
|
|
return annotated_frame |
|
|
|
def mp4_to_png(input_path: str, save_path: str, scale_factor: float) -> str: |
|
""" Converts mp4 to pngs for each frame of the video. |
|
Args: input_path is the path to the mp4 file, save_path is the directory to save the frames. |
|
Returns: save_path, fps the number of frames per second. |
|
""" |
|
|
|
fps = int(cv2.VideoCapture(input_path).get(cv2.CAP_PROP_FPS)) |
|
|
|
os.system(f"ffmpeg -i {input_path} -vf 'scale=iw*{scale_factor}:ih*{scale_factor}, fps={fps}' {save_path}/frame%08d.png") |
|
return fps |
|
|
|
def vid_stitcher(frames_dir: str, output_path: str, fps: int = 30) -> str: |
|
""" |
|
Takes a list of frames as numpy arrays and writes them to a video file. |
|
""" |
|
|
|
frame_list = sorted(glob(os.path.join(frames_dir, 'frame*.png'))) |
|
|
|
|
|
frame = cv2.imread(frame_list[0]) |
|
height, width, _ = frame.shape |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
with ThreadPoolExecutor() as executor: |
|
frames = list(executor.map(cv2.imread, frame_list)) |
|
|
|
|
|
with tqdm(total=len(frame_list), desc='Stitching frames') as pbar: |
|
for frame in frames: |
|
out.write(frame) |
|
pbar.update(1) |
|
|
|
return output_path |
|
|
|
def count_pos(phrases, text_target): |
|
""" |
|
Takes a list of list of phrases and calculates the number of lists that have at least one entry that is the target phrase |
|
""" |
|
num_pos = 0 |
|
for sublist in phrases: |
|
if sublist == None: |
|
continue |
|
for phrase in sublist: |
|
if phrase == text_target: |
|
num_pos += 1 |
|
break |
|
return num_pos |