annayding
first commit
bfa3aba
raw
history blame
2.89 kB
import torch
import numpy as np
import supervision as sv
import cv2
import os
from glob import glob
from tqdm import tqdm
def plot_predictions(
image: str,
labels: list[str],
scores: torch.Tensor,
boxes: torch.Tensor,
) -> np.ndarray:
image_source = cv2.imread(image)
image_source = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)
boxes = sv.Detections(xyxy=boxes.cpu().numpy())
labels = [
f"{phrase} {logit:.2f}"
for phrase, logit
in zip(labels, scores)
]
bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
annotated_frame = bbox_annotator.annotate(scene=annotated_frame, detections=boxes)
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=boxes, labels=labels)
return annotated_frame
def mp4_to_png(input_path: str, save_path: str, scale_factor: float) -> str:
""" Converts mp4 to pngs for each frame of the video.
Args: input_path is the path to the mp4 file, save_path is the directory to save the frames.
Returns: save_path, fps the number of frames per second.
"""
# get frames per second
fps = int(cv2.VideoCapture(input_path).get(cv2.CAP_PROP_FPS))
# run subprocess to convert mp4 to pngs
os.system(f"ffmpeg -i {input_path} -vf 'scale=iw*{scale_factor}:ih*{scale_factor}, fps={fps}' {save_path}/frame%08d.png")
return fps
def vid_stitcher(frames_dir: str, output_path: str, fps: int = 30) -> str:
"""
Takes a list of frames as numpy arrays and writes them to a video file.
"""
# Get the list of frames
frame_list = sorted(glob(os.path.join(frames_dir, 'frame*.png')))
# Prepare the VideoWriter
frame = cv2.imread(frame_list[0])
height, width, _ = frame.shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# Use multithreading to read frames faster
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
frames = list(executor.map(cv2.imread, frame_list))
# Write frames to the video
with tqdm(total=len(frame_list), desc='Stitching frames') as pbar:
for frame in frames:
out.write(frame)
pbar.update(1)
return output_path
def count_pos(phrases, text_target):
"""
Takes a list of list of phrases and calculates the number of lists that have at least one entry that is the target phrase
"""
num_pos = 0
for sublist in phrases:
if sublist == None:
continue
for phrase in sublist:
if phrase == text_target:
num_pos += 1
break
return num_pos