Spaces:
Runtime error
Runtime error
''' | |
Implementing Mobilenet v3 as seen in | |
"Searching for MobileNetV3" for video classification, | |
note that balls are 0 and strikes are 1. | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.init as init | |
import torch.nn.functional as F | |
class SEBlock3D(nn.Module): | |
def __init__(self,channels): | |
super().__init__() | |
self.se = nn.Sequential( | |
nn.AdaptiveAvgPool3d(1), | |
nn.Conv3d(channels,channels//4,kernel_size=1), | |
nn.ReLU(inplace=True), | |
nn.Conv3d(channels//4,channels,kernel_size=1), | |
nn.Hardsigmoid() | |
) | |
def forward(self,x): | |
w = self.se(x) | |
x = x * w | |
return x | |
class SEBlock2D(nn.Module): | |
def __init__(self,channels): | |
super().__init__() | |
self.se = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Conv2d(channels,channels//4,kernel_size=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(channels//4,channels,kernel_size=1), | |
nn.Hardsigmoid() | |
) | |
def forward(self,x): | |
w = self.se(x) | |
x = x * w | |
return x | |
#Bottleneck for Mobilenets | |
class Bottleneck3D(nn.Module): | |
def __init__(self, in_channels, out_channels, expanded_channels, stride=1, use_se=False, kernel_size=3,nonlinearity=nn.Hardswish(),batchnorm=True,dropout=0,bias=False): | |
super().__init__() | |
#pointwise conv1x1x1 (reduce channels) | |
self.pointwise_conv1 = nn.Conv3d(in_channels,expanded_channels,kernel_size=1,bias=bias) | |
#depthwise (spatial filtering) | |
#groups to preserve channel-wise information | |
self.depthwise_conv = nn.Conv3d( | |
expanded_channels,#in channels | |
expanded_channels,#out channels | |
groups=expanded_channels, | |
kernel_size=(1,kernel_size,kernel_size), | |
stride=stride, | |
padding=kernel_size//2, | |
bias=bias | |
) | |
#squeeze-and-excite (recalibrate channel wise) | |
self.squeeze_excite = SEBlock3D(expanded_channels) if use_se else None | |
#pointwise conv1x1x1 (expansion to increase channels) | |
self.pointwise_conv2 = nn.Conv3d(expanded_channels,out_channels,kernel_size=1,bias=bias) | |
self.batchnorm = nn.BatchNorm3d(out_channels) if batchnorm else None | |
self.nonlinearity = nonlinearity | |
self.dropout = nn.Dropout3d(p=dropout) | |
def forward(self,x): | |
x = self.pointwise_conv1(x) | |
x = self.depthwise_conv(x) | |
if self.squeeze_excite is not None: | |
x = self.squeeze_excite(x) | |
x = self.pointwise_conv2(x) | |
x = self.batchnorm(x) | |
x = self.nonlinearity(x) | |
x = self.dropout(x) | |
return x | |
#2D bottleneck for our 2d convnet with LSTM | |
class Bottleneck2D(nn.Module): | |
def __init__(self, in_channels, out_channels, expanded_channels, stride=1, use_se=False, kernel_size=3,nonlinearity=nn.Hardswish(),batchnorm=True,dropout=0,bias=False): | |
super().__init__() | |
#pointwise conv1x1x1 (reduce channels) | |
self.pointwise_conv1 = nn.Conv2d(in_channels,expanded_channels,kernel_size=1,bias=bias) | |
#depthwise (spatial filtering) | |
#groups to preserve channel-wise information | |
self.depthwise_conv = nn.Conv2d( | |
expanded_channels,#in channels | |
expanded_channels,#out channels | |
groups=expanded_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=kernel_size//2, | |
bias=bias | |
) | |
#squeeze-and-excite (recalibrate channel wise) | |
self.squeeze_excite = SEBlock2D(expanded_channels) if use_se else None | |
#pointwise conv1x1x1 (expansion to increase channels) | |
self.pointwise_conv2 = nn.Conv2d(expanded_channels,out_channels,kernel_size=1,bias=bias) | |
self.batchnorm = nn.BatchNorm2d(out_channels) if batchnorm else None | |
self.nonlinearity = nonlinearity | |
self.dropout = nn.Dropout2d(p=dropout) | |
def forward(self,x): | |
x = self.pointwise_conv1(x) | |
x = self.depthwise_conv(x) | |
if self.squeeze_excite is not None: | |
x = self.squeeze_excite(x) | |
x = self.pointwise_conv2(x) | |
x = self.batchnorm(x) | |
x = self.nonlinearity(x) | |
return x | |
#mobilenet large 3d convolutions | |
class MobileNetLarge3D(nn.Module): | |
def __init__(self,num_classes=2): | |
super().__init__() | |
self.num_classes = num_classes | |
#conv3d (h-swish): 224x224x3 -> 112x112x16 | |
self.block1 = nn.Sequential( | |
nn.Conv3d(in_channels=3,out_channels=16,stride=2,kernel_size=3,padding=1), | |
nn.BatchNorm3d(16), | |
nn.Hardswish() | |
) | |
#3x3 bottlenecks1 (3, ReLU): 112x112x16 -> 56x56x24 | |
self.block2 = nn.Sequential( | |
Bottleneck3D(in_channels=16,out_channels=16,expanded_channels=16,stride=1,nonlinearity=nn.ReLU(),dropout=0.2), | |
Bottleneck3D(in_channels=16,out_channels=24,expanded_channels=64,stride=2,nonlinearity=nn.ReLU(),dropout=0.2), | |
Bottleneck3D(in_channels=24,out_channels=24,expanded_channels=72,stride=1,nonlinearity=nn.ReLU(),dropout=0.2) | |
) | |
#5x5 bottlenecks1 (3, ReLU, squeeze-excite): 56x56x24 -> 28x28x40 | |
self.block3 = nn.Sequential( | |
Bottleneck3D(in_channels=24,out_channels=40,expanded_channels=72,stride=2,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2), | |
Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2), | |
Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2) | |
) | |
#3x3 bottlenecks2 (6, h-swish, last two get squeeze-excite): 28x28x40 -> 14x14x112 | |
self.block4 = nn.Sequential( | |
Bottleneck3D(in_channels=40,out_channels=80,expanded_channels=240,stride=2,dropout=0.2), | |
Bottleneck3D(in_channels=80,out_channels=80,expanded_channels=240,stride=1,dropout=0.2), | |
Bottleneck3D(in_channels=80,out_channels=80,expanded_channels=184,stride=1,dropout=0.2), | |
Bottleneck3D(in_channels=80,out_channels=80,expanded_channels=184,stride=1,dropout=0.2), | |
Bottleneck3D(in_channels=80,out_channels=112,expanded_channels=480,stride=1,use_se=True,dropout=0.2), | |
Bottleneck3D(in_channels=112,out_channels=112,expanded_channels=672,stride=1,use_se=True,dropout=0.2) | |
) | |
#5x5 bottlenecks2 (3, h-swish, squeeze-excite): 14x14x112 -> 7x7x160 | |
self.block5 = nn.Sequential( | |
Bottleneck3D(in_channels=112,out_channels=160,expanded_channels=672,stride=2,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck3D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck3D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5,dropout=0.2) | |
) | |
#conv3d (h-swish), avg pool 7x7: 7x7x960 -> 1x1x960 | |
self.block6 = nn.Sequential( | |
nn.Conv3d(in_channels=160,out_channels=960,stride=1,kernel_size=1), | |
nn.BatchNorm3d(960), | |
nn.Hardswish() | |
) | |
#classifier: conv3d 1x1 NBN (2, first uses h-swish): 1x1x960 | |
self.classifier = nn.Sequential( | |
nn.AdaptiveAvgPool3d((1,1,1)), | |
nn.Conv3d(in_channels=960,out_channels=1280,kernel_size=1,stride=1,padding=0), #2 classes for ball/strike | |
nn.Hardswish(), | |
nn.Conv3d(in_channels=1280,out_channels=self.num_classes,kernel_size=1,stride=1,padding=0) | |
) | |
def forward(self,x): | |
x = self.block1(x) | |
x = self.block2(x) | |
x = self.block3(x) | |
x = self.block4(x) | |
x = self.block5(x) | |
x = self.block6(x) | |
x = self.classifier(x) | |
x = x.view(x.shape[0], self.num_classes) | |
return x | |
def initialize_weights(self): | |
for module in self.modules(): | |
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Linear): | |
if hasattr(module, "nonlinearity"): | |
if module.nonlinearity == 'relu': | |
init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |
elif module.nonlinearity == 'hardswish': | |
init.xavier_uniform_(module.weight) | |
elif isinstance(module, nn.BatchNorm3d): | |
init.constant_(module.weight, 1) | |
init.constant_(module.bias, 0) | |
#mobilenet small 3d convolutions | |
class MobileNetSmall3D(nn.Module): | |
def __init__(self,num_classes=2): | |
super().__init__() | |
self.num_classes = num_classes | |
#conv3d (h-swish): 224x224x3 -> 112x112x16 | |
self.block1 = nn.Sequential( | |
nn.Conv3d(in_channels=3,out_channels=16,kernel_size=3,stride=2,padding=1), | |
nn.BatchNorm3d(16), | |
nn.Hardswish() | |
) | |
#3x3 bottlenecks (3, ReLU, first gets squeeze-excite): 112x112x16 -> 28x28x24 | |
self.block2 = nn.Sequential( | |
Bottleneck3D(in_channels=16,out_channels=16,expanded_channels=16,stride=2,use_se=True,nonlinearity=nn.LeakyReLU(),dropout=0.2), | |
Bottleneck3D(in_channels=16,out_channels=24,expanded_channels=72,stride=2,nonlinearity=nn.LeakyReLU(),dropout=0.2), | |
Bottleneck3D(in_channels=24,out_channels=24,expanded_channels=88,stride=1,nonlinearity=nn.LeakyReLU(),dropout=0.2) | |
) | |
#5x5 bottlenecks (8, h-swish, squeeze-excite): 28x28x24 -> 7x7x96 | |
self.block3 = nn.Sequential( | |
Bottleneck3D(in_channels=24,out_channels=40,expanded_channels=96,stride=2,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck3D(in_channels=40,out_channels=48,expanded_channels=120,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck3D(in_channels=48,out_channels=48,expanded_channels=144,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck3D(in_channels=48,out_channels=96,expanded_channels=288,stride=2,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck3D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck3D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2) | |
) | |
#conv3d (h-swish), avg pool 7x7: 7x7x96 -> 1x1x576 | |
self.block4 = nn.Sequential( | |
nn.Conv3d(in_channels=96,out_channels=576,kernel_size=1,stride=1,padding=0), | |
SEBlock3D(channels=576), | |
nn.BatchNorm3d(576), | |
nn.Hardswish() | |
) | |
#conv3d 1x1, NBN, (2, first uses h-swish): 1x1x576 | |
self.classifier = nn.Sequential( | |
nn.AdaptiveAvgPool3d((1,1,1)), | |
nn.Conv3d(in_channels=576,out_channels=1024,kernel_size=1,stride=1,padding=0), | |
nn.Hardswish(), | |
nn.Conv3d(in_channels=1024,out_channels=self.num_classes,kernel_size=1,stride=1,padding=0), | |
) | |
def forward(self,x): | |
x = self.block1(x) | |
x = self.block2(x) | |
x = self.block3(x) | |
x = self.block4(x) | |
x = self.classifier(x) | |
x = x.view(x.shape[0], self.num_classes) | |
return x | |
def initialize_weights(self): | |
for module in self.modules(): | |
if isinstance(module, nn.Conv3d) or isinstance(module, nn.Linear): | |
if hasattr(module, "nonlinearity"): | |
if module.nonlinearity == 'relu' or 'leaky_relu': | |
init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |
elif module.nonlinearity == 'hardswish': | |
init.xavier_uniform_(module.weight) | |
elif isinstance(module, nn.BatchNorm3d): | |
init.constant_(module.weight, 1) | |
init.constant_(module.bias, 0) | |
#MobileNetV3-Large 2D + LSTM for helping with the temporal dimension | |
class MobileNetLarge2D(nn.Module): | |
def __init__(self, num_classes=2): | |
super().__init__() | |
self.num_classes = num_classes | |
def initialize_weights(self): | |
for module in self.modules(): | |
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): | |
if hasattr(module, "nonlinearity"): | |
if module.nonlinearity == 'relu': | |
init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |
elif module.nonlinearity == 'hardswish': | |
init.xavier_uniform_(module.weight) | |
elif isinstance(module, nn.BatchNorm2d): | |
init.constant_(module.weight, 1) | |
init.constant_(module.bias, 0) | |
#conv2d (h-swish): 224x224x3 -> 112x112x16 | |
self.block1 = nn.Sequential( | |
nn.Conv2d(in_channels=3,out_channels=16,stride=2,kernel_size=3,padding=1), | |
nn.BatchNorm2d(16), | |
nn.Hardswish() | |
) | |
#3x3 bottlenecks1 (3, ReLU): 112x112x16 -> 56x56x24 | |
self.block2 = nn.Sequential( | |
Bottleneck2D(in_channels=16,out_channels=16,expanded_channels=16,stride=1,nonlinearity=nn.ReLU(),dropout=0.2), | |
Bottleneck2D(in_channels=16,out_channels=24,expanded_channels=64,stride=2,nonlinearity=nn.ReLU()), | |
Bottleneck2D(in_channels=24,out_channels=24,expanded_channels=72,stride=1,nonlinearity=nn.ReLU(),dropout=0.2) | |
) | |
#5x5 bottlenecks1 (3, ReLU, squeeze-excite): 56x56x24 -> 28x28x40 | |
self.block3 = nn.Sequential( | |
Bottleneck2D(in_channels=24,out_channels=40,expanded_channels=72,stride=2,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2), | |
Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU()), | |
Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2) | |
) | |
#3x3 bottlenecks2 (6, h-swish, last two get squeeze-excite): 28x28x40 -> 14x14x112 | |
self.block4 = nn.Sequential( | |
Bottleneck2D(in_channels=40,out_channels=80,expanded_channels=240,stride=2,dropout=0.2), | |
Bottleneck2D(in_channels=80,out_channels=80,expanded_channels=240,stride=1), | |
Bottleneck2D(in_channels=80,out_channels=80,expanded_channels=184,stride=1,dropout=0.2), | |
Bottleneck2D(in_channels=80,out_channels=80,expanded_channels=184,stride=1), | |
Bottleneck2D(in_channels=80,out_channels=112,expanded_channels=480,stride=1,use_se=True,dropout=0.2), | |
Bottleneck2D(in_channels=112,out_channels=112,expanded_channels=672,stride=1,use_se=True,dropout=0.2) | |
) | |
#5x5 bottlenecks2 (3, h-swish, squeeze-excite): 14x14x112 -> 7x7x160 | |
self.block5 = nn.Sequential( | |
Bottleneck2D(in_channels=112,out_channels=160,expanded_channels=672,stride=2,use_se=True,kernel_size=5), | |
Bottleneck2D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5), | |
Bottleneck2D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5) | |
) | |
#conv3d (h-swish), avg pool 7x7: 7x7x960 -> 1x1x960 | |
self.block6 = nn.Sequential( | |
nn.Conv2d(in_channels=160,out_channels=960,stride=1,kernel_size=1), | |
nn.BatchNorm2d(960), | |
nn.Hardswish(), | |
nn.AvgPool2d(kernel_size=7,stride=1) | |
) | |
#LSTM: 1x1x960 -> | |
self.lstm = nn.LSTM(input_size=960,hidden_size=32,num_layers=5,batch_first=True) | |
#classifier: conv3d 1x1 NBN (2, first uses h-swish): 1x1x960 | |
self.classifier = nn.Sequential( | |
nn.Linear(32,self.num_classes) #2 classes for ball/strike | |
) | |
def forward(self,x): | |
#x is shape (batch_size, timesteps, C, H, W) | |
batch_size,timesteps,C,H,W = x.size() | |
cnn_out = torch.zeros(batch_size,timesteps,960).to(x.device) #assuming the output of block6 is 960 | |
#we're looping through the frames in the video | |
for i in range(timesteps): | |
# Select the frame at the ith position | |
frame = x[:, i, :, :, :] | |
frame = self.block1(frame) | |
frame = self.block2(frame) | |
frame = self.block3(frame) | |
frame = self.block4(frame) | |
frame = self.block5(frame) | |
frame = self.block6(frame) | |
# Flatten the frame (minus the batch dimension) | |
frame = frame.view(frame.size(0), -1) | |
cnn_out[:, i, :] = frame | |
# reshape for LSTM | |
x = cnn_out | |
x, _ = self.lstm(x) | |
# get the output from the last timestep only | |
x = x[:, -1, :] | |
x = self.classifier(x) | |
return x | |
#MobileNetV3-Small 2d with lstm for helping with the temporal dimension | |
class MobileNetSmall2D(nn.Module): | |
def __init__(self,num_classes=2): | |
super().__init__() | |
self.num_classes = num_classes | |
#conv3d (h-swish): 224x224x3 -> 112x112x16 | |
self.block1 = nn.Sequential( | |
nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=2,padding=1), | |
nn.BatchNorm2d(16), | |
nn.Hardswish() | |
) | |
#3x3 bottlenecks (3, ReLU, first gets squeeze-excite): 112x112x16 -> 28x28x24 | |
self.block2 = nn.Sequential( | |
Bottleneck2D(in_channels=16,out_channels=16,expanded_channels=16,stride=2,use_se=True,nonlinearity=nn.ReLU(),dropout=0.2), | |
Bottleneck2D(in_channels=16,out_channels=24,expanded_channels=72,stride=2,nonlinearity=nn.ReLU(),dropout=0.2), | |
Bottleneck2D(in_channels=24,out_channels=24,expanded_channels=88,stride=1,nonlinearity=nn.ReLU(),dropout=0.2) | |
) | |
#5x5 bottlenecks (8, h-swish, squeeze-excite): 28x28x24 -> 7x7x96 | |
self.block3 = nn.Sequential( | |
Bottleneck2D(in_channels=24,out_channels=40,expanded_channels=96,stride=2,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck2D(in_channels=40,out_channels=48,expanded_channels=120,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck2D(in_channels=48,out_channels=48,expanded_channels=144,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck2D(in_channels=48,out_channels=96,expanded_channels=288,stride=2,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck2D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2), | |
Bottleneck2D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2) | |
) | |
#conv2d (h-swish), avg pool 7x7: 7x7x96 -> 1x1x576 | |
self.block4 = nn.Sequential( | |
nn.Conv2d(in_channels=96,out_channels=576,kernel_size=1,stride=1,padding=0), | |
SEBlock2D(channels=576), | |
nn.BatchNorm2d(576), | |
nn.Hardswish(), | |
nn.AvgPool2d(kernel_size=7,stride=1) | |
) | |
#LSTM: 1x1x576 -> | |
self.lstm = nn.LSTM(input_size=576,hidden_size=64,num_layers=1,batch_first=True) | |
#classifier: conv3d 1x1 NBN (2, first uses h-swish): 1x1x576 | |
self.classifier = nn.Sequential( | |
nn.Linear(64,self.num_classes) #2 classes for ball/strike | |
) | |
def forward(self,x): | |
# x is of shape (batch_size, timesteps, C, H, W) | |
batch_size, timesteps, C, H, W = x.size() | |
cnn_out = torch.zeros(batch_size, timesteps, 576).to(x.device) #assuming the output of block4 is 576 | |
#we're looping through the frames in the video | |
for i in range(timesteps): | |
# Select the frame at the ith position | |
frame = x[:, i, :, :, :] | |
frame = self.block1(frame) | |
frame = self.block2(frame) | |
frame = self.block3(frame) | |
frame = self.block4(frame) | |
# Flatten the frame (minus the batch dimension) | |
frame = frame.view(frame.size(0), -1) | |
cnn_out[:, i, :] = frame | |
# reshape for LSTM | |
x = cnn_out | |
x, _ = self.lstm(x) | |
# get the output from the last timestep only | |
x = x[:, -1, :] | |
x = self.classifier(x) | |
return x | |
def initialize_weights(self): | |
for module in self.modules(): | |
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): | |
if hasattr(module, "nonlinearity"): | |
if module.nonlinearity == 'relu': | |
init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') | |
elif module.nonlinearity == 'hardswish': | |
init.xavier_uniform_(module.weight) | |
elif isinstance(module, nn.BatchNorm2d): | |
init.constant_(module.weight, 1) | |
init.constant_(module.bias, 0) | |