SeyedAli commited on
Commit
831852e
1 Parent(s): f953b2f

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Image Segmentation
3
- emoji: 🏢
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.45.2
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Mask2former Demo
3
+ emoji: 🚀
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.16.2
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from predict import predict_masks
3
+ import glob
4
+
5
+ ##Create list of examples to be loaded
6
+ example_list = glob.glob("examples/*")
7
+ example_list = list(map(lambda el:[el], example_list))
8
+
9
+ demo = gr.Blocks()
10
+
11
+ with demo:
12
+
13
+ gr.Markdown("# **<p align='center'>Mask2Former: Masked Attention Mask Transformer for Universal Segmentation</p>**")
14
+ gr.Markdown("This space demonstrates the use of Mask2Former. It was introduced in the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) and first released in [this repository](https://github.com/facebookresearch/Mask2Former/). \
15
+ Before Mask2Former, you'd have to resort to using a specialized architecture designed for solving a particular kind of image segmentation task (i.e. semantic, instance or panoptic segmentation). On the other hand, in the form of Mask2Former, for the first time, we have a single architecture that is capable of solving any segmentation task and performs on par or better than specialized architectures.")
16
+
17
+ with gr.Box():
18
+
19
+
20
+ with gr.Row():
21
+ with gr.Column():
22
+ gr.Markdown("**Inputs**")
23
+ segmentation_task = gr.Dropdown(["semantic", "instance", "panoptic"], value="panoptic", label="Segmentation Task", show_label=True)
24
+ input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
25
+
26
+ with gr.Column():
27
+ gr.Markdown("**Outputs**")
28
+ output_heading = gr.Textbox(label="Output Type", show_label=True)
29
+ output_mask = gr.Image(label="Predicted Masks", show_label=True)
30
+
31
+ gr.Markdown("**Predict**")
32
+
33
+ with gr.Box():
34
+ with gr.Row():
35
+ submit_button = gr.Button("Submit")
36
+
37
+ gr.Markdown("**Examples:**")
38
+
39
+ with gr.Column():
40
+ gr.Examples(example_list, [input_image, segmentation_task], [output_mask,output_heading], predict_masks)
41
+
42
+
43
+ submit_button.click(predict_masks, inputs=[input_image, segmentation_task], outputs=[output_mask,output_heading])
44
+
45
+ gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/shivalika-singh/\">Shivalika Singh</a>')
46
+
47
+ demo.launch(debug=True)
color_palette.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # color palattes for COCO, cityscapes and ADE datasets
2
+
3
+ def coco_panoptic_palette():
4
+ return [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
5
+ (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
6
+ (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
7
+ (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
8
+ (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
9
+ (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
10
+ (255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
11
+ (0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
12
+ (78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
13
+ (134, 134, 103), (145, 148, 174), (255, 208, 186),
14
+ (197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
15
+ (151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
16
+ (166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149),
17
+ (179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205),
18
+ (147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0),
19
+ (119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88),
20
+ (95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118),
21
+ (219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15),
22
+ (127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0),
23
+ (95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122),
24
+ (191, 162, 208), (255, 255, 128), (147, 211, 203),
25
+ (150, 100, 100), (168, 171, 172), (146, 112, 198),
26
+ (210, 170, 100), (92, 136, 89), (218, 88, 184), (241, 129, 0),
27
+ (217, 17, 255), (124, 74, 181), (70, 70, 70), (255, 228, 255),
28
+ (154, 208, 0), (193, 0, 92), (76, 91, 113), (255, 180, 195),
29
+ (106, 154, 176),
30
+ (230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
31
+ (254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
32
+ (104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
33
+ (135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
34
+ (183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
35
+ (146, 139, 141),
36
+ (70, 130, 180), (134, 199, 156), (209, 226, 140), (96, 36, 108),
37
+ (96, 96, 96), (64, 170, 64), (152, 251, 152), (208, 229, 228),
38
+ (206, 186, 171), (152, 161, 64), (116, 112, 0), (0, 114, 143),
39
+ (102, 102, 156), (250, 141, 255)]
40
+
41
+ def cityscapes_palette():
42
+ return [[128, 64, 128],[244, 35, 232],[70, 70, 70],[102, 102, 156],[190, 153, 153],
43
+ [153, 153, 153],[250, 170, 30],[220, 220, 0],[107, 142, 35],[152, 251, 152],
44
+ [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
45
+ [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]
46
+
47
+ def ade_palette():
48
+ """Color palette that maps each class to RGB values.
49
+
50
+ This one is actually taken from ADE20k.
51
+ """
52
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
53
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
54
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
55
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
56
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
57
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
58
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
59
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
60
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
61
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
62
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
63
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
64
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
65
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
66
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
67
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
68
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
69
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
70
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
71
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
72
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
73
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
74
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
75
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
76
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
77
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
78
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
79
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
80
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
81
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
82
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
83
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
84
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
85
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
86
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
87
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
88
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
89
+ [102, 255, 0], [92, 0, 255]]
examples/armchair.jpg ADDED
examples/cat-dog.jpg ADDED
examples/person-bike.jpg ADDED
predict.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+ from collections import defaultdict
6
+ import os
7
+ # Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
8
+ # Hence, installing detectron2 this way when using Gradio HF spaces.
9
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
10
+
11
+ from detectron2.data import MetadataCatalog
12
+ from detectron2.utils.visualizer import ColorMode, Visualizer
13
+ from color_palette import ade_palette
14
+ from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
15
+
16
+ def load_model_and_processor(model_ckpt: str):
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
19
+ model.eval()
20
+ image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
21
+ return model, image_preprocessor
22
+
23
+ def load_default_ckpt(segmentation_task: str):
24
+ if segmentation_task == "semantic":
25
+ default_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
26
+ elif segmentation_task == "instance":
27
+ default_ckpt = "facebook/mask2former-swin-small-coco-instance"
28
+ else:
29
+ default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
30
+ return default_ckpt
31
+
32
+ def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
33
+ metadata = MetadataCatalog.get("coco_2017_val_panoptic")
34
+ for res in seg_info:
35
+ res['category_id'] = res.pop('label_id')
36
+ pred_class = res['category_id']
37
+ isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
38
+ res['isthing'] = bool(isthing)
39
+
40
+ visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
41
+ out = visualizer.draw_panoptic_seg_predictions(
42
+ predicted_segmentation_map.cpu(), seg_info, alpha=0.5
43
+ )
44
+ output_img = Image.fromarray(out.get_image())
45
+ return output_img
46
+
47
+ def draw_semantic_segmentation(segmentation_map, image, palette):
48
+
49
+ color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
50
+ for label, color in enumerate(palette):
51
+ color_segmentation_map[segmentation_map - 1 == label, :] = color
52
+ # Convert to BGR
53
+ ground_truth_color_seg = color_segmentation_map[..., ::-1]
54
+
55
+ img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5
56
+ img = img.astype(np.uint8)
57
+ return img
58
+
59
+ def visualize_instance_seg_mask(mask, input_image):
60
+ color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
61
+
62
+ labels = np.unique(mask)
63
+ label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
64
+
65
+ for label, color in label2color.items():
66
+ color_segmentation_map[mask - 1 == label, :] = color
67
+
68
+ ground_truth_color_seg = color_segmentation_map[..., ::-1]
69
+
70
+ img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5
71
+ img = img.astype(np.uint8)
72
+ return img
73
+
74
+ def predict_masks(input_img_path: str, segmentation_task: str):
75
+
76
+ #load model and image processor
77
+ default_ckpt = load_default_ckpt(segmentation_task)
78
+ model, image_processor = load_model_and_processor(default_ckpt)
79
+
80
+ ## pass input image through image processor
81
+ image = Image.open(input_img_path)
82
+ inputs = image_processor(images=image, return_tensors="pt")
83
+
84
+ ## pass inputs to model for prediction
85
+ with torch.no_grad():
86
+ outputs = model(**inputs)
87
+
88
+ # pass outputs to processor for postprocessing
89
+ if segmentation_task == "semantic":
90
+ result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
91
+ predicted_segmentation_map = result.cpu().numpy()
92
+ palette = ade_palette()
93
+ output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
94
+ output_heading = "Semantic Segmentation Output"
95
+
96
+ elif segmentation_task == "instance":
97
+ result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
98
+ predicted_instance_map = result["segmentation"].cpu().detach().numpy()
99
+ output_result = visualize_instance_seg_mask(predicted_instance_map, image)
100
+ output_heading = "Instance Segmentation Output"
101
+
102
+ else:
103
+ result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
104
+ predicted_segmentation_map = result["segmentation"]
105
+ seg_info = result['segments_info']
106
+ output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
107
+ output_heading = "Panoptic Segmentation Output"
108
+
109
+
110
+ return output_result, output_heading
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ opencv-python
4
+ git+https://github.com/huggingface/transformers.git
5
+ pillow
6
+ scipy
7
+ torchvision