|
import torch |
|
from tqdm import tqdm |
|
import cv2 |
|
import os |
|
import numpy as np |
|
import pandas as pd |
|
|
|
from datetime import datetime |
|
from typing import Tuple |
|
from PIL import Image |
|
from utils import plot_predictions, mp4_to_png, vid_stitcher |
|
from transformers import Owlv2Processor, Owlv2ForObjectDetection |
|
|
|
def preprocess_text(text_prompt: str, num_prompts: int = 1): |
|
""" |
|
Takes a string of text prompts and returns a list of lists of text prompts for each image. |
|
i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]] |
|
""" |
|
text_prompt = [s.strip() for s in text_prompt.split(",")] |
|
text_queries = [text_prompt] * num_prompts |
|
|
|
return text_queries |
|
def owl_batch_prediction( |
|
images: torch.Tensor, |
|
text_queries : list[str], |
|
threshold: float, |
|
processor, |
|
model, |
|
device: str = 'cuda' |
|
): |
|
|
|
inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device) |
|
|
|
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold) |
|
|
|
return results |
|
|
|
def owl_full_video( |
|
vid_path: str, |
|
text_prompt: str, |
|
threshold: float, |
|
fps_processed: int = 1, |
|
scaling_factor: float = 0.5, |
|
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble"), |
|
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble"), |
|
device: str = 'cuda', |
|
batch_size: int = 6, |
|
): |
|
""" Same as owl_video, but processes the entire video regardless of detection bool. |
|
Saves results per frame to a df. |
|
""" |
|
|
|
|
|
filename = os.path.splitext(os.path.basename(vid_path))[0] |
|
results_dir = f'../results/{filename}_{datetime.now().strftime("%H%M%S")}' |
|
frames_dir = os.path.join(results_dir, "frames") |
|
|
|
|
|
if not os.path.exists(results_dir): |
|
os.makedirs(results_dir, exist_ok=True) |
|
os.makedirs(frames_dir, exist_ok=True) |
|
|
|
fps = mp4_to_png(vid_path, frames_dir, scaling_factor) |
|
|
|
|
|
frame_filenames = os.listdir(frames_dir) |
|
|
|
frame_paths = [] |
|
|
|
for i, frame in enumerate(frame_filenames): |
|
if i % fps_processed == 0: |
|
frame_paths.append(os.path.join(frames_dir, frame)) |
|
|
|
|
|
df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels"]) |
|
|
|
|
|
dir_created = False |
|
|
|
|
|
for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"): |
|
frame_nums = [i*fps_processed for i in range(batch_size)] |
|
batch_paths = frame_paths[i:i+batch_size] |
|
images = [Image.open(image_path) for image_path in batch_paths] |
|
|
|
|
|
text_queries = preprocess_text(text_prompt, len(batch_paths)) |
|
results = owl_batch_prediction(images, text_queries, threshold, processor, model, device) |
|
|
|
|
|
label_ids = [] |
|
for entry in results: |
|
if entry['labels'].numel() > 0: |
|
label_ids.append(entry['labels'].tolist()) |
|
else: |
|
label_ids.append(None) |
|
|
|
text = text_queries[0] |
|
labels = [] |
|
|
|
for idx in label_ids: |
|
if idx is not None: |
|
idx = [text[id] for id in idx] |
|
labels.append(idx) |
|
else: |
|
labels.append(None) |
|
|
|
for j, image in enumerate(batch_paths): |
|
boxes = results[j]['boxes'].cpu().numpy() |
|
scores = results[j]['scores'].cpu().numpy() |
|
row = pd.DataFrame({"frame": [image], "boxes": [boxes], "scores": [scores], "labels": [labels[j]]}) |
|
df = pd.concat([df, row], ignore_index=True) |
|
|
|
|
|
annotated_frame = plot_predictions(image, labels[j], scores, boxes) |
|
cv2.imwrite(image, annotated_frame) |
|
|
|
|
|
csv_path = f"{results_dir}/{filename}_{threshold}.csv" |
|
df.to_csv(csv_path, index=False) |
|
|
|
|
|
save_path = vid_stitcher(frames_dir, output_path=os.path.join(results_dir, "output.mp4")) |
|
|
|
return csv_path, save_path |