import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from .newcrf_utils import resize, normal_init class PPM(nn.ModuleList): """Pooling Pyramid Module used in PSPNet. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module. in_channels (int): Input channels. channels (int): Channels after modules, before conv_seg. conv_cfg (dict|None): Config of conv layers. norm_cfg (dict|None): Config of norm layers. act_cfg (dict): Config of activation layers. align_corners (bool): align_corners argument of F.interpolate. """ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, act_cfg, align_corners): super(PPM, self).__init__() self.pool_scales = pool_scales self.align_corners = align_corners self.in_channels = in_channels self.channels = channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg for pool_scale in pool_scales: # == if batch size = 1, BN is not supported, change to GN if pool_scale == 1: norm_cfg = dict(type='GN', requires_grad=True, num_groups=256) self.append( nn.Sequential( nn.AdaptiveAvgPool2d(pool_scale), ConvModule( self.in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=norm_cfg, act_cfg=self.act_cfg))) def forward(self, x): """Forward function.""" ppm_outs = [] for ppm in self: ppm_out = ppm(x) upsampled_ppm_out = resize( ppm_out, size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ppm_outs.append(upsampled_ppm_out) return ppm_outs class BaseDecodeHead(nn.Module): """Base class for BaseDecodeHead. Args: in_channels (int|Sequence[int]): Input channels. channels (int): Channels after modules, before conv_seg. num_classes (int): Number of classes. dropout_ratio (float): Ratio of dropout layer. Default: 0.1. conv_cfg (dict|None): Config of conv layers. Default: None. norm_cfg (dict|None): Config of norm layers. Default: None. act_cfg (dict): Config of activation layers. Default: dict(type='ReLU') in_index (int|Sequence[int]): Input feature index. Default: -1 input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. Default: None. loss_decode (dict): Config of decode loss. Default: dict(type='CrossEntropyLoss'). ignore_index (int | None): The label index to be ignored. When using masked BCE loss, ignore_index should be set to None. Default: 255 sampler (dict|None): The config of segmentation map sampler. Default: None. align_corners (bool): align_corners argument of F.interpolate. Default: False. """ def __init__(self, in_channels, channels, *, num_classes, dropout_ratio=0.1, conv_cfg=None, norm_cfg=None, act_cfg=dict(type='ReLU'), in_index=-1, input_transform=None, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), ignore_index=255, sampler=None, align_corners=False): super(BaseDecodeHead, self).__init__() self._init_inputs(in_channels, in_index, input_transform) self.channels = channels self.num_classes = num_classes self.dropout_ratio = dropout_ratio self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.in_index = in_index # self.loss_decode = build_loss(loss_decode) self.ignore_index = ignore_index self.align_corners = align_corners # if sampler is not None: # self.sampler = build_pixel_sampler(sampler, context=self) # else: # self.sampler = None # self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) # self.conv1 = nn.Conv2d(channels, num_classes, 3, padding=1) if dropout_ratio > 0: self.dropout = nn.Dropout2d(dropout_ratio) else: self.dropout = None self.fp16_enabled = False def extra_repr(self): """Extra repr.""" s = f'input_transform={self.input_transform}, ' \ f'ignore_index={self.ignore_index}, ' \ f'align_corners={self.align_corners}' return s def _init_inputs(self, in_channels, in_index, input_transform): """Check and initialize input transforms. The in_channels, in_index and input_transform must match. Specifically, when input_transform is None, only single feature map will be selected. So in_channels and in_index must be of type int. When input_transform Args: in_channels (int|Sequence[int]): Input channels. in_index (int|Sequence[int]): Input feature index. input_transform (str|None): Transformation type of input features. Options: 'resize_concat', 'multiple_select', None. 'resize_concat': Multiple feature maps will be resize to the same size as first one and than concat together. Usually used in FCN head of HRNet. 'multiple_select': Multiple feature maps will be bundle into a list and passed into decode head. None: Only one select feature map is allowed. """ if input_transform is not None: assert input_transform in ['resize_concat', 'multiple_select'] self.input_transform = input_transform self.in_index = in_index if input_transform is not None: assert isinstance(in_channels, (list, tuple)) assert isinstance(in_index, (list, tuple)) assert len(in_channels) == len(in_index) if input_transform == 'resize_concat': self.in_channels = sum(in_channels) else: self.in_channels = in_channels else: assert isinstance(in_channels, int) assert isinstance(in_index, int) self.in_channels = in_channels def init_weights(self): """Initialize weights of classification layer.""" # normal_init(self.conv_seg, mean=0, std=0.01) # normal_init(self.conv1, mean=0, std=0.01) def _transform_inputs(self, inputs): """Transform inputs for decoder. Args: inputs (list[Tensor]): List of multi-level img features. Returns: Tensor: The transformed inputs """ if self.input_transform == 'resize_concat': inputs = [inputs[i] for i in self.in_index] upsampled_inputs = [ resize( input=x, size=inputs[0].shape[2:], mode='bilinear', align_corners=self.align_corners) for x in inputs ] inputs = torch.cat(upsampled_inputs, dim=1) elif self.input_transform == 'multiple_select': inputs = [inputs[i] for i in self.in_index] else: inputs = inputs[self.in_index] return inputs def forward(self, inputs): """Placeholder of forward function.""" pass def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): """Forward function for training. Args: inputs (list[Tensor]): List of multi-level img features. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. gt_semantic_seg (Tensor): Semantic segmentation masks used if the architecture supports semantic segmentation task. train_cfg (dict): The training config. Returns: dict[str, Tensor]: a dictionary of loss components """ seg_logits = self.forward(inputs) losses = self.losses(seg_logits, gt_semantic_seg) return losses def forward_test(self, inputs, img_metas, test_cfg): """Forward function for testing. Args: inputs (list[Tensor]): List of multi-level img features. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmseg/datasets/pipelines/formatting.py:Collect`. test_cfg (dict): The testing config. Returns: Tensor: Output segmentation map. """ return self.forward(inputs) class UPerHead(BaseDecodeHead): def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): super(UPerHead, self).__init__( input_transform='multiple_select', **kwargs) # FPN Module self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() for in_channels in self.in_channels: # skip the top layer l_conv = ConvModule( in_channels, self.channels, 1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, inplace=True) fpn_conv = ConvModule( self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, inplace=True) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) def forward(self, inputs): """Forward function.""" inputs = self._transform_inputs(inputs) # build laterals laterals = [ lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs) ] # laterals.append(self.psp_forward(inputs)) # build top-down path used_backbone_levels = len(laterals) for i in range(used_backbone_levels - 1, 0, -1): prev_shape = laterals[i - 1].shape[2:] laterals[i - 1] += resize( laterals[i], size=prev_shape, mode='bilinear', align_corners=self.align_corners) # build outputs fpn_outs = [ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1) ] # append psp feature fpn_outs.append(laterals[-1]) return fpn_outs[0] class PSP(BaseDecodeHead): """Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet `_. Args: pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid Module applied on the last feature. Default: (1, 2, 3, 6). """ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): super(PSP, self).__init__( input_transform='multiple_select', **kwargs) # PSP Module self.psp_modules = PPM( pool_scales, self.in_channels[-1], self.channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, align_corners=self.align_corners) self.bottleneck = ConvModule( self.in_channels[-1] + len(pool_scales) * self.channels, self.channels, 3, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def psp_forward(self, inputs): """Forward function of PSP module.""" x = inputs[-1] psp_outs = [x] psp_outs.extend(self.psp_modules(x)) psp_outs = torch.cat(psp_outs, dim=1) output = self.bottleneck(psp_outs) return output def forward(self, inputs): """Forward function.""" inputs = self._transform_inputs(inputs) return self.psp_forward(inputs)