File size: 8,629 Bytes
b213d84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
# Copyright (c) Facebook, Inc. and its affiliates.
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.layers import Conv2d
from .registry import ROI_DENSEPOSE_HEAD_REGISTRY
@ROI_DENSEPOSE_HEAD_REGISTRY.register()
class DensePoseDeepLabHead(nn.Module):
"""
DensePose head using DeepLabV3 model from
"Rethinking Atrous Convolution for Semantic Image Segmentation"
<https://arxiv.org/abs/1706.05587>.
"""
def __init__(self, cfg: CfgNode, input_channels: int):
super(DensePoseDeepLabHead, self).__init__()
# fmt: off
hidden_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM
kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL
norm = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM
self.n_stacked_convs = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS
self.use_nonlocal = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON
# fmt: on
pad_size = kernel_size // 2
n_channels = input_channels
self.ASPP = ASPP(input_channels, [6, 12, 56], n_channels) # 6, 12, 56
self.add_module("ASPP", self.ASPP)
if self.use_nonlocal:
self.NLBlock = NONLocalBlock2D(input_channels, bn_layer=True)
self.add_module("NLBlock", self.NLBlock)
# weight_init.c2_msra_fill(self.ASPP)
for i in range(self.n_stacked_convs):
norm_module = nn.GroupNorm(32, hidden_dim) if norm == "GN" else None
layer = Conv2d(
n_channels,
hidden_dim,
kernel_size,
stride=1,
padding=pad_size,
bias=not norm,
norm=norm_module,
)
weight_init.c2_msra_fill(layer)
n_channels = hidden_dim
layer_name = self._get_layer_name(i)
self.add_module(layer_name, layer)
self.n_out_channels = hidden_dim
# initialize_module_params(self)
def forward(self, features):
x0 = features
x = self.ASPP(x0)
if self.use_nonlocal:
x = self.NLBlock(x)
output = x
for i in range(self.n_stacked_convs):
layer_name = self._get_layer_name(i)
x = getattr(self, layer_name)(x)
x = F.relu(x)
output = x
return output
def _get_layer_name(self, i: int):
layer_name = "body_conv_fcn{}".format(i + 1)
return layer_name
# Copied from
# https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/deeplabv3.py
# See https://arxiv.org/pdf/1706.05587.pdf for details
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(
in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False
),
nn.GroupNorm(32, out_channels),
nn.ReLU(),
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.GroupNorm(32, out_channels),
nn.ReLU(),
)
def forward(self, x):
size = x.shape[-2:]
x = super(ASPPPooling, self).forward(x)
return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels):
super(ASPP, self).__init__()
modules = []
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.GroupNorm(32, out_channels),
nn.ReLU(),
)
)
rate1, rate2, rate3 = tuple(atrous_rates)
modules.append(ASPPConv(in_channels, out_channels, rate1))
modules.append(ASPPConv(in_channels, out_channels, rate2))
modules.append(ASPPConv(in_channels, out_channels, rate3))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
# nn.BatchNorm2d(out_channels),
nn.ReLU()
# nn.Dropout(0.5)
)
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
# copied from
# https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_embedded_gaussian.py
# See https://arxiv.org/abs/1711.07971 for details
class _NonLocalBlockND(nn.Module):
def __init__(
self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True
):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
bn = nn.GroupNorm # (32, hidden_dim) #nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
bn = nn.GroupNorm # (32, hidden_dim)nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool_layer = nn.MaxPool1d(kernel_size=2)
bn = nn.GroupNorm # (32, hidden_dim)nn.BatchNorm1d
self.g = conv_nd(
in_channels=self.in_channels,
out_channels=self.inter_channels,
kernel_size=1,
stride=1,
padding=0,
)
if bn_layer:
self.W = nn.Sequential(
conv_nd(
in_channels=self.inter_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0,
),
bn(32, self.in_channels),
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(
in_channels=self.inter_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0,
)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = conv_nd(
in_channels=self.in_channels,
out_channels=self.inter_channels,
kernel_size=1,
stride=1,
padding=0,
)
self.phi = conv_nd(
in_channels=self.in_channels,
out_channels=self.inter_channels,
kernel_size=1,
stride=1,
padding=0,
)
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, x):
"""
:param x: (b, c, t, h, w)
:return:
"""
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(
in_channels,
inter_channels=inter_channels,
dimension=2,
sub_sample=sub_sample,
bn_layer=bn_layer,
)
|