import gradio as gr import torch import cv2 import numpy as np from PIL import Image import matplotlib.pyplot as plt import io from ultralytics import FastSAM from ultralytics.models.fastsam import FastSAMPrompt # Set up device device = "cuda" if torch.cuda.is_available() else "cpu" # Load FastSAM model model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt def fig2img(fig): buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img def plot(annotations, prompt_process, mask_random_color=True, better_quality=True, retina=True, with_contours=True): for ann in annotations: image = ann.orig_img[..., ::-1] # BGR to RGB original_h, original_w = ann.orig_shape fig = plt.figure(figsize=(original_w / 100, original_h / 100)) plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.imshow(image) if ann.masks is not None: masks = ann.masks.data if better_quality: if isinstance(masks[0], torch.Tensor): masks = np.array(masks.cpu()) for i, mask in enumerate(masks): mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) prompt_process.fast_show_mask( masks, plt.gca(), random_color=mask_random_color, bbox=None, points=None, pointlabel=None, retinamask=retina, target_height=original_h, target_width=original_w, ) if with_contours: contour_all = [] temp = np.zeros((original_h, original_w, 1)) for i, mask in enumerate(masks): mask = mask.astype(np.uint8) if not retina: mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST) contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) contour_all.extend(iter(contours)) cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) color = np.array([0 / 255, 0 / 255, 1.0, 0.8]) contour_mask = temp / 255 * color.reshape(1, 1, -1) plt.imshow(contour_mask) plt.axis("off") plt.close() return fig2img(fig) def segment_image(input_image, object_name): try: if input_image is None: return None, "Please upload an image before submitting." input_image = Image.fromarray(input_image).convert("RGB") # Run FastSAM model everything_results = model(input_image, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9) # Prepare a Prompt Process object prompt_process = FastSAMPrompt(input_image, everything_results, device=device) # Use text prompt to segment the specified object results = prompt_process.text_prompt(text=object_name) if not results: return input_image, f"Could not find '{object_name}' in the image." # Plot the results result_image = plot(annotations=results, prompt_process=prompt_process) return result_image, f"Segmented '{object_name}' in the image." except Exception as e: return None, f"An error occurred: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=segment_image, inputs=[ gr.Image(type="numpy", label="Upload an image"), gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)") ], outputs=[ gr.Image(type="pil", label="Segmented Image"), gr.Textbox(label="Status") ], title="FastSAM Segmentation with Object Specification", description="Upload an image and specify an object to segment using FastSAM." ) # Launch the interface iface.launch()