atwang's picture
[NOT TESTED] initial implementation of app
01664b3
raw
history blame
9.43 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
from typing import Callable, Dict, List, Optional, Tuple, Union
import fvcore.nn.weight_init as weight_init
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from ..transformer_decoder.maskformer_transformer_decoder import StandardTransformerDecoder
from ..pixel_decoder.fpn import build_pixel_decoder
@SEM_SEG_HEADS_REGISTRY.register()
class PerPixelBaselineHead(nn.Module):
_version = 2
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
version = local_metadata.get("version", None)
if version is None or version < 2:
logger = logging.getLogger(__name__)
# Do not warn if train from scratch
scratch = True
logger = logging.getLogger(__name__)
for k in list(state_dict.keys()):
newk = k
if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
newk = k.replace(prefix, prefix + "pixel_decoder.")
# logger.warning(f"{k} ==> {newk}")
if newk != k:
state_dict[newk] = state_dict[k]
del state_dict[k]
scratch = False
if not scratch:
logger.warning(
f"Weight format of {self.__class__.__name__} have changed! "
"Please upgrade your models. Applying automatic conversion now ..."
)
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
num_classes: int,
pixel_decoder: nn.Module,
loss_weight: float = 1.0,
ignore_value: int = -1,
):
"""
NOTE: this interface is experimental.
Args:
input_shape: shapes (channels and stride) of the input features
num_classes: number of classes to predict
pixel_decoder: the pixel decoder module
loss_weight: loss weight
ignore_value: category id to be ignored during training.
"""
super().__init__()
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
self.in_features = [k for k, v in input_shape]
feature_strides = [v.stride for k, v in input_shape]
feature_channels = [v.channels for k, v in input_shape]
self.ignore_value = ignore_value
self.common_stride = 4
self.loss_weight = loss_weight
self.pixel_decoder = pixel_decoder
self.predictor = Conv2d(
self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0
)
weight_init.c2_msra_fill(self.predictor)
@classmethod
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
return {
"input_shape": {
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
},
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
"pixel_decoder": build_pixel_decoder(cfg, input_shape),
"loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
}
def forward(self, features, targets=None):
"""
Returns:
In training, returns (None, dict of losses)
In inference, returns (CxHxW logits, {})
"""
x = self.layers(features)
if self.training:
return None, self.losses(x, targets)
else:
x = F.interpolate(
x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
return x, {}
def layers(self, features):
x, _, _ = self.pixel_decoder.forward_features(features)
x = self.predictor(x)
return x
def losses(self, predictions, targets):
predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163
predictions = F.interpolate(
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
loss = F.cross_entropy(
predictions, targets, reduction="mean", ignore_index=self.ignore_value
)
losses = {"loss_sem_seg": loss * self.loss_weight}
return losses
@SEM_SEG_HEADS_REGISTRY.register()
class PerPixelBaselinePlusHead(PerPixelBaselineHead):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
version = local_metadata.get("version", None)
if version is None or version < 2:
# Do not warn if train from scratch
scratch = True
logger = logging.getLogger(__name__)
for k in list(state_dict.keys()):
newk = k
if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
newk = k.replace(prefix, prefix + "pixel_decoder.")
logger.debug(f"{k} ==> {newk}")
if newk != k:
state_dict[newk] = state_dict[k]
del state_dict[k]
scratch = False
if not scratch:
logger.warning(
f"Weight format of {self.__class__.__name__} have changed! "
"Please upgrade your models. Applying automatic conversion now ..."
)
@configurable
def __init__(
self,
input_shape: Dict[str, ShapeSpec],
*,
# extra parameters
transformer_predictor: nn.Module,
transformer_in_feature: str,
deep_supervision: bool,
# inherit parameters
num_classes: int,
pixel_decoder: nn.Module,
loss_weight: float = 1.0,
ignore_value: int = -1,
):
"""
NOTE: this interface is experimental.
Args:
input_shape: shapes (channels and stride) of the input features
transformer_predictor: the transformer decoder that makes prediction
transformer_in_feature: input feature name to the transformer_predictor
deep_supervision: whether or not to add supervision to the output of
every transformer decoder layer
num_classes: number of classes to predict
pixel_decoder: the pixel decoder module
loss_weight: loss weight
ignore_value: category id to be ignored during training.
"""
super().__init__(
input_shape,
num_classes=num_classes,
pixel_decoder=pixel_decoder,
loss_weight=loss_weight,
ignore_value=ignore_value,
)
del self.predictor
self.predictor = transformer_predictor
self.transformer_in_feature = transformer_in_feature
self.deep_supervision = deep_supervision
@classmethod
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
ret = super().from_config(cfg, input_shape)
ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE
if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
else:
in_channels = input_shape[ret["transformer_in_feature"]].channels
ret["transformer_predictor"] = StandardTransformerDecoder(
cfg, in_channels, mask_classification=False
)
ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
return ret
def forward(self, features, targets=None):
"""
Returns:
In training, returns (None, dict of losses)
In inference, returns (CxHxW logits, {})
"""
x, aux_outputs = self.layers(features)
if self.training:
if self.deep_supervision:
losses = self.losses(x, targets)
for i, aux_output in enumerate(aux_outputs):
losses["loss_sem_seg" + f"_{i}"] = self.losses(
aux_output["pred_masks"], targets
)["loss_sem_seg"]
return None, losses
else:
return None, self.losses(x, targets)
else:
x = F.interpolate(
x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
return x, {}
def layers(self, features):
mask_features, transformer_encoder_features, _ = self.pixel_decoder.forward_features(features)
if self.transformer_in_feature == "transformer_encoder":
assert (
transformer_encoder_features is not None
), "Please use the TransformerEncoderPixelDecoder."
predictions = self.predictor(transformer_encoder_features, mask_features)
else:
predictions = self.predictor(features[self.transformer_in_feature], mask_features)
if self.deep_supervision:
return predictions["pred_masks"], predictions["aux_outputs"]
else:
return predictions["pred_masks"], None