File size: 2,886 Bytes
bfa3aba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import torch
import numpy as np
import supervision as sv
import cv2
import os
from glob import glob
from tqdm import tqdm
def plot_predictions(
image: str,
labels: list[str],
scores: torch.Tensor,
boxes: torch.Tensor,
) -> np.ndarray:
image_source = cv2.imread(image)
image_source = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)
boxes = sv.Detections(xyxy=boxes.cpu().numpy())
labels = [
f"{phrase} {logit:.2f}"
for phrase, logit
in zip(labels, scores)
]
bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
annotated_frame = bbox_annotator.annotate(scene=annotated_frame, detections=boxes)
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=boxes, labels=labels)
return annotated_frame
def mp4_to_png(input_path: str, save_path: str, scale_factor: float) -> str:
""" Converts mp4 to pngs for each frame of the video.
Args: input_path is the path to the mp4 file, save_path is the directory to save the frames.
Returns: save_path, fps the number of frames per second.
"""
# get frames per second
fps = int(cv2.VideoCapture(input_path).get(cv2.CAP_PROP_FPS))
# run subprocess to convert mp4 to pngs
os.system(f"ffmpeg -i {input_path} -vf 'scale=iw*{scale_factor}:ih*{scale_factor}, fps={fps}' {save_path}/frame%08d.png")
return fps
def vid_stitcher(frames_dir: str, output_path: str, fps: int = 30) -> str:
"""
Takes a list of frames as numpy arrays and writes them to a video file.
"""
# Get the list of frames
frame_list = sorted(glob(os.path.join(frames_dir, 'frame*.png')))
# Prepare the VideoWriter
frame = cv2.imread(frame_list[0])
height, width, _ = frame.shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# Use multithreading to read frames faster
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
frames = list(executor.map(cv2.imread, frame_list))
# Write frames to the video
with tqdm(total=len(frame_list), desc='Stitching frames') as pbar:
for frame in frames:
out.write(frame)
pbar.update(1)
return output_path
def count_pos(phrases, text_target):
"""
Takes a list of list of phrases and calculates the number of lists that have at least one entry that is the target phrase
"""
num_pos = 0
for sublist in phrases:
if sublist == None:
continue
for phrase in sublist:
if phrase == text_target:
num_pos += 1
break
return num_pos |