Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import io | |
from ultralytics import FastSAM | |
from ultralytics.models.fastsam import FastSAMPrompt | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load FastSAM model | |
model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt | |
def fig2img(fig): | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def plot(annotations, prompt_process, mask_random_color=True, better_quality=True, retina=True, with_contours=True): | |
for ann in annotations: | |
image = ann.orig_img[..., ::-1] # BGR to RGB | |
original_h, original_w = ann.orig_shape | |
fig = plt.figure(figsize=(original_w / 100, original_h / 100)) | |
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) | |
plt.margins(0, 0) | |
plt.gca().xaxis.set_major_locator(plt.NullLocator()) | |
plt.gca().yaxis.set_major_locator(plt.NullLocator()) | |
plt.imshow(image) | |
if ann.masks is not None: | |
masks = ann.masks.data | |
if better_quality: | |
if isinstance(masks[0], torch.Tensor): | |
masks = np.array(masks.cpu()) | |
for i, mask in enumerate(masks): | |
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) | |
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) | |
prompt_process.fast_show_mask( | |
masks, | |
plt.gca(), | |
random_color=mask_random_color, | |
bbox=None, | |
points=None, | |
pointlabel=None, | |
retinamask=retina, | |
target_height=original_h, | |
target_width=original_w, | |
) | |
if with_contours: | |
contour_all = [] | |
temp = np.zeros((original_h, original_w, 1)) | |
for i, mask in enumerate(masks): | |
mask = mask.astype(np.uint8) | |
if not retina: | |
mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST) | |
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
contour_all.extend(iter(contours)) | |
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) | |
color = np.array([0 / 255, 0 / 255, 1.0, 0.8]) | |
contour_mask = temp / 255 * color.reshape(1, 1, -1) | |
plt.imshow(contour_mask) | |
plt.axis("off") | |
plt.close() | |
return fig2img(fig) | |
def segment_image(input_image, object_name): | |
try: | |
if input_image is None: | |
return None, "Please upload an image before submitting." | |
input_image = Image.fromarray(input_image).convert("RGB") | |
# Run FastSAM model | |
everything_results = model(input_image, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9) | |
# Prepare a Prompt Process object | |
prompt_process = FastSAMPrompt(input_image, everything_results, device=device) | |
# Use text prompt to segment the specified object | |
results = prompt_process.text_prompt(text=object_name) | |
if not results: | |
return input_image, f"Could not find '{object_name}' in the image." | |
# Plot the results | |
result_image = plot(annotations=results, prompt_process=prompt_process) | |
return result_image, f"Segmented '{object_name}' in the image." | |
except Exception as e: | |
return None, f"An error occurred: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=segment_image, | |
inputs=[ | |
gr.Image(type="numpy", label="Upload an image"), | |
gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Segmented Image"), | |
gr.Textbox(label="Status") | |
], | |
title="FastSAM Segmentation with Object Specification", | |
description="Upload an image and specify an object to segment using FastSAM." | |
) | |
# Launch the interface | |
iface.launch() |