Spaces:
Runtime error
Runtime error
import collections | |
import os | |
import tempfile | |
from matplotlib import gridspec | |
from matplotlib import pyplot as plt | |
import numpy as np | |
from PIL import Image | |
import urllib | |
import tensorflow as tf | |
import gradio as gr | |
from subprocess import call | |
import sys | |
import requests | |
url1 = 'https://cdn.pixabay.com/photo/2014/09/07/21/52/city-438393_1280.jpg' | |
r = requests.get(url1, allow_redirects=True) | |
open("city1.jpg", 'wb').write(r.content) | |
url2 = 'https://cdn.pixabay.com/photo/2016/02/19/11/36/canal-1209808_1280.jpg' | |
r = requests.get(url2, allow_redirects=True) | |
open("city2.jpg", 'wb').write(r.content) | |
DatasetInfo = collections.namedtuple( | |
'DatasetInfo', | |
'num_classes, label_divisor, thing_list, colormap, class_names') | |
def _cityscapes_label_colormap(): | |
"""Creates a label colormap used in CITYSCAPES segmentation benchmark. | |
See more about CITYSCAPES dataset at https://www.cityscapes-dataset.com/ | |
M. Cordts, et al. "The Cityscapes Dataset for Semantic Urban Scene Understanding." CVPR. 2016. | |
Returns: | |
A 2-D numpy array with each row being mapped RGB color (in uint8 range). | |
""" | |
colormap = np.zeros((256, 3), dtype=np.uint8) | |
colormap[0] = [128, 64, 128] | |
colormap[1] = [244, 35, 232] | |
colormap[2] = [70, 70, 70] | |
colormap[3] = [102, 102, 156] | |
colormap[4] = [190, 153, 153] | |
colormap[5] = [153, 153, 153] | |
colormap[6] = [250, 170, 30] | |
colormap[7] = [220, 220, 0] | |
colormap[8] = [107, 142, 35] | |
colormap[9] = [152, 251, 152] | |
colormap[10] = [70, 130, 180] | |
colormap[11] = [220, 20, 60] | |
colormap[12] = [255, 0, 0] | |
colormap[13] = [0, 0, 142] | |
colormap[14] = [0, 0, 70] | |
colormap[15] = [0, 60, 100] | |
colormap[16] = [0, 80, 100] | |
colormap[17] = [0, 0, 230] | |
colormap[18] = [119, 11, 32] | |
return colormap | |
def _cityscapes_class_names(): | |
return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', | |
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', | |
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', | |
'bicycle') | |
def cityscapes_dataset_information(): | |
return DatasetInfo( | |
num_classes=19, | |
label_divisor=1000, | |
thing_list=tuple(range(11, 19)), | |
colormap=_cityscapes_label_colormap(), | |
class_names=_cityscapes_class_names()) | |
def perturb_color(color, noise, used_colors, max_trials=50, random_state=None): | |
"""Pertrubs the color with some noise. | |
If `used_colors` is not None, we will return the color that has | |
not appeared before in it. | |
Args: | |
color: A numpy array with three elements [R, G, B]. | |
noise: Integer, specifying the amount of perturbing noise (in uint8 range). | |
used_colors: A set, used to keep track of used colors. | |
max_trials: An integer, maximum trials to generate random color. | |
random_state: An optional np.random.RandomState. If passed, will be used to | |
generate random numbers. | |
Returns: | |
A perturbed color that has not appeared in used_colors. | |
""" | |
if random_state is None: | |
random_state = np.random | |
for _ in range(max_trials): | |
random_color = color + random_state.randint( | |
low=-noise, high=noise + 1, size=3) | |
random_color = np.clip(random_color, 0, 255) | |
if tuple(random_color) not in used_colors: | |
used_colors.add(tuple(random_color)) | |
return random_color | |
print('Max trial reached and duplicate color will be used. Please consider ' | |
'increase noise in `perturb_color()`.') | |
return random_color | |
def color_panoptic_map(panoptic_prediction, dataset_info, perturb_noise): | |
"""Helper method to colorize output panoptic map. | |
Args: | |
panoptic_prediction: A 2D numpy array, panoptic prediction from deeplab | |
model. | |
dataset_info: A DatasetInfo object, dataset associated to the model. | |
perturb_noise: Integer, the amount of noise (in uint8 range) added to each | |
instance of the same semantic class. | |
Returns: | |
colored_panoptic_map: A 3D numpy array with last dimension of 3, colored | |
panoptic prediction map. | |
used_colors: A dictionary mapping semantic_ids to a set of colors used | |
in `colored_panoptic_map`. | |
""" | |
if panoptic_prediction.ndim != 2: | |
raise ValueError('Expect 2-D panoptic prediction. Got {}'.format( | |
panoptic_prediction.shape)) | |
semantic_map = panoptic_prediction // dataset_info.label_divisor | |
instance_map = panoptic_prediction % dataset_info.label_divisor | |
height, width = panoptic_prediction.shape | |
colored_panoptic_map = np.zeros((height, width, 3), dtype=np.uint8) | |
used_colors = collections.defaultdict(set) | |
# Use a fixed seed to reproduce the same visualization. | |
random_state = np.random.RandomState(0) | |
unique_semantic_ids = np.unique(semantic_map) | |
for semantic_id in unique_semantic_ids: | |
semantic_mask = semantic_map == semantic_id | |
if semantic_id in dataset_info.thing_list: | |
# For `thing` class, we will add a small amount of random noise to its | |
# correspondingly predefined semantic segmentation colormap. | |
unique_instance_ids = np.unique(instance_map[semantic_mask]) | |
for instance_id in unique_instance_ids: | |
instance_mask = np.logical_and(semantic_mask, | |
instance_map == instance_id) | |
random_color = perturb_color( | |
dataset_info.colormap[semantic_id], | |
perturb_noise, | |
used_colors[semantic_id], | |
random_state=random_state) | |
colored_panoptic_map[instance_mask] = random_color | |
else: | |
# For `stuff` class, we use the defined semantic color. | |
colored_panoptic_map[semantic_mask] = dataset_info.colormap[semantic_id] | |
used_colors[semantic_id].add(tuple(dataset_info.colormap[semantic_id])) | |
return colored_panoptic_map, used_colors | |
def vis_segmentation(image, | |
panoptic_prediction, | |
dataset_info, | |
perturb_noise=60): | |
"""Visualizes input image, segmentation map and overlay view.""" | |
plt.figure(figsize=(30, 20)) | |
grid_spec = gridspec.GridSpec(2, 2) | |
ax = plt.subplot(grid_spec[0]) | |
plt.imshow(image) | |
plt.axis('off') | |
ax.set_title('input image', fontsize=20) | |
ax = plt.subplot(grid_spec[1]) | |
panoptic_map, used_colors = color_panoptic_map(panoptic_prediction, | |
dataset_info, perturb_noise) | |
plt.imshow(panoptic_map) | |
plt.axis('off') | |
ax.set_title('panoptic map', fontsize=20) | |
ax = plt.subplot(grid_spec[2]) | |
plt.imshow(image) | |
plt.imshow(panoptic_map, alpha=0.7) | |
plt.axis('off') | |
ax.set_title('panoptic overlay', fontsize=20) | |
ax = plt.subplot(grid_spec[3]) | |
max_num_instances = max(len(color) for color in used_colors.values()) | |
# RGBA image as legend. | |
legend = np.zeros((len(used_colors), max_num_instances, 4), dtype=np.uint8) | |
class_names = [] | |
for i, semantic_id in enumerate(sorted(used_colors)): | |
legend[i, :len(used_colors[semantic_id]), :3] = np.array( | |
list(used_colors[semantic_id])) | |
legend[i, :len(used_colors[semantic_id]), 3] = 255 | |
if semantic_id < dataset_info.num_classes: | |
class_names.append(dataset_info.class_names[semantic_id]) | |
else: | |
class_names.append('ignore') | |
plt.imshow(legend, interpolation='nearest') | |
ax.yaxis.tick_left() | |
plt.yticks(range(len(legend)), class_names, fontsize=15) | |
plt.xticks([], []) | |
ax.tick_params(width=0.0, grid_linewidth=0.0) | |
plt.grid('off') | |
return plt | |
def run_cmd(command): | |
try: | |
print(command) | |
call(command, shell=True) | |
except KeyboardInterrupt: | |
print("Process interrupted") | |
sys.exit(1) | |
MODEL_NAME = 'max_deeplab_l_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model' | |
_MODELS = ('resnet50_os32_panoptic_deeplab_cityscapes_crowd_trainfine_saved_model', | |
'resnet50_beta_os32_panoptic_deeplab_cityscapes_trainfine_saved_model', | |
'wide_resnet41_os16_panoptic_deeplab_cityscapes_trainfine_saved_model', | |
'swidernet_sac_1_1_1_os16_panoptic_deeplab_cityscapes_trainfine_saved_model', | |
'swidernet_sac_1_1_3_os16_panoptic_deeplab_cityscapes_trainfine_saved_model', | |
'swidernet_sac_1_1_4.5_os16_panoptic_deeplab_cityscapes_trainfine_saved_model', | |
'axial_swidernet_1_1_1_os16_axial_deeplab_cityscapes_trainfine_saved_model', | |
'axial_swidernet_1_1_3_os16_axial_deeplab_cityscapes_trainfine_saved_model', | |
'axial_swidernet_1_1_4.5_os16_axial_deeplab_cityscapes_trainfine_saved_model', | |
'max_deeplab_s_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model', | |
'max_deeplab_l_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model') | |
_DOWNLOAD_URL_PATTERN = 'https://storage.googleapis.com/gresearch/tf-deeplab/saved_model/%s.tar.gz' | |
_MODEL_NAME_TO_URL_AND_DATASET = { | |
model: (_DOWNLOAD_URL_PATTERN % model, cityscapes_dataset_information()) | |
for model in _MODELS | |
} | |
MODEL_URL, DATASET_INFO = _MODEL_NAME_TO_URL_AND_DATASET[MODEL_NAME] | |
model_dir = tempfile.mkdtemp() | |
download_path = os.path.join(model_dir, MODEL_NAME + '.gz') | |
urllib.request.urlretrieve(MODEL_URL, download_path) | |
run_cmd("tar -xzvf " + download_path + " -C " + model_dir) | |
LOADED_MODEL = tf.saved_model.load(os.path.join(model_dir, MODEL_NAME)) | |
def inference(image): | |
image = image.resize(size=(512, 512)) | |
im = np.array(image) | |
output = LOADED_MODEL(tf.cast(im, tf.uint8)) | |
return vis_segmentation(im, output['panoptic_pred'][0], DATASET_INFO) | |
title = "Deeplab2 - Max Deeplab L" | |
description = "demo for Deeplab2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.\nModel: max_deeplab_l_backbone_os16_axial_deeplab_cityscapes_trainfine_saved_model" | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2106.09748'>DeepLab2: A TensorFlow Library for Deep Labeling</a> | <a href='https://github.com/google-research/deeplab2'>Github Repo</a></p>" | |
gr.Interface( | |
inference, | |
[gr.inputs.Image(type="pil", label="Input")], | |
gr.outputs.Image(type="plot", label="Output"), | |
title=title, | |
description=description, | |
article=article, | |
examples=[ | |
["city1.jpg"], | |
["city2.jpg"] | |
]).launch() | |