import torch import random import numpy as np from PIL import Image from collections import defaultdict from detectron2.data import MetadataCatalog from detectron2.utils.visualizer import ColorMode, Visualizer from color_palette import ade_palette from transformers import MaskFormerImageProcessor, Mask2FormerForUniversalSegmentation def load_model_and_processor(model_ckpt: str): device = "cuda" if torch.cuda.is_available() else "cpu" model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device)) model.eval() image_preprocessor = MaskFormerImageProcessor.from_pretrained(model_ckpt) return model, image_preprocessor def load_default_ckpt(segmentation_task: str): if segmentation_task == "semantic": default_pretrained_ckpt = "facebook/mask2former-swin-tiny-ade-semantic" elif segmentation_task == "instance": default_pretrained_ckpt = "facebook/mask2former-swin-small-coco-instance" else: default_pretrained_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic" return default_pretrained_ckpt def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image): metadata = MetadataCatalog.get("coco_2017_val_panoptic") for res in seg_info: res['category_id'] = res.pop('label_id') pred_class = res['category_id'] isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values() res['isthing'] = bool(isthing) visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE) out = visualizer.draw_panoptic_seg_predictions( predicted_segmentation_map.cpu(), seg_info, alpha=0.5 ) output_img = Image.fromarray(out.get_image()) return output_img def draw_semantic_segmentation(segmentation_map, image, palette): color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3 for label, color in enumerate(palette): color_segmentation_map[segmentation_map - 1 == label, :] = color # Convert to BGR ground_truth_color_seg = color_segmentation_map[..., ::-1] img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5 img = img.astype(np.uint8) return img def visualize_instance_seg_mask(mask, input_image): color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) labels = np.unique(mask) label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels} for label, color in label2color.items(): color_segmentation_map[mask - 1 == label, :] = color ground_truth_color_seg = color_segmentation_map[..., ::-1] img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5 img = img.astype(np.uint8) return img def predict_masks(input_img_path: str, segmentation_task: str): #load model and image processor default_pretrained_ckpt = load_default_ckpt(segmentation_task) model, image_processor = load_model_and_processor(default_pretrained_ckpt) ## pass input image through image processor image = Image.open(input_img_path) inputs = image_processor(images=image, return_tensors="pt") ## pass inputs to model for prediction with torch.no_grad(): outputs = model(**inputs) # pass outputs to processor for postprocessing if segmentation_task == "semantic": result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] predicted_segmentation_map = result.cpu().numpy() palette = ade_palette() output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette) elif segmentation_task == "instance": result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0] predicted_instance_map = result["segmentation"].cpu().detach().numpy() output_result = visualize_instance_seg_mask(predicted_instance_map, image) else: result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] predicted_segmentation_map = result["segmentation"] seg_info = result['segments_info'] output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image) return output_result