File size: 9,049 Bytes
bdbd148 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
'''
EfficientNet in PyTorch.
Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks"
Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py
主要特点:
1. 使用MBConv作为基本模块,包含SE注意力机制
2. 通过复合缩放方法(compound scaling)同时调整网络的宽度、深度和分辨率
3. 使用Swish激活函数和DropConnect正则化
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def swish(x):
"""Swish激活函数: x * sigmoid(x)"""
return x * x.sigmoid()
def drop_connect(x, drop_ratio):
"""DropConnect正则化
Args:
x: 输入tensor
drop_ratio: 丢弃率
Returns:
经过DropConnect处理的tensor
"""
keep_ratio = 1.0 - drop_ratio
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
mask.bernoulli_(keep_ratio)
x.div_(keep_ratio)
x.mul_(mask)
return x
class SE(nn.Module):
'''Squeeze-and-Excitation注意力模块
Args:
in_channels: 输入通道数
se_channels: SE模块中间层的通道数
'''
def __init__(self, in_channels, se_channels):
super(SE, self).__init__()
self.se1 = nn.Conv2d(in_channels, se_channels, kernel_size=1, bias=True)
self.se2 = nn.Conv2d(se_channels, in_channels, kernel_size=1, bias=True)
def forward(self, x):
out = F.adaptive_avg_pool2d(x, (1, 1)) # 全局平均池化
out = swish(self.se1(out))
out = self.se2(out).sigmoid()
return x * out # 特征重标定
class MBConv(nn.Module):
'''MBConv模块: Mobile Inverted Bottleneck Convolution
Args:
in_channels: 输入通道数
out_channels: 输出通道数
kernel_size: 卷积核大小
stride: 步长
expand_ratio: 扩展比率
se_ratio: SE模块的压缩比率
drop_rate: DropConnect的丢弃率
'''
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
expand_ratio=1,
se_ratio=0.25,
drop_rate=0.):
super(MBConv, self).__init__()
self.stride = stride
self.drop_rate = drop_rate
self.expand_ratio = expand_ratio
# Expansion phase
channels = expand_ratio * in_channels
self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, stride=1, padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
# Depthwise conv
self.conv2 = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride,
padding=(1 if kernel_size == 3 else 2), groups=channels, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
# SE layers
se_channels = int(in_channels * se_ratio)
self.se = SE(channels, se_channels)
# Output phase
self.conv3 = nn.Conv2d(channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
# Shortcut connection
self.has_skip = (stride == 1) and (in_channels == out_channels)
def forward(self, x):
# Expansion
out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x)))
# Depthwise convolution
out = swish(self.bn2(self.conv2(out)))
# Squeeze-and-excitation
out = self.se(out)
# Pointwise convolution
out = self.bn3(self.conv3(out))
# Shortcut
if self.has_skip:
if self.training and self.drop_rate > 0:
out = drop_connect(out, self.drop_rate)
out = out + x
return out
class EfficientNet(nn.Module):
'''EfficientNet模型
Args:
width_coefficient: 宽度系数
depth_coefficient: 深度系数
dropout_rate: 分类层的dropout率
num_classes: 分类数量
'''
def __init__(self,
width_coefficient=1.0,
depth_coefficient=1.0,
dropout_rate=0.2,
num_classes=10):
super(EfficientNet, self).__init__()
# 模型配置
cfg = {
'num_blocks': [1, 2, 2, 3, 3, 4, 1], # 每个stage的block数量
'expansion': [1, 6, 6, 6, 6, 6, 6], # 扩展比率
'out_channels': [16, 24, 40, 80, 112, 192, 320], # 输出通道数
'kernel_size': [3, 3, 5, 3, 5, 5, 3], # 卷积核大小
'stride': [1, 2, 2, 2, 1, 2, 1], # 步长
'dropout_rate': dropout_rate,
'drop_connect_rate': 0.2,
}
self.cfg = cfg
self.width_coefficient = width_coefficient
self.depth_coefficient = depth_coefficient
# Stem layer
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(32)
# Build blocks
self.layers = self._make_layers(in_channels=32)
# Head layer
final_channels = cfg['out_channels'][-1] * int(width_coefficient)
self.linear = nn.Linear(final_channels, num_classes)
def _make_layers(self, in_channels):
layers = []
cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size', 'stride']]
blocks = sum(self.cfg['num_blocks'])
b = 0 # 用于计算drop_connect_rate
for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg):
out_channels = int(out_channels * self.width_coefficient)
num_blocks = int(math.ceil(num_blocks * self.depth_coefficient))
for i in range(num_blocks):
stride_i = stride if i == 0 else 1
drop_rate = self.cfg['drop_connect_rate'] * b / blocks
layers.append(
MBConv(in_channels,
out_channels,
kernel_size,
stride_i,
expansion,
se_ratio=0.25,
drop_rate=drop_rate))
in_channels = out_channels
b += 1
return nn.Sequential(*layers)
def forward(self, x):
# Stem
out = swish(self.bn1(self.conv1(x)))
# Blocks
out = self.layers(out)
# Head
out = F.adaptive_avg_pool2d(out, 1)
out = out.view(out.size(0), -1)
if self.training and self.cfg['dropout_rate'] > 0:
out = F.dropout(out, p=self.cfg['dropout_rate'])
out = self.linear(out)
return out
def EfficientNetB0(num_classes=10):
"""EfficientNet-B0"""
return EfficientNet(width_coefficient=1.0,
depth_coefficient=1.0,
dropout_rate=0.2,
num_classes=num_classes)
def EfficientNetB1(num_classes=10):
"""EfficientNet-B1"""
return EfficientNet(width_coefficient=1.0,
depth_coefficient=1.1,
dropout_rate=0.2,
num_classes=num_classes)
def EfficientNetB2(num_classes=10):
"""EfficientNet-B2"""
return EfficientNet(width_coefficient=1.1,
depth_coefficient=1.2,
dropout_rate=0.3,
num_classes=num_classes)
def EfficientNetB3(num_classes=10):
"""EfficientNet-B3"""
return EfficientNet(width_coefficient=1.2,
depth_coefficient=1.4,
dropout_rate=0.3,
num_classes=num_classes)
def EfficientNetB4(num_classes=10):
"""EfficientNet-B4"""
return EfficientNet(width_coefficient=1.4,
depth_coefficient=1.8,
dropout_rate=0.4,
num_classes=num_classes)
def EfficientNetB5(num_classes=10):
"""EfficientNet-B5"""
return EfficientNet(width_coefficient=1.6,
depth_coefficient=2.2,
dropout_rate=0.4,
num_classes=num_classes)
def EfficientNetB6(num_classes=10):
"""EfficientNet-B6"""
return EfficientNet(width_coefficient=1.8,
depth_coefficient=2.6,
dropout_rate=0.5,
num_classes=num_classes)
def EfficientNetB7(num_classes=10):
"""EfficientNet-B7"""
return EfficientNet(width_coefficient=2.0,
depth_coefficient=3.1,
dropout_rate=0.5,
num_classes=num_classes)
def test():
"""测试函数"""
net = EfficientNetB0()
x = torch.randn(1, 3, 32, 32)
y = net(x)
print(y.size())
from torchinfo import summary
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = net.to(device)
summary(net, (1, 3, 32, 32))
if __name__ == '__main__':
test() |