Spaces:
Sleeping
Sleeping
from transformers import pipeline, SamModel, SamProcessor | |
import torch | |
import numpy as np | |
import spaces | |
from PIL import Image, ImageDraw | |
# Load models (unchanged) | |
checkpoint = "google/owlvit-base-patch16" | |
detector = pipeline(model=checkpoint, task="zero-shot-object-detection") | |
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda") | |
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
def query(image, texts, threshold): | |
texts = texts.split(",") | |
# --- Object Detection (unchanged) --- | |
predictions = detector( | |
image, | |
candidate_labels=texts, | |
threshold=threshold | |
) | |
result_labels = [] | |
draw = ImageDraw.Draw(image) # Create a drawing object for the image | |
for pred in predictions: | |
box = pred["box"] | |
score = pred["score"] | |
label = pred["label"] | |
# Round box coordinates for display and SAM input (mostly unchanged) | |
box = [round(coord, 2) for coord in list(box.values())] | |
# --- Segmentation (unchanged) --- | |
inputs = sam_processor( | |
image, | |
input_boxes=[[[box]]], # Note: SAM expects a nested list | |
return_tensors="pt" | |
).to("cuda") | |
with torch.no_grad(): | |
outputs = sam_model(**inputs) | |
mask = sam_processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
)[0][0][0].numpy() | |
mask = mask[np.newaxis, ...] | |
result_labels.append((mask, label)) | |
# --- Draw Bounding Box --- | |
draw.rectangle(box, outline="red", width=3) # Draw rectangle with a red outline | |
draw.text((box[0], box[1] - 10), label, fill="red") # Add label above the box | |
return image, result_labels # Return the modified image | |
import gradio as gr | |
description = "This DSA2024 Demo Space combines OWLv2, a state-of-the-art zero-shot object detection model, with SAM, a state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment." | |
demo = gr.Interface( | |
query, | |
inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")], | |
outputs="annotatedimage", | |
title="OWL 🤝 SAM", | |
description=description, | |
examples=[ | |
["./cats.png", "cat", 0.1], | |
], | |
cache_examples=True | |
) | |
demo.launch(debug=True) |