not-lain commited on
Commit
4b54c6a
·
1 Parent(s): 7a3185b

finish sam2

Browse files
Files changed (2) hide show
  1. app.py +31 -3
  2. assets/truck.jpg +0 -0
app.py CHANGED
@@ -7,6 +7,8 @@ from transformers import AutoModelForImageSegmentation
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image, ImageOps
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
 
10
 
11
  torch.set_float32_matmul_precision(["high", "highest"][0])
12
 
@@ -122,9 +124,29 @@ def rmbg(image=None, url=None):
122
  return image
123
 
124
 
125
- def mask_generation(image=None, json=None):
126
  predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
127
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
  @spaces.GPU
@@ -200,7 +222,13 @@ sam2_tab = gr.Interface(
200
  gr.Image("image", type="pil"),
201
  gr.JSON(),
202
  ],
203
- outputs=["image"],
 
 
 
 
 
 
204
  )
205
 
206
  demo = gr.TabbedInterface(
 
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image, ImageOps
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
 
13
  torch.set_float32_matmul_precision(["high", "highest"][0])
14
 
 
124
  return image
125
 
126
 
127
+ def mask_generation(image=None, d=None):
128
  predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
129
+ predictor.set_image(image)
130
+ input_point = np.array(d["input_points"])
131
+ input_label = np.array(d["input_labels"])
132
+ masks, scores, logits = predictor.predict(
133
+ point_coords=input_point,
134
+ point_labels=input_label,
135
+ multimask_output=True,
136
+ )
137
+ sorted_ind = np.argsort(scores)[::-1]
138
+ masks = masks[sorted_ind]
139
+ scores = scores[sorted_ind]
140
+ logits = logits[sorted_ind]
141
+
142
+ out = []
143
+ image = Image.fromarray(image)
144
+ for i in range(len(masks)):
145
+ m = Image.fromarray(masks[i] * 255).convert("L")
146
+ comp = Image.composite(Image.fromarray(image), m, m)
147
+ out.append((comp, f"image {i}"))
148
+
149
+ return out
150
 
151
 
152
  @spaces.GPU
 
222
  gr.Image("image", type="pil"),
223
  gr.JSON(),
224
  ],
225
+ outputs=gr.Gallery(),
226
+ examples=[
227
+ [
228
+ "./assets/truck.jpg",
229
+ {"input_points": [[500, 375], [1125, 625]], "input_labels": [1, 0]},
230
+ ]
231
+ ],
232
  )
233
 
234
  demo = gr.TabbedInterface(
assets/truck.jpg ADDED