import torch from tqdm import tqdm import cv2 import os import numpy as np import pandas as pd from datetime import datetime from typing import Tuple from PIL import Image from utils import plot_predictions, mp4_to_png, vid_stitcher from transformers import Owlv2Processor, Owlv2ForObjectDetection def preprocess_text(text_prompt: str, num_prompts: int = 1): """ Takes a string of text prompts and returns a list of lists of text prompts for each image. i.e. text_prompt = "a, b, c" -> [["a", "b", "c"], ["a", "b", "c"]] """ text_prompt = [s.strip() for s in text_prompt.split(",")] text_queries = [text_prompt] * num_prompts # print("text_queries:", text_queries) return text_queries def owl_batch_prediction( images: torch.Tensor, text_queries : list[str], # assuming that every image is queried with the same text prompt threshold: float, processor, model, device: str = 'cuda' ): inputs = processor(text=text_queries, images=images, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) # Target image sizes (height, width) to rescale box predictions [batch_size, 2] target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device) # Convert outputs (bounding boxes and class logits) to COCO API, resizes to original image size and filter by threshold results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold) return results def owl_full_video( vid_path: str, text_prompt: str, threshold: float, fps_processed: int = 1, scaling_factor: float = 0.5, processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble").to('cuda'), model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to('cuda'), device: str = 'cuda', batch_size: int = 6, ): """ Same as owl_video, but processes the entire video regardless of detection bool. Saves results per frame to a df. """ # create new dirs and paths for results filename = os.path.splitext(os.path.basename(vid_path))[0] results_dir = f'../results/{filename}_{datetime.now().strftime("%H%M%S")}' frames_dir = os.path.join(results_dir, "frames") # if the frames directory does not exist, create it and get the frames from the video if not os.path.exists(results_dir): os.makedirs(results_dir, exist_ok=True) os.makedirs(frames_dir, exist_ok=True) # process video and create a directory of video frames fps = mp4_to_png(vid_path, frames_dir, scaling_factor) # get all frame paths frame_filenames = os.listdir(frames_dir) frame_paths = [] # list of frame paths to process based on fps_processed # for every frame processed, add to frame_paths for i, frame in enumerate(frame_filenames): if i % fps_processed == 0: frame_paths.append(os.path.join(frames_dir, frame)) # set up df for results df = pd.DataFrame(columns=["frame", "boxes", "scores", "labels"]) # for positive detection frames whether the directory has been created dir_created = False # run owl in batches for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"): frame_nums = [i*fps_processed for i in range(batch_size)] batch_paths = frame_paths[i:i+batch_size] # paths for this batch images = [Image.open(image_path) for image_path in batch_paths] # run owl on this batch of frames text_queries = preprocess_text(text_prompt, len(batch_paths)) results = owl_batch_prediction(images, text_queries, threshold, processor, model, device) # get the labels label_ids = [] for entry in results: if entry['labels'].numel() > 0: label_ids.append(entry['labels'].tolist()) else: label_ids.append(None) text = text_queries[0] # assuming that all texts in query are the same labels = [] # convert label_ids to phrases, if no phrases, append None for idx in label_ids: if idx is not None: idx = [text[id] for id in idx] labels.append(idx) else: labels.append(None) for j, image in enumerate(batch_paths): boxes = results[j]['boxes'].cpu().numpy() scores = results[j]['scores'].cpu().numpy() row = pd.DataFrame({"frame": [image], "boxes": [boxes], "scores": [scores], "labels": [labels[j]]}) df = pd.concat([df, row], ignore_index=True) # if there are detections, save the frame replacing the original frame annotated_frame = plot_predictions(image, labels[j], scores, boxes) cv2.imwrite(image, annotated_frame) # save the df to a csv csv_path = f"{results_dir}/{filename}_{threshold}.csv" df.to_csv(csv_path, index=False) # stitch the frames into a video save_path = vid_stitcher(frames_dir, output_path=os.path.join(results_dir, "output.mp4")) return csv_path, save_path