mattmdjaga
commited on
Commit
•
cc7fbfd
1
Parent(s):
7dbfa30
App init
Browse files
app.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image, ImageDraw
|
5 |
+
import requests
|
6 |
+
from transformers import SamModel, SamProcessor
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
|
11 |
+
# Load model and processor
|
12 |
+
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
|
13 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
14 |
+
|
15 |
+
def mask_2_dots(mask):
|
16 |
+
gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
17 |
+
_, thresh = cv2.threshold(gray, 127, 255, 0)
|
18 |
+
kernel = np.ones((5,5),np.uint8)
|
19 |
+
closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
|
20 |
+
contours, _ = cv2.findContours(closed, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
21 |
+
points = []
|
22 |
+
for contour in contours:
|
23 |
+
moments = cv2.moments(contour)
|
24 |
+
cx = int(moments['m10']/moments['m00'])
|
25 |
+
cy = int(moments['m01']/moments['m00'])
|
26 |
+
points.append([cx, cy])
|
27 |
+
return [points]
|
28 |
+
|
29 |
+
def main_func(inputs):
|
30 |
+
dots = inputs['mask']
|
31 |
+
points = mask_2_dots(dots)
|
32 |
+
|
33 |
+
image_input = inputs['image']
|
34 |
+
image_input = Image.fromarray(image_input)
|
35 |
+
|
36 |
+
inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
|
37 |
+
# Forward pass
|
38 |
+
outputs = model(**inputs)
|
39 |
+
|
40 |
+
# Postprocess outputs
|
41 |
+
draw = ImageDraw.Draw(image_input)
|
42 |
+
for point in points[0]:
|
43 |
+
draw.ellipse((point[0] - 10, point[1] - 10, point[0] + 10, point[1] + 10), fill="red")
|
44 |
+
|
45 |
+
|
46 |
+
masks = processor.image_processor.post_process_masks(
|
47 |
+
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
|
48 |
+
)
|
49 |
+
#scores = outputs.iou_scores
|
50 |
+
|
51 |
+
mask = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
|
52 |
+
|
53 |
+
pred_masks = [image_input]
|
54 |
+
for i in range(mask.shape[2]):
|
55 |
+
#mask[:,:,i] = mask[:,:,i] * scores[0][i].item()
|
56 |
+
pred_masks.append(Image.fromarray((mask[:,:,i] * 255).astype(np.uint8)))
|
57 |
+
|
58 |
+
return pred_masks
|
59 |
+
|
60 |
+
|
61 |
+
with gr.Blocks() as demo:
|
62 |
+
gr.Markdown("# Demo to run Segment Anything base model")
|
63 |
+
gr.Markdown("""This app uses the [Segment Anything](https://huggingface.co/facebook/sam-vit-base) model from Meta to get a mask from a points in an image.
|
64 |
+
Currently it only works for creating dots for one object. But, I'm planning to add extra features to make it work for multiple objects.
|
65 |
+
The output shows the image with the dots then the 3 predicted masks.
|
66 |
+
""")
|
67 |
+
with gr.Tab("Flip Image"):
|
68 |
+
with gr.Row():
|
69 |
+
image_input = gr.Image(tool='sketch')
|
70 |
+
image_output = gr.Gallery()
|
71 |
+
|
72 |
+
image_button = gr.Button("Segment Image")
|
73 |
+
|
74 |
+
image_button.click(main_func, inputs=image_input, outputs=image_output)
|