mask2former-demo / predict.py
shivi's picture
Add instance seg visualization
aeaceee
raw
history blame
4.44 kB
import torch
import random
import numpy as np
from PIL import Image
from collections import defaultdict
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from color_palette import ade_palette
from transformers import MaskFormerImageProcessor, Mask2FormerForUniversalSegmentation
def load_model_and_processor(model_ckpt: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
model.eval()
image_preprocessor = MaskFormerImageProcessor.from_pretrained(model_ckpt)
return model, image_preprocessor
def load_default_ckpt(segmentation_task: str):
if segmentation_task == "semantic":
default_pretrained_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
elif segmentation_task == "instance":
default_pretrained_ckpt = "facebook/mask2former-swin-small-coco-instance"
else:
default_pretrained_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
return default_pretrained_ckpt
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
for res in seg_info:
res['category_id'] = res.pop('label_id')
pred_class = res['category_id']
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
res['isthing'] = bool(isthing)
visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
out = visualizer.draw_panoptic_seg_predictions(
predicted_segmentation_map.cpu(), seg_info, alpha=0.5
)
output_img = Image.fromarray(out.get_image())
return output_img
def draw_semantic_segmentation(segmentation_map, image, palette):
color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
for label, color in enumerate(palette):
color_segmentation_map[segmentation_map - 1 == label, :] = color
# Convert to BGR
ground_truth_color_seg = color_segmentation_map[..., ::-1]
img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5
img = img.astype(np.uint8)
return img
def visualize_instance_seg_mask(mask, input_image):
color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
labels = np.unique(mask)
label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
for label, color in label2color.items():
color_segmentation_map[mask - 1 == label, :] = color
ground_truth_color_seg = color_segmentation_map[..., ::-1]
img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5
img = img.astype(np.uint8)
return img
def predict_masks(input_img_path: str, segmentation_task: str):
#load model and image processor
default_pretrained_ckpt = load_default_ckpt(segmentation_task)
model, image_processor = load_model_and_processor(default_pretrained_ckpt)
## pass input image through image processor
image = Image.open(input_img_path)
inputs = image_processor(images=image, return_tensors="pt")
## pass inputs to model for prediction
with torch.no_grad():
outputs = model(**inputs)
# pass outputs to processor for postprocessing
if segmentation_task == "semantic":
result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = result.cpu().numpy()
palette = ade_palette()
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
elif segmentation_task == "instance":
result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_instance_map = result["segmentation"].cpu().detach().numpy()
output_result = visualize_instance_seg_mask(predicted_instance_map, image)
else:
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = result["segmentation"]
seg_info = result['segments_info']
output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
return output_result