Spaces:
Running
Running
dragonSwing
commited on
Commit
•
67d4d3e
1
Parent(s):
8b0abc0
Upload files
Browse files- .gitignore +12 -0
- app.py +214 -0
- assets/.DS_Store +0 -0
- assets/dogs.jpg +0 -0
- assets/picture1.jpg +0 -0
- assets/picture2.jpg +0 -0
- assets/picture3.jpg +0 -0
- assets/picture4.jpg +0 -0
- assets/picture5.jpg +0 -0
- assets/picture6.jpg +0 -0
- requirements.txt +5 -0
- utils.py +178 -0
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*.pyo
|
3 |
+
*.pyd
|
4 |
+
__py
|
5 |
+
**/__pycache__/
|
6 |
+
data
|
7 |
+
onnx
|
8 |
+
results
|
9 |
+
**.egg-info
|
10 |
+
*.log
|
11 |
+
*.onnx
|
12 |
+
.hypothesis
|
app.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
|
4 |
+
from nanosam import Predictor
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import time
|
8 |
+
from PIL import ImageDraw
|
9 |
+
from utils import download_file_from_url, fast_process, format_results, point_prompt
|
10 |
+
|
11 |
+
# Most of our demo code is from [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM). Huge thanks for AN-619.
|
12 |
+
|
13 |
+
if not os.path.exists("onnx/sam_hgv2_b4_ln_nonorm_image_encoder.onnx"):
|
14 |
+
download_file_from_url(
|
15 |
+
"https://huggingface.co/dragonSwing/nanosam/resolve/main/sam_hgv2_b4_ln_nonorm_image_encoder.onnx",
|
16 |
+
"onnx/sam_hgv2_b4_ln_nonorm_image_encoder.onnx",
|
17 |
+
)
|
18 |
+
|
19 |
+
if not os.path.exists("onnx/efficientvit_l0_mask_decoder.onnx"):
|
20 |
+
download_file_from_url(
|
21 |
+
"https://huggingface.co/dragonSwing/nanosam/resolve/main/efficientvit_l0_mask_decoder.onnx",
|
22 |
+
"onnx/efficientvit_l0_mask_decoder.onnx",
|
23 |
+
)
|
24 |
+
|
25 |
+
# Load the pre-trained model
|
26 |
+
image_encoder_cfg = {
|
27 |
+
"path": "onnx/sam_hgv2_b4_ln_nonorm_image_encoder.onnx",
|
28 |
+
"provider": "cpu",
|
29 |
+
"normalize_input": False,
|
30 |
+
}
|
31 |
+
mask_decoder_cfg = {
|
32 |
+
"path": "onnx/efficientvit_l0_mask_decoder.onnx",
|
33 |
+
"provider": "cpu",
|
34 |
+
}
|
35 |
+
predictor = Predictor(image_encoder_cfg, mask_decoder_cfg)
|
36 |
+
|
37 |
+
# Description
|
38 |
+
title = "<center><strong><font size='8'>Faster Segment Anything(NanoSAM)<font></strong></center>"
|
39 |
+
|
40 |
+
description_p = """ ## This is a demo of [Faster Segment Anything(NanoSAM) Model](https://github.com/binh234/nanosam).
|
41 |
+
# Instructions for point mode
|
42 |
+
0. Restart by click the Restart button
|
43 |
+
1. Select a point with Add Mask for the foreground (Must)
|
44 |
+
2. Select a point with Remove Area for the background (Optional)
|
45 |
+
3. Click the Start Segmenting.
|
46 |
+
- Github [link](https://github.com/binh234/nanosam)
|
47 |
+
- Model Card [link](https://huggingface.co/dragoswing/nanosam)
|
48 |
+
We will provide box mode soon.
|
49 |
+
Enjoy!
|
50 |
+
"""
|
51 |
+
|
52 |
+
examples = [
|
53 |
+
["assets/picture3.jpg"],
|
54 |
+
["assets/picture4.jpg"],
|
55 |
+
["assets/picture5.jpg"],
|
56 |
+
["assets/picture6.jpg"],
|
57 |
+
["assets/picture1.jpg"],
|
58 |
+
["assets/picture2.jpg"],
|
59 |
+
["assets/dogs.jpg"],
|
60 |
+
]
|
61 |
+
|
62 |
+
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
|
63 |
+
|
64 |
+
|
65 |
+
def get_empty_state():
|
66 |
+
return {"points": [], "point_labels": [], "features": None}
|
67 |
+
|
68 |
+
|
69 |
+
def clear():
|
70 |
+
return None, None, get_empty_state()
|
71 |
+
|
72 |
+
|
73 |
+
def set_image(image):
|
74 |
+
state = get_empty_state()
|
75 |
+
start = time.perf_counter()
|
76 |
+
predictor.set_image(image)
|
77 |
+
end = time.perf_counter()
|
78 |
+
print(f"Encoder time: {end - start: .3f}s")
|
79 |
+
state["features"] = predictor.features
|
80 |
+
return state
|
81 |
+
|
82 |
+
|
83 |
+
def segment_with_points(
|
84 |
+
image,
|
85 |
+
state,
|
86 |
+
better_quality=False,
|
87 |
+
withContours=True,
|
88 |
+
use_retina=True,
|
89 |
+
mask_random_color=True,
|
90 |
+
):
|
91 |
+
global predictor
|
92 |
+
|
93 |
+
points = np.asarray(state["points"])
|
94 |
+
point_labels = np.asarray(state["point_labels"])
|
95 |
+
if len(points) == 0 and len(point_labels) == 0:
|
96 |
+
raise gr.Error("No points selected")
|
97 |
+
if len(points) != len(point_labels):
|
98 |
+
raise gr.Error("Mismatch length between points and point labels")
|
99 |
+
if state["features"] is None:
|
100 |
+
raise gr.Error(
|
101 |
+
"Image was not set correctly, please wait for a moment after uploading image before drawing points!"
|
102 |
+
)
|
103 |
+
|
104 |
+
predictor.features = state["features"]
|
105 |
+
img_w, img_h = image.size
|
106 |
+
predictor.original_size = (img_h, img_w)
|
107 |
+
start = time.perf_counter()
|
108 |
+
masks, scores, logits = predictor.predict(
|
109 |
+
points=points,
|
110 |
+
point_labels=point_labels,
|
111 |
+
)
|
112 |
+
end = time.perf_counter()
|
113 |
+
print(f"Decoder time: {end - start: .3f}s")
|
114 |
+
|
115 |
+
# results = format_results(masks[0], scores[0], logits[0], 0)
|
116 |
+
|
117 |
+
# annotations, _ = point_prompt(results, points, point_labels, img_h, img_w)
|
118 |
+
# annotations = np.array([annotations])
|
119 |
+
|
120 |
+
fig = fast_process(
|
121 |
+
annotations=[masks[0, scores.argmax()] > 0],
|
122 |
+
image=image,
|
123 |
+
scale=1,
|
124 |
+
better_quality=better_quality,
|
125 |
+
mask_random_color=mask_random_color,
|
126 |
+
bbox=None,
|
127 |
+
use_retina=use_retina,
|
128 |
+
withContours=withContours,
|
129 |
+
)
|
130 |
+
|
131 |
+
# return fig, None
|
132 |
+
return fig
|
133 |
+
|
134 |
+
|
135 |
+
def get_points_with_draw(image, label, evt: gr.SelectData, state):
|
136 |
+
x, y = evt.index[0], evt.index[1]
|
137 |
+
point_radius, point_color = 15, (
|
138 |
+
(255, 255, 0)
|
139 |
+
if label == "Add Mask"
|
140 |
+
else (
|
141 |
+
255,
|
142 |
+
0,
|
143 |
+
255,
|
144 |
+
)
|
145 |
+
)
|
146 |
+
state["points"].append([x, y])
|
147 |
+
state["point_labels"].append(1 if label == "Add Mask" else 0)
|
148 |
+
|
149 |
+
print(x, y, label == "Add Mask")
|
150 |
+
|
151 |
+
draw = ImageDraw.Draw(image)
|
152 |
+
draw.ellipse(
|
153 |
+
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
|
154 |
+
fill=point_color,
|
155 |
+
)
|
156 |
+
return image, state
|
157 |
+
|
158 |
+
|
159 |
+
cond_img_p = gr.Image(label="Input with points", type="pil", interactive=True)
|
160 |
+
|
161 |
+
segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type="pil")
|
162 |
+
|
163 |
+
global_points = []
|
164 |
+
global_point_labels = []
|
165 |
+
|
166 |
+
with gr.Blocks(css=css, title="Faster Segment Anything(NanoSAM)") as demo:
|
167 |
+
state = gr.State(value=get_empty_state())
|
168 |
+
with gr.Row():
|
169 |
+
with gr.Column(scale=1):
|
170 |
+
# Title
|
171 |
+
gr.Markdown(title)
|
172 |
+
|
173 |
+
with gr.Tab("Point mode"):
|
174 |
+
# Images
|
175 |
+
with gr.Row(variant="panel"):
|
176 |
+
with gr.Column(scale=1):
|
177 |
+
cond_img_p.render()
|
178 |
+
|
179 |
+
with gr.Column(scale=1):
|
180 |
+
segm_img_p.render()
|
181 |
+
|
182 |
+
# Submit & Clear
|
183 |
+
with gr.Row():
|
184 |
+
with gr.Column():
|
185 |
+
with gr.Row():
|
186 |
+
add_or_remove = gr.Radio(
|
187 |
+
["Add Mask", "Remove Area"],
|
188 |
+
value="Add Mask",
|
189 |
+
)
|
190 |
+
|
191 |
+
with gr.Column():
|
192 |
+
segment_btn_p = gr.Button("Start segmenting!", variant="primary")
|
193 |
+
restart_btn_p = gr.Button("Restart", variant="secondary")
|
194 |
+
|
195 |
+
gr.Markdown("Try some of the examples below ⬇️")
|
196 |
+
gr.Examples(
|
197 |
+
examples=examples,
|
198 |
+
inputs=[cond_img_p],
|
199 |
+
outputs=[state],
|
200 |
+
fn=set_image,
|
201 |
+
run_on_click=True,
|
202 |
+
examples_per_page=4,
|
203 |
+
)
|
204 |
+
|
205 |
+
with gr.Column():
|
206 |
+
# Description
|
207 |
+
gr.Markdown(description_p)
|
208 |
+
|
209 |
+
cond_img_p.upload(set_image, inputs=[cond_img_p], outputs=[state])
|
210 |
+
cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove, state], [cond_img_p, state])
|
211 |
+
segment_btn_p.click(segment_with_points, [cond_img_p, state], [segm_img_p])
|
212 |
+
restart_btn_p.click(clear, outputs=[cond_img_p, segm_img_p, state])
|
213 |
+
|
214 |
+
demo.queue().launch()
|
assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/dogs.jpg
ADDED
assets/picture1.jpg
ADDED
assets/picture2.jpg
ADDED
assets/picture3.jpg
ADDED
assets/picture4.jpg
ADDED
assets/picture5.jpg
ADDED
assets/picture6.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
onnx>=1.14.0
|
3 |
+
onnxruntime>=1.14.0
|
4 |
+
opencv-python-headless
|
5 |
+
git+https://github.com/binh234/nanosam.git
|
utils.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
def point_prompt(masks, points, point_label, target_height, target_width):
|
11 |
+
h = masks[0]["segmentation"].shape[0]
|
12 |
+
w = masks[0]["segmentation"].shape[1]
|
13 |
+
if h != target_height or w != target_width:
|
14 |
+
points = [
|
15 |
+
[int(point[0] * w / target_width), int(point[1] * h / target_height)]
|
16 |
+
for point in points
|
17 |
+
]
|
18 |
+
onemask = np.zeros((h, w))
|
19 |
+
for i, annotation in enumerate(masks):
|
20 |
+
if type(annotation) == dict:
|
21 |
+
mask = annotation["segmentation"]
|
22 |
+
else:
|
23 |
+
mask = annotation
|
24 |
+
for i, point in enumerate(points):
|
25 |
+
if mask[point[1], point[0]] == 1:
|
26 |
+
if point_label[i] == 0:
|
27 |
+
onemask -= mask
|
28 |
+
else:
|
29 |
+
onemask += mask
|
30 |
+
break
|
31 |
+
onemask = onemask > 0
|
32 |
+
return onemask, 0
|
33 |
+
|
34 |
+
|
35 |
+
def format_results(masks, scores, logits, filter=0):
|
36 |
+
annotations = []
|
37 |
+
n = len(scores)
|
38 |
+
for i in range(n):
|
39 |
+
annotation = {}
|
40 |
+
|
41 |
+
mask = masks[i] > 0
|
42 |
+
tmp = np.where(mask)
|
43 |
+
annotation["id"] = i
|
44 |
+
annotation["segmentation"] = mask
|
45 |
+
annotation["bbox"] = [
|
46 |
+
np.min(tmp[0]),
|
47 |
+
np.min(tmp[1]),
|
48 |
+
np.max(tmp[1]),
|
49 |
+
np.max(tmp[0]),
|
50 |
+
]
|
51 |
+
annotation["score"] = scores[i]
|
52 |
+
annotation["area"] = mask.sum()
|
53 |
+
annotations.append(annotation)
|
54 |
+
return annotations
|
55 |
+
|
56 |
+
|
57 |
+
def fast_process(
|
58 |
+
annotations,
|
59 |
+
image,
|
60 |
+
scale,
|
61 |
+
better_quality=False,
|
62 |
+
mask_random_color=True,
|
63 |
+
bbox=None,
|
64 |
+
use_retina=True,
|
65 |
+
withContours=True,
|
66 |
+
):
|
67 |
+
if isinstance(annotations[0], dict):
|
68 |
+
annotations = [annotation["segmentation"] for annotation in annotations]
|
69 |
+
|
70 |
+
original_h = image.height
|
71 |
+
original_w = image.width
|
72 |
+
if better_quality:
|
73 |
+
for i, mask in enumerate(annotations):
|
74 |
+
mask = cv2.morphologyEx(
|
75 |
+
mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
|
76 |
+
)
|
77 |
+
annotations[i] = cv2.morphologyEx(
|
78 |
+
mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
|
79 |
+
)
|
80 |
+
annotations = np.asarray(annotations)
|
81 |
+
inner_mask = fast_show_mask(
|
82 |
+
annotations,
|
83 |
+
plt.gca(),
|
84 |
+
random_color=mask_random_color,
|
85 |
+
bbox=bbox,
|
86 |
+
retinamask=use_retina,
|
87 |
+
target_height=original_h,
|
88 |
+
target_width=original_w,
|
89 |
+
)
|
90 |
+
|
91 |
+
if withContours:
|
92 |
+
contour_all = []
|
93 |
+
temp = np.zeros((original_h, original_w, 1))
|
94 |
+
for i, mask in enumerate(annotations):
|
95 |
+
if type(mask) == dict:
|
96 |
+
mask = mask["segmentation"]
|
97 |
+
annotation = mask.astype(np.uint8)
|
98 |
+
if use_retina == False:
|
99 |
+
annotation = cv2.resize(
|
100 |
+
annotation,
|
101 |
+
(original_w, original_h),
|
102 |
+
interpolation=cv2.INTER_NEAREST,
|
103 |
+
)
|
104 |
+
contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
105 |
+
for contour in contours:
|
106 |
+
contour_all.append(contour)
|
107 |
+
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
|
108 |
+
color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
|
109 |
+
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
110 |
+
|
111 |
+
image = image.convert("RGBA")
|
112 |
+
overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), "RGBA")
|
113 |
+
image.paste(overlay_inner, (0, 0), overlay_inner)
|
114 |
+
|
115 |
+
if withContours:
|
116 |
+
overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), "RGBA")
|
117 |
+
image.paste(overlay_contour, (0, 0), overlay_contour)
|
118 |
+
|
119 |
+
return image
|
120 |
+
|
121 |
+
|
122 |
+
# CPU post process
|
123 |
+
def fast_show_mask(
|
124 |
+
annotation,
|
125 |
+
ax,
|
126 |
+
random_color=False,
|
127 |
+
bbox=None,
|
128 |
+
retinamask=True,
|
129 |
+
target_height=960,
|
130 |
+
target_width=960,
|
131 |
+
):
|
132 |
+
mask_sum = annotation.shape[0]
|
133 |
+
height = annotation.shape[1]
|
134 |
+
weight = annotation.shape[2]
|
135 |
+
areas = np.sum(annotation, axis=(1, 2))
|
136 |
+
sorted_indices = np.argsort(areas)[::1]
|
137 |
+
annotation = annotation[sorted_indices]
|
138 |
+
|
139 |
+
index = (annotation != 0).argmax(axis=0)
|
140 |
+
if random_color == True:
|
141 |
+
color = np.random.random((mask_sum, 1, 1, 3))
|
142 |
+
else:
|
143 |
+
color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
|
144 |
+
transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
|
145 |
+
visual = np.concatenate([color, transparency], axis=-1)
|
146 |
+
mask_image = np.expand_dims(annotation, -1) * visual
|
147 |
+
|
148 |
+
mask = np.zeros((height, weight, 4))
|
149 |
+
|
150 |
+
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing="ij")
|
151 |
+
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
152 |
+
|
153 |
+
mask[h_indices, w_indices, :] = mask_image[indices]
|
154 |
+
if bbox is not None:
|
155 |
+
x1, y1, x2, y2 = bbox
|
156 |
+
ax.add_patch(
|
157 |
+
plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1)
|
158 |
+
)
|
159 |
+
|
160 |
+
if retinamask == False:
|
161 |
+
mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
162 |
+
|
163 |
+
return mask
|
164 |
+
|
165 |
+
|
166 |
+
def download_file_from_url(url, output_file, chunk_size=8192):
|
167 |
+
output_dir = os.path.dirname(output_file)
|
168 |
+
os.makedirs(output_dir, exist_ok=True)
|
169 |
+
try:
|
170 |
+
with requests.get(url, stream=True) as response:
|
171 |
+
if response.status_code == 200:
|
172 |
+
with open(output_file, 'wb') as f:
|
173 |
+
for chunk in response.iter_content(chunk_size=chunk_size):
|
174 |
+
f.write(chunk)
|
175 |
+
else:
|
176 |
+
print(f"Failed to download file. Status code: {response.status_code}")
|
177 |
+
except Exception as e:
|
178 |
+
print(f"An error occurred: {e}")
|