picklebot_demo / mobilenet.py
hbfreed's picture
added more necessary stuff
010a8b2 verified
raw
history blame
No virus
21.5 kB
'''
Implementing Mobilenet v3 as seen in
"Searching for MobileNetV3" for video classification,
note that balls are 0 and strikes are 1.
'''
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
class SEBlock3D(nn.Module):
def __init__(self,channels):
super().__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool3d(1),
nn.Conv3d(channels,channels//4,kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv3d(channels//4,channels,kernel_size=1),
nn.Hardsigmoid()
)
def forward(self,x):
w = self.se(x)
x = x * w
return x
class SEBlock2D(nn.Module):
def __init__(self,channels):
super().__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels,channels//4,kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(channels//4,channels,kernel_size=1),
nn.Hardsigmoid()
)
def forward(self,x):
w = self.se(x)
x = x * w
return x
#Bottleneck for Mobilenets
class Bottleneck3D(nn.Module):
def __init__(self, in_channels, out_channels, expanded_channels, stride=1, use_se=False, kernel_size=3,nonlinearity=nn.Hardswish(),batchnorm=True,dropout=0,bias=False):
super().__init__()
#pointwise conv1x1x1 (reduce channels)
self.pointwise_conv1 = nn.Conv3d(in_channels,expanded_channels,kernel_size=1,bias=bias)
#depthwise (spatial filtering)
#groups to preserve channel-wise information
self.depthwise_conv = nn.Conv3d(
expanded_channels,#in channels
expanded_channels,#out channels
groups=expanded_channels,
kernel_size=(1,kernel_size,kernel_size),
stride=stride,
padding=kernel_size//2,
bias=bias
)
#squeeze-and-excite (recalibrate channel wise)
self.squeeze_excite = SEBlock3D(expanded_channels) if use_se else None
#pointwise conv1x1x1 (expansion to increase channels)
self.pointwise_conv2 = nn.Conv3d(expanded_channels,out_channels,kernel_size=1,bias=bias)
self.batchnorm = nn.BatchNorm3d(out_channels) if batchnorm else None
self.nonlinearity = nonlinearity
self.dropout = nn.Dropout3d(p=dropout)
def forward(self,x):
x = self.pointwise_conv1(x)
x = self.depthwise_conv(x)
if self.squeeze_excite is not None:
x = self.squeeze_excite(x)
x = self.pointwise_conv2(x)
x = self.batchnorm(x)
x = self.nonlinearity(x)
x = self.dropout(x)
return x
#2D bottleneck for our 2d convnet with LSTM
class Bottleneck2D(nn.Module):
def __init__(self, in_channels, out_channels, expanded_channels, stride=1, use_se=False, kernel_size=3,nonlinearity=nn.Hardswish(),batchnorm=True,dropout=0,bias=False):
super().__init__()
#pointwise conv1x1x1 (reduce channels)
self.pointwise_conv1 = nn.Conv2d(in_channels,expanded_channels,kernel_size=1,bias=bias)
#depthwise (spatial filtering)
#groups to preserve channel-wise information
self.depthwise_conv = nn.Conv2d(
expanded_channels,#in channels
expanded_channels,#out channels
groups=expanded_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size//2,
bias=bias
)
#squeeze-and-excite (recalibrate channel wise)
self.squeeze_excite = SEBlock2D(expanded_channels) if use_se else None
#pointwise conv1x1x1 (expansion to increase channels)
self.pointwise_conv2 = nn.Conv2d(expanded_channels,out_channels,kernel_size=1,bias=bias)
self.batchnorm = nn.BatchNorm2d(out_channels) if batchnorm else None
self.nonlinearity = nonlinearity
self.dropout = nn.Dropout2d(p=dropout)
def forward(self,x):
x = self.pointwise_conv1(x)
x = self.depthwise_conv(x)
if self.squeeze_excite is not None:
x = self.squeeze_excite(x)
x = self.pointwise_conv2(x)
x = self.batchnorm(x)
x = self.nonlinearity(x)
return x
#mobilenet large 3d convolutions
class MobileNetLarge3D(nn.Module):
def __init__(self,num_classes=2):
super().__init__()
self.num_classes = num_classes
#conv3d (h-swish): 224x224x3 -> 112x112x16
self.block1 = nn.Sequential(
nn.Conv3d(in_channels=3,out_channels=16,stride=2,kernel_size=3,padding=1),
nn.BatchNorm3d(16),
nn.Hardswish()
)
#3x3 bottlenecks1 (3, ReLU): 112x112x16 -> 56x56x24
self.block2 = nn.Sequential(
Bottleneck3D(in_channels=16,out_channels=16,expanded_channels=16,stride=1,nonlinearity=nn.ReLU(),dropout=0.2),
Bottleneck3D(in_channels=16,out_channels=24,expanded_channels=64,stride=2,nonlinearity=nn.ReLU(),dropout=0.2),
Bottleneck3D(in_channels=24,out_channels=24,expanded_channels=72,stride=1,nonlinearity=nn.ReLU(),dropout=0.2)
)
#5x5 bottlenecks1 (3, ReLU, squeeze-excite): 56x56x24 -> 28x28x40
self.block3 = nn.Sequential(
Bottleneck3D(in_channels=24,out_channels=40,expanded_channels=72,stride=2,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2),
Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2),
Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2)
)
#3x3 bottlenecks2 (6, h-swish, last two get squeeze-excite): 28x28x40 -> 14x14x112
self.block4 = nn.Sequential(
Bottleneck3D(in_channels=40,out_channels=80,expanded_channels=240,stride=2,dropout=0.2),
Bottleneck3D(in_channels=80,out_channels=80,expanded_channels=240,stride=1,dropout=0.2),
Bottleneck3D(in_channels=80,out_channels=80,expanded_channels=184,stride=1,dropout=0.2),
Bottleneck3D(in_channels=80,out_channels=80,expanded_channels=184,stride=1,dropout=0.2),
Bottleneck3D(in_channels=80,out_channels=112,expanded_channels=480,stride=1,use_se=True,dropout=0.2),
Bottleneck3D(in_channels=112,out_channels=112,expanded_channels=672,stride=1,use_se=True,dropout=0.2)
)
#5x5 bottlenecks2 (3, h-swish, squeeze-excite): 14x14x112 -> 7x7x160
self.block5 = nn.Sequential(
Bottleneck3D(in_channels=112,out_channels=160,expanded_channels=672,stride=2,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck3D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck3D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5,dropout=0.2)
)
#conv3d (h-swish), avg pool 7x7: 7x7x960 -> 1x1x960
self.block6 = nn.Sequential(
nn.Conv3d(in_channels=160,out_channels=960,stride=1,kernel_size=1),
nn.BatchNorm3d(960),
nn.Hardswish()
)
#classifier: conv3d 1x1 NBN (2, first uses h-swish): 1x1x960
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool3d((1,1,1)),
nn.Conv3d(in_channels=960,out_channels=1280,kernel_size=1,stride=1,padding=0), #2 classes for ball/strike
nn.Hardswish(),
nn.Conv3d(in_channels=1280,out_channels=self.num_classes,kernel_size=1,stride=1,padding=0)
)
def forward(self,x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.classifier(x)
x = x.view(x.shape[0], self.num_classes)
return x
def initialize_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Linear):
if hasattr(module, "nonlinearity"):
if module.nonlinearity == 'relu':
init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif module.nonlinearity == 'hardswish':
init.xavier_uniform_(module.weight)
elif isinstance(module, nn.BatchNorm3d):
init.constant_(module.weight, 1)
init.constant_(module.bias, 0)
#mobilenet small 3d convolutions
class MobileNetSmall3D(nn.Module):
def __init__(self,num_classes=2):
super().__init__()
self.num_classes = num_classes
#conv3d (h-swish): 224x224x3 -> 112x112x16
self.block1 = nn.Sequential(
nn.Conv3d(in_channels=3,out_channels=16,kernel_size=3,stride=2,padding=1),
nn.BatchNorm3d(16),
nn.Hardswish()
)
#3x3 bottlenecks (3, ReLU, first gets squeeze-excite): 112x112x16 -> 28x28x24
self.block2 = nn.Sequential(
Bottleneck3D(in_channels=16,out_channels=16,expanded_channels=16,stride=2,use_se=True,nonlinearity=nn.LeakyReLU(),dropout=0.2),
Bottleneck3D(in_channels=16,out_channels=24,expanded_channels=72,stride=2,nonlinearity=nn.LeakyReLU(),dropout=0.2),
Bottleneck3D(in_channels=24,out_channels=24,expanded_channels=88,stride=1,nonlinearity=nn.LeakyReLU(),dropout=0.2)
)
#5x5 bottlenecks (8, h-swish, squeeze-excite): 28x28x24 -> 7x7x96
self.block3 = nn.Sequential(
Bottleneck3D(in_channels=24,out_channels=40,expanded_channels=96,stride=2,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck3D(in_channels=40,out_channels=48,expanded_channels=120,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck3D(in_channels=48,out_channels=48,expanded_channels=144,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck3D(in_channels=48,out_channels=96,expanded_channels=288,stride=2,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck3D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck3D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2)
)
#conv3d (h-swish), avg pool 7x7: 7x7x96 -> 1x1x576
self.block4 = nn.Sequential(
nn.Conv3d(in_channels=96,out_channels=576,kernel_size=1,stride=1,padding=0),
SEBlock3D(channels=576),
nn.BatchNorm3d(576),
nn.Hardswish()
)
#conv3d 1x1, NBN, (2, first uses h-swish): 1x1x576
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool3d((1,1,1)),
nn.Conv3d(in_channels=576,out_channels=1024,kernel_size=1,stride=1,padding=0),
nn.Hardswish(),
nn.Conv3d(in_channels=1024,out_channels=self.num_classes,kernel_size=1,stride=1,padding=0),
)
def forward(self,x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.classifier(x)
x = x.view(x.shape[0], self.num_classes)
return x
def initialize_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Linear):
if hasattr(module, "nonlinearity"):
if module.nonlinearity == 'relu' or 'leaky_relu':
init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif module.nonlinearity == 'hardswish':
init.xavier_uniform_(module.weight)
elif isinstance(module, nn.BatchNorm3d):
init.constant_(module.weight, 1)
init.constant_(module.bias, 0)
#MobileNetV3-Large 2D + LSTM for helping with the temporal dimension
class MobileNetLarge2D(nn.Module):
def __init__(self, num_classes=2):
super().__init__()
self.num_classes = num_classes
def initialize_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
if hasattr(module, "nonlinearity"):
if module.nonlinearity == 'relu':
init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif module.nonlinearity == 'hardswish':
init.xavier_uniform_(module.weight)
elif isinstance(module, nn.BatchNorm2d):
init.constant_(module.weight, 1)
init.constant_(module.bias, 0)
#conv2d (h-swish): 224x224x3 -> 112x112x16
self.block1 = nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=16,stride=2,kernel_size=3,padding=1),
nn.BatchNorm2d(16),
nn.Hardswish()
)
#3x3 bottlenecks1 (3, ReLU): 112x112x16 -> 56x56x24
self.block2 = nn.Sequential(
Bottleneck2D(in_channels=16,out_channels=16,expanded_channels=16,stride=1,nonlinearity=nn.ReLU(),dropout=0.2),
Bottleneck2D(in_channels=16,out_channels=24,expanded_channels=64,stride=2,nonlinearity=nn.ReLU()),
Bottleneck2D(in_channels=24,out_channels=24,expanded_channels=72,stride=1,nonlinearity=nn.ReLU(),dropout=0.2)
)
#5x5 bottlenecks1 (3, ReLU, squeeze-excite): 56x56x24 -> 28x28x40
self.block3 = nn.Sequential(
Bottleneck2D(in_channels=24,out_channels=40,expanded_channels=72,stride=2,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2),
Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU()),
Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2)
)
#3x3 bottlenecks2 (6, h-swish, last two get squeeze-excite): 28x28x40 -> 14x14x112
self.block4 = nn.Sequential(
Bottleneck2D(in_channels=40,out_channels=80,expanded_channels=240,stride=2,dropout=0.2),
Bottleneck2D(in_channels=80,out_channels=80,expanded_channels=240,stride=1),
Bottleneck2D(in_channels=80,out_channels=80,expanded_channels=184,stride=1,dropout=0.2),
Bottleneck2D(in_channels=80,out_channels=80,expanded_channels=184,stride=1),
Bottleneck2D(in_channels=80,out_channels=112,expanded_channels=480,stride=1,use_se=True,dropout=0.2),
Bottleneck2D(in_channels=112,out_channels=112,expanded_channels=672,stride=1,use_se=True,dropout=0.2)
)
#5x5 bottlenecks2 (3, h-swish, squeeze-excite): 14x14x112 -> 7x7x160
self.block5 = nn.Sequential(
Bottleneck2D(in_channels=112,out_channels=160,expanded_channels=672,stride=2,use_se=True,kernel_size=5),
Bottleneck2D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5),
Bottleneck2D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5)
)
#conv3d (h-swish), avg pool 7x7: 7x7x960 -> 1x1x960
self.block6 = nn.Sequential(
nn.Conv2d(in_channels=160,out_channels=960,stride=1,kernel_size=1),
nn.BatchNorm2d(960),
nn.Hardswish(),
nn.AvgPool2d(kernel_size=7,stride=1)
)
#LSTM: 1x1x960 ->
self.lstm = nn.LSTM(input_size=960,hidden_size=32,num_layers=5,batch_first=True)
#classifier: conv3d 1x1 NBN (2, first uses h-swish): 1x1x960
self.classifier = nn.Sequential(
nn.Linear(32,self.num_classes) #2 classes for ball/strike
)
def forward(self,x):
#x is shape (batch_size, timesteps, C, H, W)
batch_size,timesteps,C,H,W = x.size()
cnn_out = torch.zeros(batch_size,timesteps,960).to(x.device) #assuming the output of block6 is 960
#we're looping through the frames in the video
for i in range(timesteps):
# Select the frame at the ith position
frame = x[:, i, :, :, :]
frame = self.block1(frame)
frame = self.block2(frame)
frame = self.block3(frame)
frame = self.block4(frame)
frame = self.block5(frame)
frame = self.block6(frame)
# Flatten the frame (minus the batch dimension)
frame = frame.view(frame.size(0), -1)
cnn_out[:, i, :] = frame
# reshape for LSTM
x = cnn_out
x, _ = self.lstm(x)
# get the output from the last timestep only
x = x[:, -1, :]
x = self.classifier(x)
return x
#MobileNetV3-Small 2d with lstm for helping with the temporal dimension
class MobileNetSmall2D(nn.Module):
def __init__(self,num_classes=2):
super().__init__()
self.num_classes = num_classes
#conv3d (h-swish): 224x224x3 -> 112x112x16
self.block1 = nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=2,padding=1),
nn.BatchNorm2d(16),
nn.Hardswish()
)
#3x3 bottlenecks (3, ReLU, first gets squeeze-excite): 112x112x16 -> 28x28x24
self.block2 = nn.Sequential(
Bottleneck2D(in_channels=16,out_channels=16,expanded_channels=16,stride=2,use_se=True,nonlinearity=nn.ReLU(),dropout=0.2),
Bottleneck2D(in_channels=16,out_channels=24,expanded_channels=72,stride=2,nonlinearity=nn.ReLU(),dropout=0.2),
Bottleneck2D(in_channels=24,out_channels=24,expanded_channels=88,stride=1,nonlinearity=nn.ReLU(),dropout=0.2)
)
#5x5 bottlenecks (8, h-swish, squeeze-excite): 28x28x24 -> 7x7x96
self.block3 = nn.Sequential(
Bottleneck2D(in_channels=24,out_channels=40,expanded_channels=96,stride=2,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck2D(in_channels=40,out_channels=48,expanded_channels=120,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck2D(in_channels=48,out_channels=48,expanded_channels=144,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck2D(in_channels=48,out_channels=96,expanded_channels=288,stride=2,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck2D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2),
Bottleneck2D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2)
)
#conv2d (h-swish), avg pool 7x7: 7x7x96 -> 1x1x576
self.block4 = nn.Sequential(
nn.Conv2d(in_channels=96,out_channels=576,kernel_size=1,stride=1,padding=0),
SEBlock2D(channels=576),
nn.BatchNorm2d(576),
nn.Hardswish(),
nn.AvgPool2d(kernel_size=7,stride=1)
)
#LSTM: 1x1x576 ->
self.lstm = nn.LSTM(input_size=576,hidden_size=64,num_layers=1,batch_first=True)
#classifier: conv3d 1x1 NBN (2, first uses h-swish): 1x1x576
self.classifier = nn.Sequential(
nn.Linear(64,self.num_classes) #2 classes for ball/strike
)
def forward(self,x):
# x is of shape (batch_size, timesteps, C, H, W)
batch_size, timesteps, C, H, W = x.size()
cnn_out = torch.zeros(batch_size, timesteps, 576).to(x.device) #assuming the output of block4 is 576
#we're looping through the frames in the video
for i in range(timesteps):
# Select the frame at the ith position
frame = x[:, i, :, :, :]
frame = self.block1(frame)
frame = self.block2(frame)
frame = self.block3(frame)
frame = self.block4(frame)
# Flatten the frame (minus the batch dimension)
frame = frame.view(frame.size(0), -1)
cnn_out[:, i, :] = frame
# reshape for LSTM
x = cnn_out
x, _ = self.lstm(x)
# get the output from the last timestep only
x = x[:, -1, :]
x = self.classifier(x)
return x
def initialize_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
if hasattr(module, "nonlinearity"):
if module.nonlinearity == 'relu':
init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif module.nonlinearity == 'hardswish':
init.xavier_uniform_(module.weight)
elif isinstance(module, nn.BatchNorm2d):
init.constant_(module.weight, 1)
init.constant_(module.bias, 0)