Spaces:
Runtime error
Runtime error
Add instance seg visualization
Browse files- app.py +2 -2
- predict.py +17 -17
app.py
CHANGED
@@ -10,14 +10,14 @@ demo = gr.Blocks()
|
|
10 |
|
11 |
with demo:
|
12 |
|
13 |
-
gr.Markdown("# **<p align='center'>Mask2Former: Masked Attention Transformer for Universal Segmentation</p>**")
|
14 |
gr.Markdown("This space demonstrates the use of Mask2Former. It was introduced in the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) and first released in [this repository](https://github.com/facebookresearch/Mask2Former/). \
|
15 |
Before Mask2Former, you'd have to resort to using a specialized architecture designed for solving a particular kind of image segmentation task (i.e. semantic, instance or panoptic segmentation). On the other hand, in the form of Mask2Former, for the first time, we have a single architecture that is capable of solving any segmentation task and performs on par or better than specialized architectures.")
|
16 |
|
17 |
with gr.Box():
|
18 |
|
19 |
with gr.Row():
|
20 |
-
segmentation_task = gr.Dropdown(["semantic", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
|
21 |
with gr.Box():
|
22 |
with gr.Row():
|
23 |
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
|
|
10 |
|
11 |
with demo:
|
12 |
|
13 |
+
gr.Markdown("# **<p align='center'>Mask2Former: Masked Attention Mask Transformer for Universal Segmentation</p>**")
|
14 |
gr.Markdown("This space demonstrates the use of Mask2Former. It was introduced in the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) and first released in [this repository](https://github.com/facebookresearch/Mask2Former/). \
|
15 |
Before Mask2Former, you'd have to resort to using a specialized architecture designed for solving a particular kind of image segmentation task (i.e. semantic, instance or panoptic segmentation). On the other hand, in the form of Mask2Former, for the first time, we have a single architecture that is capable of solving any segmentation task and performs on par or better than specialized architectures.")
|
16 |
|
17 |
with gr.Box():
|
18 |
|
19 |
with gr.Row():
|
20 |
+
segmentation_task = gr.Dropdown(["semantic", "instance", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
|
21 |
with gr.Box():
|
22 |
with gr.Row():
|
23 |
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
predict.py
CHANGED
@@ -40,6 +40,7 @@ def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
|
|
40 |
return output_img
|
41 |
|
42 |
def draw_semantic_segmentation(segmentation_map, image, palette):
|
|
|
43 |
color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
|
44 |
for label, color in enumerate(palette):
|
45 |
color_segmentation_map[segmentation_map - 1 == label, :] = color
|
@@ -50,15 +51,20 @@ def draw_semantic_segmentation(segmentation_map, image, palette):
|
|
50 |
img = img.astype(np.uint8)
|
51 |
return img
|
52 |
|
53 |
-
def visualize_instance_seg_mask(mask):
|
54 |
-
|
|
|
55 |
labels = np.unique(mask)
|
56 |
-
label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
62 |
|
63 |
def predict_masks(input_img_path: str, segmentation_task: str):
|
64 |
|
@@ -82,15 +88,9 @@ def predict_masks(input_img_path: str, segmentation_task: str):
|
|
82 |
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
|
83 |
|
84 |
elif segmentation_task == "instance":
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
# # predicted_segmentation_map = torch.argmax(result, dim=0).numpy()
|
89 |
-
# # results = torch.argmax(predicted_segmentation_map, dim=0).numpy()
|
90 |
-
# print("predicted_segmentation_map:",predicted_segmentation_map)
|
91 |
-
# print("type predicted_segmentation_map:", type(predicted_segmentation_map))
|
92 |
-
# output_result = visualize_instance_seg_mask(predicted_segmentation_map)
|
93 |
-
# # mask = plot_semantic_map(predicted_segmentation_map, image)
|
94 |
|
95 |
else:
|
96 |
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
|
|
40 |
return output_img
|
41 |
|
42 |
def draw_semantic_segmentation(segmentation_map, image, palette):
|
43 |
+
|
44 |
color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
|
45 |
for label, color in enumerate(palette):
|
46 |
color_segmentation_map[segmentation_map - 1 == label, :] = color
|
|
|
51 |
img = img.astype(np.uint8)
|
52 |
return img
|
53 |
|
54 |
+
def visualize_instance_seg_mask(mask, input_image):
|
55 |
+
color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
|
56 |
+
|
57 |
labels = np.unique(mask)
|
58 |
+
label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
|
59 |
+
|
60 |
+
for label, color in label2color.items():
|
61 |
+
color_segmentation_map[mask - 1 == label, :] = color
|
62 |
+
|
63 |
+
ground_truth_color_seg = color_segmentation_map[..., ::-1]
|
64 |
+
|
65 |
+
img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5
|
66 |
+
img = img.astype(np.uint8)
|
67 |
+
return img
|
68 |
|
69 |
def predict_masks(input_img_path: str, segmentation_task: str):
|
70 |
|
|
|
88 |
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
|
89 |
|
90 |
elif segmentation_task == "instance":
|
91 |
+
result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
92 |
+
predicted_instance_map = result["segmentation"].cpu().detach().numpy()
|
93 |
+
output_result = visualize_instance_seg_mask(predicted_instance_map, image)
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
else:
|
96 |
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|