Spaces:
Sleeping
Sleeping
""" | |
Source url: https://github.com/lukemelas/EfficientNet-PyTorch | |
Modified by Min Seok Lee, Wooseok Shin, Nikita Selin | |
License: Apache License 2.0 | |
Changes: | |
- Added support for extracting edge features | |
- Added support for extracting object features at different levels | |
- Refactored the code | |
""" | |
from typing import Any, List | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from carvekit.ml.arch.tracerb7.effi_utils import ( | |
get_same_padding_conv2d, | |
calculate_output_image_size, | |
MemoryEfficientSwish, | |
drop_connect, | |
round_filters, | |
round_repeats, | |
Swish, | |
create_block_args, | |
) | |
class MBConvBlock(nn.Module): | |
"""Mobile Inverted Residual Bottleneck Block. | |
Args: | |
block_args (namedtuple): BlockArgs, defined in utils.py. | |
global_params (namedtuple): GlobalParam, defined in utils.py. | |
image_size (tuple or list): [image_height, image_width]. | |
References: | |
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1) | |
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2) | |
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3) | |
""" | |
def __init__(self, block_args, global_params, image_size=None): | |
super().__init__() | |
self._block_args = block_args | |
self._bn_mom = ( | |
1 - global_params.batch_norm_momentum | |
) # pytorch's difference from tensorflow | |
self._bn_eps = global_params.batch_norm_epsilon | |
self.has_se = (self._block_args.se_ratio is not None) and ( | |
0 < self._block_args.se_ratio <= 1 | |
) | |
self.id_skip = ( | |
block_args.id_skip | |
) # whether to use skip connection and drop connect | |
# Expansion phase (Inverted Bottleneck) | |
inp = self._block_args.input_filters # number of input channels | |
oup = ( | |
self._block_args.input_filters * self._block_args.expand_ratio | |
) # number of output channels | |
if self._block_args.expand_ratio != 1: | |
Conv2d = get_same_padding_conv2d(image_size=image_size) | |
self._expand_conv = Conv2d( | |
in_channels=inp, out_channels=oup, kernel_size=1, bias=False | |
) | |
self._bn0 = nn.BatchNorm2d( | |
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps | |
) | |
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size | |
# Depthwise convolution phase | |
k = self._block_args.kernel_size | |
s = self._block_args.stride | |
Conv2d = get_same_padding_conv2d(image_size=image_size) | |
self._depthwise_conv = Conv2d( | |
in_channels=oup, | |
out_channels=oup, | |
groups=oup, # groups makes it depthwise | |
kernel_size=k, | |
stride=s, | |
bias=False, | |
) | |
self._bn1 = nn.BatchNorm2d( | |
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps | |
) | |
image_size = calculate_output_image_size(image_size, s) | |
# Squeeze and Excitation layer, if desired | |
if self.has_se: | |
Conv2d = get_same_padding_conv2d(image_size=(1, 1)) | |
num_squeezed_channels = max( | |
1, int(self._block_args.input_filters * self._block_args.se_ratio) | |
) | |
self._se_reduce = Conv2d( | |
in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1 | |
) | |
self._se_expand = Conv2d( | |
in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1 | |
) | |
# Pointwise convolution phase | |
final_oup = self._block_args.output_filters | |
Conv2d = get_same_padding_conv2d(image_size=image_size) | |
self._project_conv = Conv2d( | |
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False | |
) | |
self._bn2 = nn.BatchNorm2d( | |
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps | |
) | |
self._swish = MemoryEfficientSwish() | |
def forward(self, inputs, drop_connect_rate=None): | |
"""MBConvBlock's forward function. | |
Args: | |
inputs (tensor): Input tensor. | |
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). | |
Returns: | |
Output of this block after processing. | |
""" | |
# Expansion and Depthwise Convolution | |
x = inputs | |
if self._block_args.expand_ratio != 1: | |
x = self._expand_conv(inputs) | |
x = self._bn0(x) | |
x = self._swish(x) | |
x = self._depthwise_conv(x) | |
x = self._bn1(x) | |
x = self._swish(x) | |
# Squeeze and Excitation | |
if self.has_se: | |
x_squeezed = F.adaptive_avg_pool2d(x, 1) | |
x_squeezed = self._se_reduce(x_squeezed) | |
x_squeezed = self._swish(x_squeezed) | |
x_squeezed = self._se_expand(x_squeezed) | |
x = torch.sigmoid(x_squeezed) * x | |
# Pointwise Convolution | |
x = self._project_conv(x) | |
x = self._bn2(x) | |
# Skip connection and drop connect | |
input_filters, output_filters = ( | |
self._block_args.input_filters, | |
self._block_args.output_filters, | |
) | |
if ( | |
self.id_skip | |
and self._block_args.stride == 1 | |
and input_filters == output_filters | |
): | |
# The combination of skip connection and drop connect brings about stochastic depth. | |
if drop_connect_rate: | |
x = drop_connect(x, p=drop_connect_rate, training=self.training) | |
x = x + inputs # skip connection | |
return x | |
def set_swish(self, memory_efficient=True): | |
"""Sets swish function as memory efficient (for training) or standard (for export). | |
Args: | |
memory_efficient (bool): Whether to use memory-efficient version of swish. | |
""" | |
self._swish = MemoryEfficientSwish() if memory_efficient else Swish() | |
class EfficientNet(nn.Module): | |
def __init__(self, blocks_args=None, global_params=None): | |
super().__init__() | |
assert isinstance(blocks_args, list), "blocks_args should be a list" | |
assert len(blocks_args) > 0, "block args must be greater than 0" | |
self._global_params = global_params | |
self._blocks_args = blocks_args | |
# Batch norm parameters | |
bn_mom = 1 - self._global_params.batch_norm_momentum | |
bn_eps = self._global_params.batch_norm_epsilon | |
# Get stem static or dynamic convolution depending on image size | |
image_size = global_params.image_size | |
Conv2d = get_same_padding_conv2d(image_size=image_size) | |
# Stem | |
in_channels = 3 # rgb | |
out_channels = round_filters( | |
32, self._global_params | |
) # number of output channels | |
self._conv_stem = Conv2d( | |
in_channels, out_channels, kernel_size=3, stride=2, bias=False | |
) | |
self._bn0 = nn.BatchNorm2d( | |
num_features=out_channels, momentum=bn_mom, eps=bn_eps | |
) | |
image_size = calculate_output_image_size(image_size, 2) | |
# Build blocks | |
self._blocks = nn.ModuleList([]) | |
for block_args in self._blocks_args: | |
# Update block input and output filters based on depth multiplier. | |
block_args = block_args._replace( | |
input_filters=round_filters( | |
block_args.input_filters, self._global_params | |
), | |
output_filters=round_filters( | |
block_args.output_filters, self._global_params | |
), | |
num_repeat=round_repeats(block_args.num_repeat, self._global_params), | |
) | |
# The first block needs to take care of stride and filter size increase. | |
self._blocks.append( | |
MBConvBlock(block_args, self._global_params, image_size=image_size) | |
) | |
image_size = calculate_output_image_size(image_size, block_args.stride) | |
if block_args.num_repeat > 1: # modify block_args to keep same output size | |
block_args = block_args._replace( | |
input_filters=block_args.output_filters, stride=1 | |
) | |
for _ in range(block_args.num_repeat - 1): | |
self._blocks.append( | |
MBConvBlock(block_args, self._global_params, image_size=image_size) | |
) | |
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 | |
self._swish = MemoryEfficientSwish() | |
def set_swish(self, memory_efficient=True): | |
"""Sets swish function as memory efficient (for training) or standard (for export). | |
Args: | |
memory_efficient (bool): Whether to use memory-efficient version of swish. | |
""" | |
self._swish = MemoryEfficientSwish() if memory_efficient else Swish() | |
for block in self._blocks: | |
block.set_swish(memory_efficient) | |
def extract_endpoints(self, inputs): | |
endpoints = dict() | |
# Stem | |
x = self._swish(self._bn0(self._conv_stem(inputs))) | |
prev_x = x | |
# Blocks | |
for idx, block in enumerate(self._blocks): | |
drop_connect_rate = self._global_params.drop_connect_rate | |
if drop_connect_rate: | |
drop_connect_rate *= float(idx) / len( | |
self._blocks | |
) # scale drop connect_rate | |
x = block(x, drop_connect_rate=drop_connect_rate) | |
if prev_x.size(2) > x.size(2): | |
endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x | |
prev_x = x | |
# Head | |
x = self._swish(self._bn1(self._conv_head(x))) | |
endpoints["reduction_{}".format(len(endpoints) + 1)] = x | |
return endpoints | |
def _change_in_channels(self, in_channels): | |
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3. | |
Args: | |
in_channels (int): Input data's channel number. | |
""" | |
if in_channels != 3: | |
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) | |
out_channels = round_filters(32, self._global_params) | |
self._conv_stem = Conv2d( | |
in_channels, out_channels, kernel_size=3, stride=2, bias=False | |
) | |
class EfficientEncoderB7(EfficientNet): | |
def __init__(self): | |
super().__init__( | |
*create_block_args( | |
width_coefficient=2.0, | |
depth_coefficient=3.1, | |
dropout_rate=0.5, | |
image_size=600, | |
) | |
) | |
self._change_in_channels(3) | |
self.block_idx = [10, 17, 37, 54] | |
self.channels = [48, 80, 224, 640] | |
def initial_conv(self, inputs): | |
x = self._swish(self._bn0(self._conv_stem(inputs))) | |
return x | |
def get_blocks(self, x, H, W, block_idx): | |
features = [] | |
for idx, block in enumerate(self._blocks): | |
drop_connect_rate = self._global_params.drop_connect_rate | |
if drop_connect_rate: | |
drop_connect_rate *= float(idx) / len( | |
self._blocks | |
) # scale drop connect_rate | |
x = block(x, drop_connect_rate=drop_connect_rate) | |
if idx == block_idx[0]: | |
features.append(x.clone()) | |
if idx == block_idx[1]: | |
features.append(x.clone()) | |
if idx == block_idx[2]: | |
features.append(x.clone()) | |
if idx == block_idx[3]: | |
features.append(x.clone()) | |
return features | |
def forward(self, inputs: torch.Tensor) -> List[Any]: | |
B, C, H, W = inputs.size() | |
x = self.initial_conv(inputs) # Prepare input for the backbone | |
return self.get_blocks( | |
x, H, W, block_idx=self.block_idx | |
) # Get backbone features and edge maps | |