|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from PIL import Image, ImageDraw |
|
import requests |
|
from transformers import SamModel, SamProcessor |
|
import cv2 |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) |
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") |
|
|
|
def mask_2_dots(mask): |
|
gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) |
|
_, thresh = cv2.threshold(gray, 127, 255, 0) |
|
kernel = np.ones((5,5),np.uint8) |
|
closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel) |
|
contours, _ = cv2.findContours(closed, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
points = [] |
|
for contour in contours: |
|
moments = cv2.moments(contour) |
|
cx = int(moments['m10']/moments['m00']) |
|
cy = int(moments['m01']/moments['m00']) |
|
points.append([cx, cy]) |
|
return [points] |
|
|
|
def main_func(inputs): |
|
dots = inputs['mask'] |
|
points = mask_2_dots(dots) |
|
|
|
image_input = inputs['image'] |
|
image_input = Image.fromarray(image_input) |
|
|
|
inputs = processor(image_input, input_points=points, return_tensors="pt").to(device) |
|
|
|
outputs = model(**inputs) |
|
|
|
|
|
draw = ImageDraw.Draw(image_input) |
|
for point in points[0]: |
|
draw.ellipse((point[0] - 10, point[1] - 10, point[0] + 10, point[1] + 10), fill="red") |
|
|
|
|
|
masks = processor.image_processor.post_process_masks( |
|
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() |
|
) |
|
|
|
|
|
mask = masks[0].squeeze(0).numpy().transpose(1, 2, 0) |
|
|
|
pred_masks = [image_input] |
|
for i in range(mask.shape[2]): |
|
|
|
pred_masks.append(Image.fromarray((mask[:,:,i] * 255).astype(np.uint8))) |
|
|
|
return pred_masks |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Demo to run Segment Anything base model") |
|
gr.Markdown("""This app uses the [Segment Anything](https://huggingface.co/facebook/sam-vit-base) model from Meta to get a mask from a points in an image. |
|
Currently it only works for creating dots for one object. But, I'm planning to add extra features to make it work for multiple objects. |
|
The output shows the image with the dots then the 3 predicted masks. |
|
""") |
|
with gr.Tab("Flip Image"): |
|
with gr.Row(): |
|
image_input = gr.Image(tool='sketch') |
|
image_output = gr.Gallery() |
|
|
|
image_button = gr.Button("Segment Image") |
|
|
|
image_button.click(main_func, inputs=image_input, outputs=image_output) |
|
|