annayding
changed type of boxes
f77e3aa
import torch
import numpy as np
import supervision as sv
import cv2
import os
from glob import glob
from tqdm import tqdm
def plot_predictions(
image: str,
labels: list[str],
scores: list[float],
boxes: list[float],
) -> np.ndarray:
image_source = cv2.imread(image)
image_source = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)
boxes = sv.Detections(xyxy=boxes)
labels = [
f"{phrase} {logit:.2f}"
for phrase, logit
in zip(labels, scores)
]
bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
annotated_frame = bbox_annotator.annotate(scene=annotated_frame, detections=boxes)
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=boxes, labels=labels)
return annotated_frame
def mp4_to_png(input_path: str, save_path: str, scale_factor: float) -> str:
""" Converts mp4 to pngs for each frame of the video.
Args: input_path is the path to the mp4 file, save_path is the directory to save the frames.
Returns: save_path, fps the number of frames per second.
"""
# get frames per second
fps = int(cv2.VideoCapture(input_path).get(cv2.CAP_PROP_FPS))
# run subprocess to convert mp4 to pngs
os.system(f"ffmpeg -i {input_path} -vf 'scale=iw*{scale_factor}:ih*{scale_factor}, fps={fps}' {save_path}/frame%08d.png")
return fps
def vid_stitcher(frames_dir: str, output_path: str, fps: int = 30) -> str:
"""
Takes a list of frames as numpy arrays and writes them to a video file.
"""
# Get the list of frames
frame_list = sorted(glob(os.path.join(frames_dir, 'frame*.png')))
# Prepare the VideoWriter
frame = cv2.imread(frame_list[0])
height, width, _ = frame.shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# Use multithreading to read frames faster
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
frames = list(executor.map(cv2.imread, frame_list))
# Write frames to the video
with tqdm(total=len(frame_list), desc='Stitching frames') as pbar:
for frame in frames:
out.write(frame)
pbar.update(1)
return output_path
def count_pos(phrases, text_target):
"""
Takes a list of list of phrases and calculates the number of lists that have at least one entry that is the target phrase
"""
num_pos = 0
for sublist in phrases:
if sublist == None:
continue
for phrase in sublist:
if phrase == text_target:
num_pos += 1
break
return num_pos