File size: 2,886 Bytes
bfa3aba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import numpy as np
import supervision as sv
import cv2
import os
from glob import glob
from tqdm import tqdm


def plot_predictions(
        image: str,
        labels: list[str],
        scores: torch.Tensor,
        boxes: torch.Tensor,
    ) -> np.ndarray:

    image_source = cv2.imread(image)
    image_source = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)

    boxes = sv.Detections(xyxy=boxes.cpu().numpy())

    labels = [
        f"{phrase} {logit:.2f}"
        for phrase, logit
        in zip(labels, scores)
    ]

    bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
    label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
    annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
    annotated_frame = bbox_annotator.annotate(scene=annotated_frame, detections=boxes)
    annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=boxes, labels=labels)

    return annotated_frame

def mp4_to_png(input_path: str, save_path: str, scale_factor: float) -> str: 
    """ Converts mp4 to pngs for each frame of the video.
        Args: input_path is the path to the mp4 file, save_path is the directory to save the frames.
        Returns: save_path, fps the number of frames per second.
    """  
    # get frames per second
    fps = int(cv2.VideoCapture(input_path).get(cv2.CAP_PROP_FPS))
    # run subprocess to convert mp4 to pngs 
    os.system(f"ffmpeg -i {input_path} -vf 'scale=iw*{scale_factor}:ih*{scale_factor}, fps={fps}' {save_path}/frame%08d.png")
    return fps

def vid_stitcher(frames_dir: str, output_path: str, fps: int = 30) -> str:
    """
    Takes a list of frames as numpy arrays and writes them to a video file.
    """
    # Get the list of frames
    frame_list = sorted(glob(os.path.join(frames_dir, 'frame*.png')))

    # Prepare the VideoWriter
    frame = cv2.imread(frame_list[0])
    height, width, _ = frame.shape
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    # Use multithreading to read frames faster
    from concurrent.futures import ThreadPoolExecutor
    with ThreadPoolExecutor() as executor:
        frames = list(executor.map(cv2.imread, frame_list))

    # Write frames to the video
    with tqdm(total=len(frame_list), desc='Stitching frames') as pbar:
        for frame in frames:
            out.write(frame)
            pbar.update(1)
    
    return output_path

def count_pos(phrases, text_target):
    """
    Takes a list of list of phrases and calculates the number of lists that have at least one entry that is the target phrase
    """
    num_pos = 0
    for sublist in phrases:
        if sublist == None: 
            continue
        for phrase in sublist:
            if phrase == text_target:
                num_pos += 1
                break
    return num_pos