import torch import numpy as np import supervision as sv import cv2 import os from glob import glob from tqdm import tqdm def plot_predictions( image: str, labels: list[str], scores: list[float], boxes: list[float], ) -> np.ndarray: image_source = cv2.imread(image) image_source = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB) boxes = sv.Detections(xyxy=boxes) labels = [ f"{phrase} {logit:.2f}" for phrase, logit in zip(labels, scores) ] bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX) label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR) annotated_frame = bbox_annotator.annotate(scene=annotated_frame, detections=boxes) annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=boxes, labels=labels) return annotated_frame def mp4_to_png(input_path: str, save_path: str, scale_factor: float) -> str: """ Converts mp4 to pngs for each frame of the video. Args: input_path is the path to the mp4 file, save_path is the directory to save the frames. Returns: save_path, fps the number of frames per second. """ # get frames per second fps = int(cv2.VideoCapture(input_path).get(cv2.CAP_PROP_FPS)) # run subprocess to convert mp4 to pngs os.system(f"ffmpeg -i {input_path} -vf 'scale=iw*{scale_factor}:ih*{scale_factor}, fps={fps}' {save_path}/frame%08d.png") return fps def vid_stitcher(frames_dir: str, output_path: str, fps: int = 30) -> str: """ Takes a list of frames as numpy arrays and writes them to a video file. """ # Get the list of frames frame_list = sorted(glob(os.path.join(frames_dir, 'frame*.png'))) # Prepare the VideoWriter frame = cv2.imread(frame_list[0]) height, width, _ = frame.shape fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) # Use multithreading to read frames faster from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor() as executor: frames = list(executor.map(cv2.imread, frame_list)) # Write frames to the video with tqdm(total=len(frame_list), desc='Stitching frames') as pbar: for frame in frames: out.write(frame) pbar.update(1) return output_path def count_pos(phrases, text_target): """ Takes a list of list of phrases and calculates the number of lists that have at least one entry that is the target phrase """ num_pos = 0 for sublist in phrases: if sublist == None: continue for phrase in sublist: if phrase == text_target: num_pos += 1 break return num_pos