annayding commited on
Commit
bfa3aba
·
1 Parent(s): 376ad36

first commit

Browse files
Files changed (4) hide show
  1. app.py +116 -0
  2. owl_core.py +130 -0
  3. requirements.txt +10 -0
  4. utils.py +85 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ # set CUDA_HOME
4
+ os.environ["CUDA_HOME"] = "/usr/local/cuda-12.3/"
5
+
6
+ import gradio as gr
7
+ from tqdm import tqdm
8
+ import cv2
9
+ import os
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+
14
+ from typing import Tuple
15
+ from PIL import Image
16
+ from owl_core import owl_full_video
17
+
18
+
19
+ def run_owl(input_vid,
20
+ text_prompt,
21
+ confidence_threshold,
22
+ fps_processed,
23
+ scaling_factor
24
+ ):
25
+ new_input_vid = input_vid.replace(" ", "_")
26
+ os.rename(input_vid, new_input_vid)
27
+ csv_path, vid_path = owl_full_video(input_vid,
28
+ text_prompt,
29
+ confidence_threshold,
30
+ fps_processed=fps_processed,
31
+ scaling_factor=scaling_factor)
32
+
33
+ global CSV_PATH
34
+ CSV_PATH = csv_path
35
+ global VID_PATH
36
+ VID_PATH = vid_path
37
+ return vid_path
38
+
39
+ def vid_download():
40
+ """
41
+ """
42
+ print(CSV_PATH, VID_PATH)
43
+ return [CSV_PATH, VID_PATH]
44
+
45
+
46
+ with gr.Blocks() as demo:
47
+ gr.HTML(
48
+ """
49
+ <h1 align="center" style="font-size:xxx-large">🦍 Primate Detection</h1>
50
+ """
51
+ )
52
+
53
+ with gr.Row():
54
+ with gr.Column():
55
+ input = gr.Video(label="Input Video", interactive=True)
56
+ text_prompt = gr.Textbox(label="What do you want to detect? (Multiple species should be separated by commas")
57
+ with gr.Accordion("Advanced Options", open=False):
58
+ conf_threshold = gr.Slider(
59
+ label="Confidence Threshold",
60
+ info="Adjust the threshold to change the sensitivity of the model, lower thresholds being more sensitive.",
61
+ minimum=0.0,
62
+ maximum=1.0,
63
+ value=0.3,
64
+ step=0.05
65
+ )
66
+ fps_processed = gr.Slider(
67
+ label="Frame Detection Rate",
68
+ info="Adjust the frame detection rate. I.e. a value of 120 will run detection every 120 frames, a value of 1 will run detection on every frame. Note: the lower the number the slower the processing time.",
69
+ minimum=1,
70
+ maximum=120,
71
+ value=30,
72
+ step=1)
73
+ scaling_factor = gr.Slider(
74
+ label="Downsample Factor",
75
+ info="Adjust the downsample factor. Note: the higher the number the faster the processing time but lower the accuracy.",
76
+ minimum=1,
77
+ maximum=5,
78
+ value=2,
79
+ step=1
80
+ )
81
+
82
+ # TODO: Make button visible only after a file has been uploaded
83
+ run_btn = gr.Button(value="Run Detection", visible=True)
84
+ with gr.Column():
85
+ vid = gr.Video(label="Output Video", height=350, interactive=False, visible=True)
86
+ # download_btn = gr.Button(value="Generate Download", visible=True)
87
+ download_file = gr.Files(label="CSV, Video Output", interactive=False)
88
+
89
+ run_btn.click(fn=run_owl, inputs=[input, text_prompt, conf_threshold, fps_processed, scaling_factor, ], outputs=[vid])
90
+ vid.change(fn=vid_download, outputs=download_file)
91
+
92
+ # gr.Examples(
93
+ # [["baboon_15s.mp4", "baboon", 0.25, 0.25, 1, 1]],
94
+ # inputs = [input, text_prompt, conf_threshold, fps_processed, scaling_factor],
95
+ # outputs = [vid],
96
+ # fn=run_sam_dino,
97
+ # cache_examples=True,
98
+ # label='Example'
99
+ # )
100
+
101
+ gr.DuplicateButton()
102
+
103
+ gr.Markdown(
104
+ """
105
+ ## Frequently Asked Questions
106
+
107
+ ##### How can I run the interface on my own computer?
108
+ By clicking on the three dots on the top right corner of the interface, you will be able to clone the repository or run it with a Docker image on your local machine. \
109
+ For local machine setup instructions please check the README file.
110
+ ##### The video is very slow to process, how can I speed it up?
111
+ You can speed up the processing by adjusting the frame detection rate in the advanced options. The lower the number the slower the processing time. Choosing only\
112
+ bounding boxes will make the processing faster. You can also duplicate the space using the Duplicate Button and choose a different GPU which will make the processing faster.
113
+ """
114
+ )
115
+
116
+ demo.launch(share=False)
owl_core.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ import cv2
4
+ import os
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from datetime import datetime
9
+ from typing import Tuple
10
+ from PIL import Image
11
+ from utils import plot_predictions, mp4_to_png, vid_stitcher
12
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
13
+
14
+ def preprocess_text(text_prompt: str, num_prompts: int = 1):
15
+ """
16
+ Takes a string of text prompts and returns a list of lists of text prompts for each image.
17
+ i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]]
18
+ """
19
+ text_prompt = [s.strip() for s in text_prompt.split(",")]
20
+ text_queries = [text_prompt] * num_prompts
21
+ # print("text_queries:", text_queries)
22
+ return text_queries
23
+ def owl_batch_prediction(
24
+ images: torch.Tensor,
25
+ text_queries : list[str], # assuming that every image is queried with the same text prompt
26
+ threshold: float,
27
+ processor,
28
+ model,
29
+ device: str = 'cuda'
30
+ ):
31
+
32
+ inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device)
33
+ with torch.no_grad():
34
+ outputs = model(**inputs)
35
+
36
+ # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
37
+ target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)
38
+ # Convert outputs (bounding boxes and class logits) to COCO API, resizes to original image size and filter by threshold
39
+ results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold)
40
+
41
+ return results
42
+
43
+ def owl_full_video(
44
+ vid_path: str,
45
+ text_prompt: str,
46
+ threshold: float,
47
+ fps_processed: int = 1,
48
+ scaling_factor: float = 0.5,
49
+ processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble"),
50
+ model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble"),
51
+ device: str = 'cuda',
52
+ batch_size: int = 6,
53
+ ):
54
+ """ Same as owl_video, but processes the entire video regardless of detection bool.
55
+ Saves results per frame to a df.
56
+ """
57
+
58
+ # create new dirs and paths for results
59
+ filename = os.path.splitext(os.path.basename(vid_path))[0]
60
+ results_dir = f'../results/{filename}_{datetime.now().strftime("%H%M%S")}'
61
+ frames_dir = os.path.join(results_dir, "frames")
62
+
63
+ # if the frames directory does not exist, create it and get the frames from the video
64
+ if not os.path.exists(results_dir):
65
+ os.makedirs(results_dir, exist_ok=True)
66
+ os.makedirs(frames_dir, exist_ok=True)
67
+ # process video and create a directory of video frames
68
+ fps = mp4_to_png(vid_path, frames_dir, scaling_factor)
69
+
70
+ # get all frame paths
71
+ frame_filenames = os.listdir(frames_dir)
72
+
73
+ frame_paths = [] # list of frame paths to process based on fps_processed
74
+ # for every frame processed, add to frame_paths
75
+ for i, frame in enumerate(frame_filenames):
76
+ if i % fps_processed == 0:
77
+ frame_paths.append(os.path.join(frames_dir, frame))
78
+
79
+ # set up df for results
80
+ df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels"])
81
+
82
+ # for positive detection frames whether the directory has been created
83
+ dir_created = False
84
+
85
+ # run owl in batches
86
+ for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"):
87
+ frame_nums = [i*fps_processed for i in range(batch_size)]
88
+ batch_paths = frame_paths[i:i+batch_size] # paths for this batch
89
+ images = [Image.open(image_path) for image_path in batch_paths]
90
+
91
+ # run owl on this batch of frames
92
+ text_queries = preprocess_text(text_prompt, len(batch_paths))
93
+ results = owl_batch_prediction(images, text_queries, threshold, processor, model, device)
94
+
95
+ # get the labels
96
+ label_ids = []
97
+ for entry in results:
98
+ if entry['labels'].numel() > 0:
99
+ label_ids.append(entry['labels'].tolist())
100
+ else:
101
+ label_ids.append(None)
102
+
103
+ text = text_queries[0] # assuming that all texts in query are the same
104
+ labels = []
105
+ # convert label_ids to phrases, if no phrases, append None
106
+ for idx in label_ids:
107
+ if idx is not None:
108
+ idx = [text[id] for id in idx]
109
+ labels.append(idx)
110
+ else:
111
+ labels.append(None)
112
+
113
+ for j, image in enumerate(batch_paths):
114
+ boxes = results[j]['boxes'].cpu().numpy()
115
+ scores = results[j]['scores'].cpu().numpy()
116
+ row = pd.DataFrame({"frame": [image], "boxes": [boxes], "scores": [scores], "labels": [labels[j]]})
117
+ df = pd.concat([df, row], ignore_index=True)
118
+
119
+ # if there are detections, save the frame replacing the original frame
120
+ annotated_frame = plot_predictions(image, labels[j], scores, boxes)
121
+ cv2.imwrite(image, annotated_frame)
122
+
123
+ # save the df to a csv
124
+ csv_path = f"{results_dir}/{filename}_{threshold}.csv"
125
+ df.to_csv(csv_path, index=False)
126
+
127
+ # stitch the frames into a video
128
+ save_path = vid_stitcher(frames_dir, output_path=os.path.join(results_dir, "output.mp4"))
129
+
130
+ return csv_path, save_path
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.8.0
2
+ numpy==2.2.0
3
+ opencv_python==4.7.0.68
4
+ opencv_python_headless==4.8.1.78
5
+ pandas==1.4.2
6
+ Pillow==11.0.0
7
+ supervision==0.25.0
8
+ torch==2.0.1
9
+ tqdm==4.65.0
10
+ transformers==4.36.2
utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import supervision as sv
4
+ import cv2
5
+ import os
6
+ from glob import glob
7
+ from tqdm import tqdm
8
+
9
+
10
+ def plot_predictions(
11
+ image: str,
12
+ labels: list[str],
13
+ scores: torch.Tensor,
14
+ boxes: torch.Tensor,
15
+ ) -> np.ndarray:
16
+
17
+ image_source = cv2.imread(image)
18
+ image_source = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)
19
+
20
+ boxes = sv.Detections(xyxy=boxes.cpu().numpy())
21
+
22
+ labels = [
23
+ f"{phrase} {logit:.2f}"
24
+ for phrase, logit
25
+ in zip(labels, scores)
26
+ ]
27
+
28
+ bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
29
+ label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
30
+ annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
31
+ annotated_frame = bbox_annotator.annotate(scene=annotated_frame, detections=boxes)
32
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=boxes, labels=labels)
33
+
34
+ return annotated_frame
35
+
36
+ def mp4_to_png(input_path: str, save_path: str, scale_factor: float) -> str:
37
+ """ Converts mp4 to pngs for each frame of the video.
38
+ Args: input_path is the path to the mp4 file, save_path is the directory to save the frames.
39
+ Returns: save_path, fps the number of frames per second.
40
+ """
41
+ # get frames per second
42
+ fps = int(cv2.VideoCapture(input_path).get(cv2.CAP_PROP_FPS))
43
+ # run subprocess to convert mp4 to pngs
44
+ os.system(f"ffmpeg -i {input_path} -vf 'scale=iw*{scale_factor}:ih*{scale_factor}, fps={fps}' {save_path}/frame%08d.png")
45
+ return fps
46
+
47
+ def vid_stitcher(frames_dir: str, output_path: str, fps: int = 30) -> str:
48
+ """
49
+ Takes a list of frames as numpy arrays and writes them to a video file.
50
+ """
51
+ # Get the list of frames
52
+ frame_list = sorted(glob(os.path.join(frames_dir, 'frame*.png')))
53
+
54
+ # Prepare the VideoWriter
55
+ frame = cv2.imread(frame_list[0])
56
+ height, width, _ = frame.shape
57
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
58
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
59
+
60
+ # Use multithreading to read frames faster
61
+ from concurrent.futures import ThreadPoolExecutor
62
+ with ThreadPoolExecutor() as executor:
63
+ frames = list(executor.map(cv2.imread, frame_list))
64
+
65
+ # Write frames to the video
66
+ with tqdm(total=len(frame_list), desc='Stitching frames') as pbar:
67
+ for frame in frames:
68
+ out.write(frame)
69
+ pbar.update(1)
70
+
71
+ return output_path
72
+
73
+ def count_pos(phrases, text_target):
74
+ """
75
+ Takes a list of list of phrases and calculates the number of lists that have at least one entry that is the target phrase
76
+ """
77
+ num_pos = 0
78
+ for sublist in phrases:
79
+ if sublist == None:
80
+ continue
81
+ for phrase in sublist:
82
+ if phrase == text_target:
83
+ num_pos += 1
84
+ break
85
+ return num_pos