shivi commited on
Commit
aeaceee
1 Parent(s): c027c15

Add instance seg visualization

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. predict.py +17 -17
app.py CHANGED
@@ -10,14 +10,14 @@ demo = gr.Blocks()
10
 
11
  with demo:
12
 
13
- gr.Markdown("# **<p align='center'>Mask2Former: Masked Attention Transformer for Universal Segmentation</p>**")
14
  gr.Markdown("This space demonstrates the use of Mask2Former. It was introduced in the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) and first released in [this repository](https://github.com/facebookresearch/Mask2Former/). \
15
  Before Mask2Former, you'd have to resort to using a specialized architecture designed for solving a particular kind of image segmentation task (i.e. semantic, instance or panoptic segmentation). On the other hand, in the form of Mask2Former, for the first time, we have a single architecture that is capable of solving any segmentation task and performs on par or better than specialized architectures.")
16
 
17
  with gr.Box():
18
 
19
  with gr.Row():
20
- segmentation_task = gr.Dropdown(["semantic", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
21
  with gr.Box():
22
  with gr.Row():
23
  input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
 
10
 
11
  with demo:
12
 
13
+ gr.Markdown("# **<p align='center'>Mask2Former: Masked Attention Mask Transformer for Universal Segmentation</p>**")
14
  gr.Markdown("This space demonstrates the use of Mask2Former. It was introduced in the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) and first released in [this repository](https://github.com/facebookresearch/Mask2Former/). \
15
  Before Mask2Former, you'd have to resort to using a specialized architecture designed for solving a particular kind of image segmentation task (i.e. semantic, instance or panoptic segmentation). On the other hand, in the form of Mask2Former, for the first time, we have a single architecture that is capable of solving any segmentation task and performs on par or better than specialized architectures.")
16
 
17
  with gr.Box():
18
 
19
  with gr.Row():
20
+ segmentation_task = gr.Dropdown(["semantic", "instance", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
21
  with gr.Box():
22
  with gr.Row():
23
  input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
predict.py CHANGED
@@ -40,6 +40,7 @@ def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
40
  return output_img
41
 
42
  def draw_semantic_segmentation(segmentation_map, image, palette):
 
43
  color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
44
  for label, color in enumerate(palette):
45
  color_segmentation_map[segmentation_map - 1 == label, :] = color
@@ -50,15 +51,20 @@ def draw_semantic_segmentation(segmentation_map, image, palette):
50
  img = img.astype(np.uint8)
51
  return img
52
 
53
- def visualize_instance_seg_mask(mask):
54
- image = np.zeros((mask.shape[0], mask.shape[1], 3))
 
55
  labels = np.unique(mask)
56
- label2color = {label: (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
57
- for i in range(image.shape[0]):
58
- for j in range(image.shape[1]):
59
- image[i, j, :] = label2color[mask[i, j]]
60
- image = image / 255
61
- return image
 
 
 
 
62
 
63
  def predict_masks(input_img_path: str, segmentation_task: str):
64
 
@@ -82,15 +88,9 @@ def predict_masks(input_img_path: str, segmentation_task: str):
82
  output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
83
 
84
  elif segmentation_task == "instance":
85
- pass
86
- # result = image_processor.post_process_segmentation(outputs)[0].cpu().detach()
87
- # predicted_segmentation_map = result["segmentation"]
88
- # # predicted_segmentation_map = torch.argmax(result, dim=0).numpy()
89
- # # results = torch.argmax(predicted_segmentation_map, dim=0).numpy()
90
- # print("predicted_segmentation_map:",predicted_segmentation_map)
91
- # print("type predicted_segmentation_map:", type(predicted_segmentation_map))
92
- # output_result = visualize_instance_seg_mask(predicted_segmentation_map)
93
- # # mask = plot_semantic_map(predicted_segmentation_map, image)
94
 
95
  else:
96
  result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
 
40
  return output_img
41
 
42
  def draw_semantic_segmentation(segmentation_map, image, palette):
43
+
44
  color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
45
  for label, color in enumerate(palette):
46
  color_segmentation_map[segmentation_map - 1 == label, :] = color
 
51
  img = img.astype(np.uint8)
52
  return img
53
 
54
+ def visualize_instance_seg_mask(mask, input_image):
55
+ color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
56
+
57
  labels = np.unique(mask)
58
+ label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
59
+
60
+ for label, color in label2color.items():
61
+ color_segmentation_map[mask - 1 == label, :] = color
62
+
63
+ ground_truth_color_seg = color_segmentation_map[..., ::-1]
64
+
65
+ img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5
66
+ img = img.astype(np.uint8)
67
+ return img
68
 
69
  def predict_masks(input_img_path: str, segmentation_task: str):
70
 
 
88
  output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
89
 
90
  elif segmentation_task == "instance":
91
+ result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
92
+ predicted_instance_map = result["segmentation"].cpu().detach().numpy()
93
+ output_result = visualize_instance_seg_mask(predicted_instance_map, image)
 
 
 
 
 
 
94
 
95
  else:
96
  result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]