Spaces:
Runtime error
Runtime error
Add mask extension function to inpaint mode.
Browse files
app.py
CHANGED
@@ -250,6 +250,38 @@ def xywh_to_xyxy(box, sizeW, sizeH):
|
|
250 |
box = box.numpy()
|
251 |
return box
|
252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
254 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
|
255 |
|
@@ -372,40 +404,16 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
372 |
if task_type == 'inpainting':
|
373 |
# inpainting pipeline
|
374 |
image_source_for_inpaint = image_pil.resize((512, 512))
|
|
|
|
|
|
|
375 |
image_mask_for_inpaint = mask_pil.resize((512, 512))
|
376 |
image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
377 |
else:
|
378 |
# remove from mask
|
379 |
if mask_source_radio == mask_source_segment:
|
380 |
-
|
381 |
-
|
382 |
-
boxes_filt_ori_array = boxes_filt_ori.numpy()
|
383 |
-
if inpaint_mode == 'merge':
|
384 |
-
extend_shape_0 = masks_shape[0]
|
385 |
-
extend_shape_1 = masks_shape[1]
|
386 |
-
else:
|
387 |
-
extend_shape_0 = 1
|
388 |
-
extend_shape_1 = 1
|
389 |
-
for i in range(extend_shape_0):
|
390 |
-
for j in range(extend_shape_1):
|
391 |
-
mask = masks_ori[i][j].cpu().numpy()
|
392 |
-
mask_pil = Image.fromarray(mask)
|
393 |
-
|
394 |
-
if remove_mode == 'segment':
|
395 |
-
useRectangle = False
|
396 |
-
else:
|
397 |
-
useRectangle = True
|
398 |
-
|
399 |
-
try:
|
400 |
-
remove_mask_extend = int(remove_mask_extend)
|
401 |
-
except:
|
402 |
-
remove_mask_extend = 10
|
403 |
-
mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
|
404 |
-
# box_convert(torch.tensor(boxes_filt_ori_array[i]), in_fmt="cxcywh", out_fmt="xyxy").numpy(),
|
405 |
-
xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), size[0], size[1]),
|
406 |
-
extend_pixels=remove_mask_extend, useRectangle=useRectangle)
|
407 |
-
mask_imgs.append(mask_pil_exp)
|
408 |
-
mask_pil = mix_masks(mask_imgs)
|
409 |
output_images.append(mask_pil.convert("RGB"))
|
410 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
|
411 |
|
@@ -495,7 +503,7 @@ if __name__ == "__main__":
|
|
495 |
with gr.Column(scale=1):
|
496 |
remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
|
497 |
with gr.Column(scale=1):
|
498 |
-
remove_mask_extend = gr.Textbox(label="
|
499 |
|
500 |
with gr.Column():
|
501 |
gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
|
|
|
250 |
box = box.numpy()
|
251 |
return box
|
252 |
|
253 |
+
def to_extend_mask(segment_mask, boxes_filt, size, remove_mask_extend, remove_mode):
|
254 |
+
# remove from mask
|
255 |
+
mask_imgs = []
|
256 |
+
masks_shape = segment_mask.shape
|
257 |
+
boxes_filt_ori_array = boxes_filt.numpy()
|
258 |
+
if inpaint_mode == 'merge':
|
259 |
+
extend_shape_0 = masks_shape[0]
|
260 |
+
extend_shape_1 = masks_shape[1]
|
261 |
+
else:
|
262 |
+
extend_shape_0 = 1
|
263 |
+
extend_shape_1 = 1
|
264 |
+
for i in range(extend_shape_0):
|
265 |
+
for j in range(extend_shape_1):
|
266 |
+
mask = segment_mask[i][j].cpu().numpy()
|
267 |
+
mask_pil = Image.fromarray(mask)
|
268 |
+
|
269 |
+
if remove_mode == 'segment':
|
270 |
+
useRectangle = False
|
271 |
+
else:
|
272 |
+
useRectangle = True
|
273 |
+
|
274 |
+
try:
|
275 |
+
remove_mask_extend = int(remove_mask_extend)
|
276 |
+
except:
|
277 |
+
remove_mask_extend = 10
|
278 |
+
mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
|
279 |
+
xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), size[0], size[1]),
|
280 |
+
extend_pixels=remove_mask_extend, useRectangle=useRectangle)
|
281 |
+
mask_imgs.append(mask_pil_exp)
|
282 |
+
mask_pil = mix_masks(mask_imgs)
|
283 |
+
return mask_pil
|
284 |
+
|
285 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
286 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
|
287 |
|
|
|
404 |
if task_type == 'inpainting':
|
405 |
# inpainting pipeline
|
406 |
image_source_for_inpaint = image_pil.resize((512, 512))
|
407 |
+
if remove_mask_extend:
|
408 |
+
mask_pil = to_extend_mask(masks_ori, boxes_filt_ori, size, remove_mask_extend, remove_mode)
|
409 |
+
output_images.append(mask_pil.convert("RGB"))
|
410 |
image_mask_for_inpaint = mask_pil.resize((512, 512))
|
411 |
image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
412 |
else:
|
413 |
# remove from mask
|
414 |
if mask_source_radio == mask_source_segment:
|
415 |
+
if remove_mask_extend:
|
416 |
+
mask_pil = to_extend_mask(masks_ori, boxes_filt_ori, size, remove_mask_extend, remove_mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
output_images.append(mask_pil.convert("RGB"))
|
418 |
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
|
419 |
|
|
|
503 |
with gr.Column(scale=1):
|
504 |
remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
|
505 |
with gr.Column(scale=1):
|
506 |
+
remove_mask_extend = gr.Textbox(label="Enlarge Mask (Empty: no mask extension, default: 10)")
|
507 |
|
508 |
with gr.Column():
|
509 |
gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
|