annayding
commited on
Commit
·
bfa3aba
1
Parent(s):
376ad36
first commit
Browse files- app.py +116 -0
- owl_core.py +130 -0
- requirements.txt +10 -0
- utils.py +85 -0
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
# set CUDA_HOME
|
4 |
+
os.environ["CUDA_HOME"] = "/usr/local/cuda-12.3/"
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
from tqdm import tqdm
|
8 |
+
import cv2
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from typing import Tuple
|
15 |
+
from PIL import Image
|
16 |
+
from owl_core import owl_full_video
|
17 |
+
|
18 |
+
|
19 |
+
def run_owl(input_vid,
|
20 |
+
text_prompt,
|
21 |
+
confidence_threshold,
|
22 |
+
fps_processed,
|
23 |
+
scaling_factor
|
24 |
+
):
|
25 |
+
new_input_vid = input_vid.replace(" ", "_")
|
26 |
+
os.rename(input_vid, new_input_vid)
|
27 |
+
csv_path, vid_path = owl_full_video(input_vid,
|
28 |
+
text_prompt,
|
29 |
+
confidence_threshold,
|
30 |
+
fps_processed=fps_processed,
|
31 |
+
scaling_factor=scaling_factor)
|
32 |
+
|
33 |
+
global CSV_PATH
|
34 |
+
CSV_PATH = csv_path
|
35 |
+
global VID_PATH
|
36 |
+
VID_PATH = vid_path
|
37 |
+
return vid_path
|
38 |
+
|
39 |
+
def vid_download():
|
40 |
+
"""
|
41 |
+
"""
|
42 |
+
print(CSV_PATH, VID_PATH)
|
43 |
+
return [CSV_PATH, VID_PATH]
|
44 |
+
|
45 |
+
|
46 |
+
with gr.Blocks() as demo:
|
47 |
+
gr.HTML(
|
48 |
+
"""
|
49 |
+
<h1 align="center" style="font-size:xxx-large">🦍 Primate Detection</h1>
|
50 |
+
"""
|
51 |
+
)
|
52 |
+
|
53 |
+
with gr.Row():
|
54 |
+
with gr.Column():
|
55 |
+
input = gr.Video(label="Input Video", interactive=True)
|
56 |
+
text_prompt = gr.Textbox(label="What do you want to detect? (Multiple species should be separated by commas")
|
57 |
+
with gr.Accordion("Advanced Options", open=False):
|
58 |
+
conf_threshold = gr.Slider(
|
59 |
+
label="Confidence Threshold",
|
60 |
+
info="Adjust the threshold to change the sensitivity of the model, lower thresholds being more sensitive.",
|
61 |
+
minimum=0.0,
|
62 |
+
maximum=1.0,
|
63 |
+
value=0.3,
|
64 |
+
step=0.05
|
65 |
+
)
|
66 |
+
fps_processed = gr.Slider(
|
67 |
+
label="Frame Detection Rate",
|
68 |
+
info="Adjust the frame detection rate. I.e. a value of 120 will run detection every 120 frames, a value of 1 will run detection on every frame. Note: the lower the number the slower the processing time.",
|
69 |
+
minimum=1,
|
70 |
+
maximum=120,
|
71 |
+
value=30,
|
72 |
+
step=1)
|
73 |
+
scaling_factor = gr.Slider(
|
74 |
+
label="Downsample Factor",
|
75 |
+
info="Adjust the downsample factor. Note: the higher the number the faster the processing time but lower the accuracy.",
|
76 |
+
minimum=1,
|
77 |
+
maximum=5,
|
78 |
+
value=2,
|
79 |
+
step=1
|
80 |
+
)
|
81 |
+
|
82 |
+
# TODO: Make button visible only after a file has been uploaded
|
83 |
+
run_btn = gr.Button(value="Run Detection", visible=True)
|
84 |
+
with gr.Column():
|
85 |
+
vid = gr.Video(label="Output Video", height=350, interactive=False, visible=True)
|
86 |
+
# download_btn = gr.Button(value="Generate Download", visible=True)
|
87 |
+
download_file = gr.Files(label="CSV, Video Output", interactive=False)
|
88 |
+
|
89 |
+
run_btn.click(fn=run_owl, inputs=[input, text_prompt, conf_threshold, fps_processed, scaling_factor, ], outputs=[vid])
|
90 |
+
vid.change(fn=vid_download, outputs=download_file)
|
91 |
+
|
92 |
+
# gr.Examples(
|
93 |
+
# [["baboon_15s.mp4", "baboon", 0.25, 0.25, 1, 1]],
|
94 |
+
# inputs = [input, text_prompt, conf_threshold, fps_processed, scaling_factor],
|
95 |
+
# outputs = [vid],
|
96 |
+
# fn=run_sam_dino,
|
97 |
+
# cache_examples=True,
|
98 |
+
# label='Example'
|
99 |
+
# )
|
100 |
+
|
101 |
+
gr.DuplicateButton()
|
102 |
+
|
103 |
+
gr.Markdown(
|
104 |
+
"""
|
105 |
+
## Frequently Asked Questions
|
106 |
+
|
107 |
+
##### How can I run the interface on my own computer?
|
108 |
+
By clicking on the three dots on the top right corner of the interface, you will be able to clone the repository or run it with a Docker image on your local machine. \
|
109 |
+
For local machine setup instructions please check the README file.
|
110 |
+
##### The video is very slow to process, how can I speed it up?
|
111 |
+
You can speed up the processing by adjusting the frame detection rate in the advanced options. The lower the number the slower the processing time. Choosing only\
|
112 |
+
bounding boxes will make the processing faster. You can also duplicate the space using the Duplicate Button and choose a different GPU which will make the processing faster.
|
113 |
+
"""
|
114 |
+
)
|
115 |
+
|
116 |
+
demo.launch(share=False)
|
owl_core.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
import cv2
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
from datetime import datetime
|
9 |
+
from typing import Tuple
|
10 |
+
from PIL import Image
|
11 |
+
from utils import plot_predictions, mp4_to_png, vid_stitcher
|
12 |
+
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
13 |
+
|
14 |
+
def preprocess_text(text_prompt: str, num_prompts: int = 1):
|
15 |
+
"""
|
16 |
+
Takes a string of text prompts and returns a list of lists of text prompts for each image.
|
17 |
+
i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]]
|
18 |
+
"""
|
19 |
+
text_prompt = [s.strip() for s in text_prompt.split(",")]
|
20 |
+
text_queries = [text_prompt] * num_prompts
|
21 |
+
# print("text_queries:", text_queries)
|
22 |
+
return text_queries
|
23 |
+
def owl_batch_prediction(
|
24 |
+
images: torch.Tensor,
|
25 |
+
text_queries : list[str], # assuming that every image is queried with the same text prompt
|
26 |
+
threshold: float,
|
27 |
+
processor,
|
28 |
+
model,
|
29 |
+
device: str = 'cuda'
|
30 |
+
):
|
31 |
+
|
32 |
+
inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device)
|
33 |
+
with torch.no_grad():
|
34 |
+
outputs = model(**inputs)
|
35 |
+
|
36 |
+
# Target image sizes (height, width) to rescale box predictions [batch_size, 2]
|
37 |
+
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)
|
38 |
+
# Convert outputs (bounding boxes and class logits) to COCO API, resizes to original image size and filter by threshold
|
39 |
+
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold)
|
40 |
+
|
41 |
+
return results
|
42 |
+
|
43 |
+
def owl_full_video(
|
44 |
+
vid_path: str,
|
45 |
+
text_prompt: str,
|
46 |
+
threshold: float,
|
47 |
+
fps_processed: int = 1,
|
48 |
+
scaling_factor: float = 0.5,
|
49 |
+
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble"),
|
50 |
+
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble"),
|
51 |
+
device: str = 'cuda',
|
52 |
+
batch_size: int = 6,
|
53 |
+
):
|
54 |
+
""" Same as owl_video, but processes the entire video regardless of detection bool.
|
55 |
+
Saves results per frame to a df.
|
56 |
+
"""
|
57 |
+
|
58 |
+
# create new dirs and paths for results
|
59 |
+
filename = os.path.splitext(os.path.basename(vid_path))[0]
|
60 |
+
results_dir = f'../results/{filename}_{datetime.now().strftime("%H%M%S")}'
|
61 |
+
frames_dir = os.path.join(results_dir, "frames")
|
62 |
+
|
63 |
+
# if the frames directory does not exist, create it and get the frames from the video
|
64 |
+
if not os.path.exists(results_dir):
|
65 |
+
os.makedirs(results_dir, exist_ok=True)
|
66 |
+
os.makedirs(frames_dir, exist_ok=True)
|
67 |
+
# process video and create a directory of video frames
|
68 |
+
fps = mp4_to_png(vid_path, frames_dir, scaling_factor)
|
69 |
+
|
70 |
+
# get all frame paths
|
71 |
+
frame_filenames = os.listdir(frames_dir)
|
72 |
+
|
73 |
+
frame_paths = [] # list of frame paths to process based on fps_processed
|
74 |
+
# for every frame processed, add to frame_paths
|
75 |
+
for i, frame in enumerate(frame_filenames):
|
76 |
+
if i % fps_processed == 0:
|
77 |
+
frame_paths.append(os.path.join(frames_dir, frame))
|
78 |
+
|
79 |
+
# set up df for results
|
80 |
+
df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels"])
|
81 |
+
|
82 |
+
# for positive detection frames whether the directory has been created
|
83 |
+
dir_created = False
|
84 |
+
|
85 |
+
# run owl in batches
|
86 |
+
for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"):
|
87 |
+
frame_nums = [i*fps_processed for i in range(batch_size)]
|
88 |
+
batch_paths = frame_paths[i:i+batch_size] # paths for this batch
|
89 |
+
images = [Image.open(image_path) for image_path in batch_paths]
|
90 |
+
|
91 |
+
# run owl on this batch of frames
|
92 |
+
text_queries = preprocess_text(text_prompt, len(batch_paths))
|
93 |
+
results = owl_batch_prediction(images, text_queries, threshold, processor, model, device)
|
94 |
+
|
95 |
+
# get the labels
|
96 |
+
label_ids = []
|
97 |
+
for entry in results:
|
98 |
+
if entry['labels'].numel() > 0:
|
99 |
+
label_ids.append(entry['labels'].tolist())
|
100 |
+
else:
|
101 |
+
label_ids.append(None)
|
102 |
+
|
103 |
+
text = text_queries[0] # assuming that all texts in query are the same
|
104 |
+
labels = []
|
105 |
+
# convert label_ids to phrases, if no phrases, append None
|
106 |
+
for idx in label_ids:
|
107 |
+
if idx is not None:
|
108 |
+
idx = [text[id] for id in idx]
|
109 |
+
labels.append(idx)
|
110 |
+
else:
|
111 |
+
labels.append(None)
|
112 |
+
|
113 |
+
for j, image in enumerate(batch_paths):
|
114 |
+
boxes = results[j]['boxes'].cpu().numpy()
|
115 |
+
scores = results[j]['scores'].cpu().numpy()
|
116 |
+
row = pd.DataFrame({"frame": [image], "boxes": [boxes], "scores": [scores], "labels": [labels[j]]})
|
117 |
+
df = pd.concat([df, row], ignore_index=True)
|
118 |
+
|
119 |
+
# if there are detections, save the frame replacing the original frame
|
120 |
+
annotated_frame = plot_predictions(image, labels[j], scores, boxes)
|
121 |
+
cv2.imwrite(image, annotated_frame)
|
122 |
+
|
123 |
+
# save the df to a csv
|
124 |
+
csv_path = f"{results_dir}/{filename}_{threshold}.csv"
|
125 |
+
df.to_csv(csv_path, index=False)
|
126 |
+
|
127 |
+
# stitch the frames into a video
|
128 |
+
save_path = vid_stitcher(frames_dir, output_path=os.path.join(results_dir, "output.mp4"))
|
129 |
+
|
130 |
+
return csv_path, save_path
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==5.8.0
|
2 |
+
numpy==2.2.0
|
3 |
+
opencv_python==4.7.0.68
|
4 |
+
opencv_python_headless==4.8.1.78
|
5 |
+
pandas==1.4.2
|
6 |
+
Pillow==11.0.0
|
7 |
+
supervision==0.25.0
|
8 |
+
torch==2.0.1
|
9 |
+
tqdm==4.65.0
|
10 |
+
transformers==4.36.2
|
utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import supervision as sv
|
4 |
+
import cv2
|
5 |
+
import os
|
6 |
+
from glob import glob
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def plot_predictions(
|
11 |
+
image: str,
|
12 |
+
labels: list[str],
|
13 |
+
scores: torch.Tensor,
|
14 |
+
boxes: torch.Tensor,
|
15 |
+
) -> np.ndarray:
|
16 |
+
|
17 |
+
image_source = cv2.imread(image)
|
18 |
+
image_source = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)
|
19 |
+
|
20 |
+
boxes = sv.Detections(xyxy=boxes.cpu().numpy())
|
21 |
+
|
22 |
+
labels = [
|
23 |
+
f"{phrase} {logit:.2f}"
|
24 |
+
for phrase, logit
|
25 |
+
in zip(labels, scores)
|
26 |
+
]
|
27 |
+
|
28 |
+
bbox_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
29 |
+
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
|
30 |
+
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
|
31 |
+
annotated_frame = bbox_annotator.annotate(scene=annotated_frame, detections=boxes)
|
32 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=boxes, labels=labels)
|
33 |
+
|
34 |
+
return annotated_frame
|
35 |
+
|
36 |
+
def mp4_to_png(input_path: str, save_path: str, scale_factor: float) -> str:
|
37 |
+
""" Converts mp4 to pngs for each frame of the video.
|
38 |
+
Args: input_path is the path to the mp4 file, save_path is the directory to save the frames.
|
39 |
+
Returns: save_path, fps the number of frames per second.
|
40 |
+
"""
|
41 |
+
# get frames per second
|
42 |
+
fps = int(cv2.VideoCapture(input_path).get(cv2.CAP_PROP_FPS))
|
43 |
+
# run subprocess to convert mp4 to pngs
|
44 |
+
os.system(f"ffmpeg -i {input_path} -vf 'scale=iw*{scale_factor}:ih*{scale_factor}, fps={fps}' {save_path}/frame%08d.png")
|
45 |
+
return fps
|
46 |
+
|
47 |
+
def vid_stitcher(frames_dir: str, output_path: str, fps: int = 30) -> str:
|
48 |
+
"""
|
49 |
+
Takes a list of frames as numpy arrays and writes them to a video file.
|
50 |
+
"""
|
51 |
+
# Get the list of frames
|
52 |
+
frame_list = sorted(glob(os.path.join(frames_dir, 'frame*.png')))
|
53 |
+
|
54 |
+
# Prepare the VideoWriter
|
55 |
+
frame = cv2.imread(frame_list[0])
|
56 |
+
height, width, _ = frame.shape
|
57 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
58 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
59 |
+
|
60 |
+
# Use multithreading to read frames faster
|
61 |
+
from concurrent.futures import ThreadPoolExecutor
|
62 |
+
with ThreadPoolExecutor() as executor:
|
63 |
+
frames = list(executor.map(cv2.imread, frame_list))
|
64 |
+
|
65 |
+
# Write frames to the video
|
66 |
+
with tqdm(total=len(frame_list), desc='Stitching frames') as pbar:
|
67 |
+
for frame in frames:
|
68 |
+
out.write(frame)
|
69 |
+
pbar.update(1)
|
70 |
+
|
71 |
+
return output_path
|
72 |
+
|
73 |
+
def count_pos(phrases, text_target):
|
74 |
+
"""
|
75 |
+
Takes a list of list of phrases and calculates the number of lists that have at least one entry that is the target phrase
|
76 |
+
"""
|
77 |
+
num_pos = 0
|
78 |
+
for sublist in phrases:
|
79 |
+
if sublist == None:
|
80 |
+
continue
|
81 |
+
for phrase in sublist:
|
82 |
+
if phrase == text_target:
|
83 |
+
num_pos += 1
|
84 |
+
break
|
85 |
+
return num_pos
|