OWLSAM_DSA2024 / app.py
andrewkatumba's picture
Update app.py
ba59054 verified
from transformers import pipeline, SamModel, SamProcessor
import torch
import numpy as np
import spaces
from PIL import Image, ImageDraw
# Load models (unchanged)
checkpoint = "google/owlvit-base-patch16"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
@spaces.GPU
def query(image, texts, threshold):
texts = texts.split(",")
# --- Object Detection (unchanged) ---
predictions = detector(
image,
candidate_labels=texts,
threshold=threshold
)
result_labels = []
draw = ImageDraw.Draw(image) # Create a drawing object for the image
for pred in predictions:
box = pred["box"]
score = pred["score"]
label = pred["label"]
# Round box coordinates for display and SAM input (mostly unchanged)
box = [round(coord, 2) for coord in list(box.values())]
# --- Segmentation (unchanged) ---
inputs = sam_processor(
image,
input_boxes=[[[box]]], # Note: SAM expects a nested list
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = sam_model(**inputs)
mask = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask = mask[np.newaxis, ...]
result_labels.append((mask, label))
# --- Draw Bounding Box ---
draw.rectangle(box, outline="red", width=3) # Draw rectangle with a red outline
draw.text((box[0], box[1] - 10), label, fill="red") # Add label above the box
return image, result_labels # Return the modified image
import gradio as gr
description = "This DSA2024 Demo Space combines OWLv2, a state-of-the-art zero-shot object detection model, with SAM, a state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
demo = gr.Interface(
query,
inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
outputs="annotatedimage",
title="OWL 🤝 SAM",
description=description,
examples=[
["./cats.png", "cat", 0.1],
],
cache_examples=True
)
demo.launch(debug=True)