|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
"""
|
|
Functions for building the BottleneckBlock from Detectron2.
|
|
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py
|
|
"""
|
|
|
|
def get_norm(norm, out_channels, num_norm_groups=32):
|
|
"""
|
|
Args:
|
|
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
|
or a callable that takes a channel number and returns
|
|
the normalization layer as a nn.Module.
|
|
Returns:
|
|
nn.Module or None: the normalization layer
|
|
"""
|
|
if norm is None:
|
|
return None
|
|
if isinstance(norm, str):
|
|
if len(norm) == 0:
|
|
return None
|
|
norm = {
|
|
"GN": lambda channels: nn.GroupNorm(num_norm_groups, channels),
|
|
}[norm]
|
|
return norm(out_channels)
|
|
|
|
class Conv2d(nn.Conv2d):
|
|
"""
|
|
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""
|
|
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
|
Args:
|
|
norm (nn.Module, optional): a normalization layer
|
|
activation (callable(Tensor) -> Tensor): a callable activation function
|
|
It assumes that norm layer is used before activation.
|
|
"""
|
|
norm = kwargs.pop("norm", None)
|
|
activation = kwargs.pop("activation", None)
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.norm = norm
|
|
self.activation = activation
|
|
|
|
def forward(self, x):
|
|
x = F.conv2d(
|
|
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
|
)
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
if self.activation is not None:
|
|
x = self.activation(x)
|
|
return x
|
|
|
|
class CNNBlockBase(nn.Module):
|
|
"""
|
|
A CNN block is assumed to have input channels, output channels and a stride.
|
|
The input and output of `forward()` method must be NCHW tensors.
|
|
The method can perform arbitrary computation but must match the given
|
|
channels and stride specification.
|
|
Attribute:
|
|
in_channels (int):
|
|
out_channels (int):
|
|
stride (int):
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, stride):
|
|
"""
|
|
The `__init__` method of any subclass should also contain these arguments.
|
|
Args:
|
|
in_channels (int):
|
|
out_channels (int):
|
|
stride (int):
|
|
"""
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.stride = stride
|
|
|
|
class BottleneckBlock(CNNBlockBase):
|
|
"""
|
|
The standard bottleneck residual block used by ResNet-50, 101 and 152
|
|
defined in :paper:`ResNet`. It contains 3 conv layers with kernels
|
|
1x1, 3x3, 1x1, and a projection shortcut if needed.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
*,
|
|
bottleneck_channels,
|
|
stride=1,
|
|
num_groups=1,
|
|
norm="GN",
|
|
stride_in_1x1=False,
|
|
dilation=1,
|
|
num_norm_groups=32
|
|
):
|
|
"""
|
|
Args:
|
|
bottleneck_channels (int): number of output channels for the 3x3
|
|
"bottleneck" conv layers.
|
|
num_groups (int): number of groups for the 3x3 conv layer.
|
|
norm (str or callable): normalization for all conv layers.
|
|
See :func:`layers.get_norm` for supported format.
|
|
stride_in_1x1 (bool): when stride>1, whether to put stride in the
|
|
first 1x1 convolution or the bottleneck 3x3 convolution.
|
|
dilation (int): the dilation rate of the 3x3 conv layer.
|
|
"""
|
|
super().__init__(in_channels, out_channels, stride)
|
|
|
|
if in_channels != out_channels:
|
|
self.shortcut = Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=stride,
|
|
bias=False,
|
|
norm=get_norm(norm, out_channels, num_norm_groups),
|
|
)
|
|
else:
|
|
self.shortcut = None
|
|
|
|
|
|
|
|
|
|
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
|
|
|
self.conv1 = Conv2d(
|
|
in_channels,
|
|
bottleneck_channels,
|
|
kernel_size=1,
|
|
stride=stride_1x1,
|
|
bias=False,
|
|
norm=get_norm(norm, bottleneck_channels, num_norm_groups),
|
|
)
|
|
|
|
self.conv2 = Conv2d(
|
|
bottleneck_channels,
|
|
bottleneck_channels,
|
|
kernel_size=3,
|
|
stride=stride_3x3,
|
|
padding=1 * dilation,
|
|
bias=False,
|
|
groups=num_groups,
|
|
dilation=dilation,
|
|
norm=get_norm(norm, bottleneck_channels, num_norm_groups),
|
|
)
|
|
|
|
self.conv3 = Conv2d(
|
|
bottleneck_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
bias=False,
|
|
norm=get_norm(norm, out_channels, num_norm_groups),
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(x)
|
|
out = F.relu_(out)
|
|
|
|
out = self.conv2(out)
|
|
out = F.relu_(out)
|
|
|
|
out = self.conv3(out)
|
|
|
|
if self.shortcut is not None:
|
|
shortcut = self.shortcut(x)
|
|
else:
|
|
shortcut = x
|
|
|
|
out += shortcut
|
|
out = F.relu_(out)
|
|
return out
|
|
|
|
class ResNet(nn.Module):
|
|
"""
|
|
Implement :paper:`ResNet`.
|
|
"""
|
|
|
|
def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
|
|
"""
|
|
Args:
|
|
stem (nn.Module): a stem module
|
|
stages (list[list[CNNBlockBase]]): several (typically 4) stages,
|
|
each contains multiple :class:`CNNBlockBase`.
|
|
num_classes (None or int): if None, will not perform classification.
|
|
Otherwise, will create a linear layer.
|
|
out_features (list[str]): name of the layers whose outputs should
|
|
be returned in forward. Can be anything in "stem", "linear", or "res2" ...
|
|
If None, will return the output of the last layer.
|
|
freeze_at (int): The number of stages at the beginning to freeze.
|
|
see :meth:`freeze` for detailed explanation.
|
|
"""
|
|
super().__init__()
|
|
self.stem = stem
|
|
self.num_classes = num_classes
|
|
|
|
current_stride = self.stem.stride
|
|
self._out_feature_strides = {"stem": current_stride}
|
|
self._out_feature_channels = {"stem": self.stem.out_channels}
|
|
|
|
self.stage_names, self.stages = [], []
|
|
|
|
if out_features is not None:
|
|
|
|
|
|
num_stages = max(
|
|
[{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
|
|
)
|
|
stages = stages[:num_stages]
|
|
for i, blocks in enumerate(stages):
|
|
assert len(blocks) > 0, len(blocks)
|
|
for block in blocks:
|
|
assert isinstance(block, CNNBlockBase), block
|
|
|
|
name = "res" + str(i + 2)
|
|
stage = nn.Sequential(*blocks)
|
|
|
|
self.add_module(name, stage)
|
|
self.stage_names.append(name)
|
|
self.stages.append(stage)
|
|
|
|
self._out_feature_strides[name] = current_stride = int(
|
|
current_stride * np.prod([k.stride for k in blocks])
|
|
)
|
|
self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
|
|
self.stage_names = tuple(self.stage_names)
|
|
|
|
if num_classes is not None:
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.linear = nn.Linear(curr_channels, num_classes)
|
|
|
|
|
|
|
|
|
|
nn.init.normal_(self.linear.weight, std=0.01)
|
|
name = "linear"
|
|
|
|
if out_features is None:
|
|
out_features = [name]
|
|
self._out_features = out_features
|
|
assert len(self._out_features)
|
|
children = [x[0] for x in self.named_children()]
|
|
for out_feature in self._out_features:
|
|
assert out_feature in children, "Available children: {}".format(", ".join(children))
|
|
self.freeze(freeze_at)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Args:
|
|
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
|
Returns:
|
|
dict[str->Tensor]: names and the corresponding features
|
|
"""
|
|
assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
|
outputs = {}
|
|
x = self.stem(x)
|
|
if "stem" in self._out_features:
|
|
outputs["stem"] = x
|
|
for name, stage in zip(self.stage_names, self.stages):
|
|
x = stage(x)
|
|
if name in self._out_features:
|
|
outputs[name] = x
|
|
if self.num_classes is not None:
|
|
x = self.avgpool(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.linear(x)
|
|
if "linear" in self._out_features:
|
|
outputs["linear"] = x
|
|
return outputs
|
|
|
|
def freeze(self, freeze_at=0):
|
|
"""
|
|
Freeze the first several stages of the ResNet. Commonly used in
|
|
fine-tuning.
|
|
Layers that produce the same feature map spatial size are defined as one
|
|
"stage" by :paper:`FPN`.
|
|
Args:
|
|
freeze_at (int): number of stages to freeze.
|
|
`1` means freezing the stem. `2` means freezing the stem and
|
|
one residual stage, etc.
|
|
Returns:
|
|
nn.Module: this ResNet itself
|
|
"""
|
|
if freeze_at >= 1:
|
|
self.stem.freeze()
|
|
for idx, stage in enumerate(self.stages, start=2):
|
|
if freeze_at >= idx:
|
|
for block in stage.children():
|
|
block.freeze()
|
|
return self
|
|
|
|
@staticmethod
|
|
def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
|
|
"""
|
|
Create a list of blocks of the same type that forms one ResNet stage.
|
|
Args:
|
|
block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
|
|
stage. A module of this type must not change spatial resolution of inputs unless its
|
|
stride != 1.
|
|
num_blocks (int): number of blocks in this stage
|
|
in_channels (int): input channels of the entire stage.
|
|
out_channels (int): output channels of **every block** in the stage.
|
|
kwargs: other arguments passed to the constructor of
|
|
`block_class`. If the argument name is "xx_per_block", the
|
|
argument is a list of values to be passed to each block in the
|
|
stage. Otherwise, the same argument is passed to every block
|
|
in the stage.
|
|
Returns:
|
|
list[CNNBlockBase]: a list of block module.
|
|
Examples:
|
|
::
|
|
stage = ResNet.make_stage(
|
|
BottleneckBlock, 3, in_channels=16, out_channels=64,
|
|
bottleneck_channels=16, num_groups=1,
|
|
stride_per_block=[2, 1, 1],
|
|
dilations_per_block=[1, 1, 2]
|
|
)
|
|
Usually, layers that produce the same feature map spatial size are defined as one
|
|
"stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
|
|
all be 1.
|
|
"""
|
|
blocks = []
|
|
for i in range(num_blocks):
|
|
curr_kwargs = {}
|
|
for k, v in kwargs.items():
|
|
if k.endswith("_per_block"):
|
|
assert len(v) == num_blocks, (
|
|
f"Argument '{k}' of make_stage should have the "
|
|
f"same length as num_blocks={num_blocks}."
|
|
)
|
|
newk = k[: -len("_per_block")]
|
|
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
|
|
curr_kwargs[newk] = v[i]
|
|
else:
|
|
curr_kwargs[k] = v
|
|
|
|
blocks.append(
|
|
block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
|
|
)
|
|
in_channels = out_channels
|
|
return blocks
|
|
|
|
@staticmethod
|
|
def make_default_stages(depth, block_class=None, **kwargs):
|
|
"""
|
|
Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
|
|
If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
|
|
instead for fine-grained customization.
|
|
Args:
|
|
depth (int): depth of ResNet
|
|
block_class (type): the CNN block class. Has to accept
|
|
`bottleneck_channels` argument for depth > 50.
|
|
By default it is BasicBlock or BottleneckBlock, based on the
|
|
depth.
|
|
kwargs:
|
|
other arguments to pass to `make_stage`. Should not contain
|
|
stride and channels, as they are predefined for each depth.
|
|
Returns:
|
|
list[list[CNNBlockBase]]: modules in all stages; see arguments of
|
|
:class:`ResNet.__init__`.
|
|
"""
|
|
num_blocks_per_stage = {
|
|
18: [2, 2, 2, 2],
|
|
34: [3, 4, 6, 3],
|
|
50: [3, 4, 6, 3],
|
|
101: [3, 4, 23, 3],
|
|
152: [3, 8, 36, 3],
|
|
}[depth]
|
|
if block_class is None:
|
|
block_class = BasicBlock if depth < 50 else BottleneckBlock
|
|
if depth < 50:
|
|
in_channels = [64, 64, 128, 256]
|
|
out_channels = [64, 128, 256, 512]
|
|
else:
|
|
in_channels = [64, 256, 512, 1024]
|
|
out_channels = [256, 512, 1024, 2048]
|
|
ret = []
|
|
for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
|
|
if depth >= 50:
|
|
kwargs["bottleneck_channels"] = o // 4
|
|
ret.append(
|
|
ResNet.make_stage(
|
|
block_class=block_class,
|
|
num_blocks=n,
|
|
stride_per_block=[s] + [1] * (n - 1),
|
|
in_channels=i,
|
|
out_channels=o,
|
|
**kwargs,
|
|
)
|
|
)
|
|
return ret |