Spaces:
Runtime error
Runtime error
""" | |
Using as reference: | |
- https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512 | |
- https://huggingface.co/spaces/chansung/segformer-tf-transformers/blob/main/app.py | |
- https://huggingface.co/facebook/detr-resnet-50-panoptic | |
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ | |
https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_panoptic_segmentation_minimal_example_(with_DetrFeatureExtractor).ipynb | |
https://arxiv.org/abs/2005.12872 | |
https://arxiv.org/pdf/1801.00868.pdf | |
Additions | |
- add shown labels as strings | |
- show only animal masks (ask an nlp model?) | |
For next time | |
- for diff 'confidence' the high conf masks should change.... | |
- colors are not great and should be constant per class? add text? | |
- Im getting core dumped (segmentation fault) when loading hugging face model.. :() | |
https://github.com/huggingface/transformers/issues/16939 | |
- cap slider to 95? | |
- switch between panoptic and semantic? | |
""" | |
from transformers import DetrFeatureExtractor, DetrForSegmentation | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torchvision | |
import itertools | |
import seaborn as sns | |
def predict_animal_mask(im, | |
gr_slider_confidence): | |
image = Image.fromarray(im) # im: numpy array 3d: 480, 640, 3: to PIL Image | |
image = image.resize((200,200)) # PIL image # could I upsample output instead? better? | |
# encoding is a dict with pixel_values and pixel_mask | |
encoding = feature_extractor(images=image, return_tensors="pt") #pt=Pytorch, tf=TensorFlow | |
outputs = model(**encoding) # odict with keys: ['logits', 'pred_boxes', 'pred_masks', 'last_hidden_state', 'encoder_last_hidden_state'] | |
logits = outputs.logits # torch.Size([1, 100, 251]); class logits? but why 251? | |
bboxes = outputs.pred_boxes | |
masks = outputs.pred_masks # torch.Size([1, 100, 200, 200]); mask logits? for every pixel, score in each of the 100 classes? there is a mask per class | |
# keep only the masks with high confidence?-------------------------------- | |
# compute the prob per mask (i.e., class), excluding the "no-object" class (the last one) | |
prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0] # why logits last dim 251? | |
# threshold the confidence | |
keep = prob_per_query > gr_slider_confidence/100.0 | |
# postprocess the mask (numpy arrays) | |
label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() # from the masks per class, select the highest per pixel | |
color_mask = np.zeros(image.size+(3,)) | |
palette = itertools.cycle(sns.color_palette()) | |
for lbl in np.unique(label_per_pixel): #enumerate(palette()): | |
color_mask[label_per_pixel==lbl,:] = np.asarray(next(palette))*255 #color | |
# color_mask = np.zeros(image.size+(3,)) | |
# for lbl, color in enumerate(ade_palette()): | |
# color_mask[label_per_pixel==lbl,:] = color | |
# Show image + mask | |
pred_img = np.array(image.convert('RGB'))*0.25 + color_mask*0.75 | |
pred_img = pred_img.astype(np.uint8) | |
return pred_img | |
####################################### | |
# get models from hugging face | |
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic') | |
model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic') | |
# gradio components -inputs | |
gr_image_input = gr.inputs.Image() | |
gr_slider_confidence = gr.inputs.Slider(0,100,5,85, | |
label='Set confidence threshold for masks') | |
# gradio outputs | |
gr_image_output = gr.outputs.Image() | |
#################################################### | |
# Create user interface and launch | |
gr.Interface(predict_animal_mask, | |
inputs = [gr_image_input,gr_slider_confidence], | |
outputs = gr_image_output, | |
title = 'Image segmentation with varying confidence', | |
description = "A panoptic (semantic+instance) segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch() | |
#################################### | |
# url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
# image = Image.open(requests.get(url, stream=True).raw) | |
# inputs = feature_extractor(images=image, return_tensors="pt") | |
# outputs = model(**inputs) | |
# logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4) | |