Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Deeplab2 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""This file contains the Motion-DeepLab architecture.""" | |
import functools | |
from typing import Any, Dict, Text, Tuple | |
from absl import logging | |
import tensorflow as tf | |
from deeplab2 import common | |
from deeplab2 import config_pb2 | |
from deeplab2.data import dataset | |
from deeplab2.model import builder | |
from deeplab2.model import utils | |
from deeplab2.model.post_processor import motion_deeplab | |
from deeplab2.model.post_processor import post_processor_builder | |
class MotionDeepLab(tf.keras.Model): | |
"""This class represents the Motion-DeepLab meta architecture. | |
This class is the basis of the Motion-DeepLab architecture. This Model can be | |
used for Video Panoptic Segmentation or Segmenting and Tracking Every Pixel | |
(STEP). | |
""" | |
def __init__(self, | |
config: config_pb2.ExperimentOptions, | |
dataset_descriptor: dataset.DatasetDescriptor): | |
"""Initializes a Motion-DeepLab architecture. | |
Args: | |
config: A config_pb2.ExperimentOptions configuration. | |
dataset_descriptor: A dataset.DatasetDescriptor. | |
""" | |
super(MotionDeepLab, self).__init__(name='MotionDeepLab') | |
if config.trainer_options.solver_options.use_sync_batchnorm: | |
logging.info('Synchronized Batchnorm is used.') | |
bn_layer = functools.partial( | |
tf.keras.layers.experimental.SyncBatchNormalization, | |
momentum=config.trainer_options.solver_options.batchnorm_momentum, | |
epsilon=config.trainer_options.solver_options.batchnorm_epsilon) | |
else: | |
logging.info('Standard (unsynchronized) Batchnorm is used.') | |
bn_layer = functools.partial( | |
tf.keras.layers.BatchNormalization, | |
momentum=config.trainer_options.solver_options.batchnorm_momentum, | |
epsilon=config.trainer_options.solver_options.batchnorm_epsilon) | |
self._encoder = builder.create_encoder( | |
config.model_options.backbone, bn_layer, | |
conv_kernel_weight_decay=( | |
config.trainer_options.solver_options.weight_decay)) | |
self._decoder = builder.create_decoder(config.model_options, bn_layer, | |
dataset_descriptor.ignore_label) | |
self._prev_center_prediction = tf.Variable( | |
0.0, | |
trainable=False, | |
validate_shape=False, | |
shape=tf.TensorShape(None), | |
dtype=tf.float32, | |
name='prev_prediction_buffer') | |
self._prev_center_list = tf.Variable( | |
tf.zeros((0, 5), dtype=tf.int32), | |
trainable=False, | |
validate_shape=False, | |
shape=tf.TensorShape(None), | |
name='prev_prediction_list') | |
self._next_tracking_id = tf.Variable( | |
1, | |
trainable=False, | |
validate_shape=False, | |
dtype=tf.int32, | |
name='next+_tracking_id') | |
self._post_processor = post_processor_builder.get_post_processor( | |
config, dataset_descriptor) | |
self._render_fn = functools.partial( | |
motion_deeplab.render_panoptic_map_as_heatmap, | |
sigma=8, | |
label_divisor=dataset_descriptor.panoptic_label_divisor, | |
void_label=dataset_descriptor.ignore_label) | |
self._track_fn = functools.partial( | |
motion_deeplab.assign_instances_to_previous_tracks, | |
label_divisor=dataset_descriptor.panoptic_label_divisor) | |
# The ASPP pooling size is always set to train crop size, which is found to | |
# be experimentally better. | |
pool_size = config.train_dataset_options.crop_size | |
output_stride = float(config.model_options.backbone.output_stride) | |
pool_size = tuple( | |
utils.scale_mutable_sequence(pool_size, 1.0 / output_stride)) | |
logging.info('Setting pooling size to %s', pool_size) | |
self.set_pool_size(pool_size) | |
def call(self, input_tensor: tf.Tensor, training=False) -> Dict[Text, Any]: | |
"""Performs a forward pass. | |
Args: | |
input_tensor: An input tensor of type tf.Tensor with shape [batch, height, | |
width, channels]. The input tensor should contain batches of RGB images. | |
training: A boolean flag indicating whether training behavior should be | |
used (default: False). | |
Returns: | |
A dictionary containing the results of the specified DeepLab architecture. | |
The results are bilinearly upsampled to input size before returning. | |
""" | |
if not training: | |
# During evaluation, we add the previous predicted heatmap as 7th input | |
# channel (cf. during training, we use groundtruth heatmap). | |
input_tensor = self._add_previous_heatmap_to_input(input_tensor) | |
# Normalize the input in the same way as Inception. We normalize it outside | |
# the encoder so that we can extend encoders to different backbones without | |
# copying the normalization to each encoder. We normalize it after data | |
# preprocessing because it is faster on TPUs than on host CPUs. The | |
# normalization should not increase TPU memory consumption because it does | |
# not require gradient. | |
input_tensor = input_tensor / 127.5 - 1.0 | |
# Get the static spatial shape of the input tensor. | |
_, input_h, input_w, _ = input_tensor.get_shape().as_list() | |
pred = self._decoder( | |
self._encoder(input_tensor, training=training), training=training) | |
result_dict = dict() | |
for key, value in pred.items(): | |
if (key == common.PRED_OFFSET_MAP_KEY or | |
key == common.PRED_FRAME_OFFSET_MAP_KEY): | |
result_dict[key] = utils.resize_and_rescale_offsets( | |
value, [input_h, input_w]) | |
else: | |
result_dict[key] = utils.resize_bilinear( | |
value, [input_h, input_w]) | |
# Change the semantic logits to probabilities with softmax. | |
result_dict[common.PRED_SEMANTIC_PROBS_KEY] = tf.nn.softmax( | |
result_dict[common.PRED_SEMANTIC_LOGITS_KEY]) | |
if not training: | |
result_dict.update(self._post_processor(result_dict)) | |
next_heatmap, next_centers = self._render_fn( | |
result_dict[common.PRED_PANOPTIC_KEY]) | |
panoptic_map, next_centers, next_id = self._track_fn( | |
self._prev_center_list.value(), | |
next_centers, | |
next_heatmap, | |
result_dict[common.PRED_FRAME_OFFSET_MAP_KEY], | |
result_dict[common.PRED_PANOPTIC_KEY], | |
self._next_tracking_id.value() | |
) | |
result_dict[common.PRED_PANOPTIC_KEY] = panoptic_map | |
self._next_tracking_id.assign(next_id) | |
self._prev_center_prediction.assign( | |
tf.expand_dims(next_heatmap, axis=3, name='expand_prev_centermap')) | |
self._prev_center_list.assign(next_centers) | |
if common.PRED_CENTER_HEATMAP_KEY in result_dict: | |
result_dict[common.PRED_CENTER_HEATMAP_KEY] = tf.squeeze( | |
result_dict[common.PRED_CENTER_HEATMAP_KEY], axis=3) | |
return result_dict | |
def _add_previous_heatmap_to_input(self, input_tensor: tf.Tensor | |
) -> tf.Tensor: | |
frame1, frame2 = tf.split(input_tensor, [3, 3], axis=3) | |
# We use a simple way to detect if the first frame of a sequence is being | |
# processed. For the first frame, frame1 and frame2 are identical. | |
if tf.reduce_all(tf.equal(frame1, frame2)): | |
h = tf.shape(input_tensor)[1] | |
w = tf.shape(input_tensor)[2] | |
prev_center = tf.zeros((1, h, w, 1), dtype=tf.float32) | |
self._prev_center_list.assign(tf.zeros((0, 5), dtype=tf.int32)) | |
self._next_tracking_id.assign(1) | |
else: | |
prev_center = self._prev_center_prediction | |
output_tensor = tf.concat([frame1, frame2, prev_center], axis=3) | |
output_tensor.set_shape([None, None, None, 7]) | |
return output_tensor | |
def reset_pooling_layer(self): | |
"""Resets the ASPP pooling layer to global average pooling.""" | |
self._decoder.reset_pooling_layer() | |
def set_pool_size(self, pool_size: Tuple[int, int]): | |
"""Sets the pooling size of the ASPP pooling layer. | |
Args: | |
pool_size: A tuple specifying the pooling size of the ASPP pooling layer. | |
""" | |
self._decoder.set_pool_size(pool_size) | |
def checkpoint_items(self) -> Dict[Text, Any]: | |
items = dict(encoder=self._encoder) | |
items.update(self._decoder.checkpoint_items) | |
return items | |