''' Implementing Mobilenet v3 as seen in "Searching for MobileNetV3" for video classification, note that balls are 0 and strikes are 1. ''' import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F class SEBlock3D(nn.Module): def __init__(self,channels): super().__init__() self.se = nn.Sequential( nn.AdaptiveAvgPool3d(1), nn.Conv3d(channels,channels//4,kernel_size=1), nn.ReLU(inplace=True), nn.Conv3d(channels//4,channels,kernel_size=1), nn.Hardsigmoid() ) def forward(self,x): w = self.se(x) x = x * w return x class SEBlock2D(nn.Module): def __init__(self,channels): super().__init__() self.se = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels,channels//4,kernel_size=1), nn.ReLU(inplace=True), nn.Conv2d(channels//4,channels,kernel_size=1), nn.Hardsigmoid() ) def forward(self,x): w = self.se(x) x = x * w return x #Bottleneck for Mobilenets class Bottleneck3D(nn.Module): def __init__(self, in_channels, out_channels, expanded_channels, stride=1, use_se=False, kernel_size=3,nonlinearity=nn.Hardswish(),batchnorm=True,dropout=0,bias=False): super().__init__() #pointwise conv1x1x1 (reduce channels) self.pointwise_conv1 = nn.Conv3d(in_channels,expanded_channels,kernel_size=1,bias=bias) #depthwise (spatial filtering) #groups to preserve channel-wise information self.depthwise_conv = nn.Conv3d( expanded_channels,#in channels expanded_channels,#out channels groups=expanded_channels, kernel_size=(1,kernel_size,kernel_size), stride=stride, padding=kernel_size//2, bias=bias ) #squeeze-and-excite (recalibrate channel wise) self.squeeze_excite = SEBlock3D(expanded_channels) if use_se else None #pointwise conv1x1x1 (expansion to increase channels) self.pointwise_conv2 = nn.Conv3d(expanded_channels,out_channels,kernel_size=1,bias=bias) self.batchnorm = nn.BatchNorm3d(out_channels) if batchnorm else None self.nonlinearity = nonlinearity self.dropout = nn.Dropout3d(p=dropout) def forward(self,x): x = self.pointwise_conv1(x) x = self.depthwise_conv(x) if self.squeeze_excite is not None: x = self.squeeze_excite(x) x = self.pointwise_conv2(x) x = self.batchnorm(x) x = self.nonlinearity(x) x = self.dropout(x) return x #2D bottleneck for our 2d convnet with LSTM class Bottleneck2D(nn.Module): def __init__(self, in_channels, out_channels, expanded_channels, stride=1, use_se=False, kernel_size=3,nonlinearity=nn.Hardswish(),batchnorm=True,dropout=0,bias=False): super().__init__() #pointwise conv1x1x1 (reduce channels) self.pointwise_conv1 = nn.Conv2d(in_channels,expanded_channels,kernel_size=1,bias=bias) #depthwise (spatial filtering) #groups to preserve channel-wise information self.depthwise_conv = nn.Conv2d( expanded_channels,#in channels expanded_channels,#out channels groups=expanded_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, bias=bias ) #squeeze-and-excite (recalibrate channel wise) self.squeeze_excite = SEBlock2D(expanded_channels) if use_se else None #pointwise conv1x1x1 (expansion to increase channels) self.pointwise_conv2 = nn.Conv2d(expanded_channels,out_channels,kernel_size=1,bias=bias) self.batchnorm = nn.BatchNorm2d(out_channels) if batchnorm else None self.nonlinearity = nonlinearity self.dropout = nn.Dropout2d(p=dropout) def forward(self,x): x = self.pointwise_conv1(x) x = self.depthwise_conv(x) if self.squeeze_excite is not None: x = self.squeeze_excite(x) x = self.pointwise_conv2(x) x = self.batchnorm(x) x = self.nonlinearity(x) return x #mobilenet large 3d convolutions class MobileNetLarge3D(nn.Module): def __init__(self,num_classes=2): super().__init__() self.num_classes = num_classes #conv3d (h-swish): 224x224x3 -> 112x112x16 self.block1 = nn.Sequential( nn.Conv3d(in_channels=3,out_channels=16,stride=2,kernel_size=3,padding=1), nn.BatchNorm3d(16), nn.Hardswish() ) #3x3 bottlenecks1 (3, ReLU): 112x112x16 -> 56x56x24 self.block2 = nn.Sequential( Bottleneck3D(in_channels=16,out_channels=16,expanded_channels=16,stride=1,nonlinearity=nn.ReLU(),dropout=0.2), Bottleneck3D(in_channels=16,out_channels=24,expanded_channels=64,stride=2,nonlinearity=nn.ReLU(),dropout=0.2), Bottleneck3D(in_channels=24,out_channels=24,expanded_channels=72,stride=1,nonlinearity=nn.ReLU(),dropout=0.2) ) #5x5 bottlenecks1 (3, ReLU, squeeze-excite): 56x56x24 -> 28x28x40 self.block3 = nn.Sequential( Bottleneck3D(in_channels=24,out_channels=40,expanded_channels=72,stride=2,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2), Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2), Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2) ) #3x3 bottlenecks2 (6, h-swish, last two get squeeze-excite): 28x28x40 -> 14x14x112 self.block4 = nn.Sequential( Bottleneck3D(in_channels=40,out_channels=80,expanded_channels=240,stride=2,dropout=0.2), Bottleneck3D(in_channels=80,out_channels=80,expanded_channels=240,stride=1,dropout=0.2), Bottleneck3D(in_channels=80,out_channels=80,expanded_channels=184,stride=1,dropout=0.2), Bottleneck3D(in_channels=80,out_channels=80,expanded_channels=184,stride=1,dropout=0.2), Bottleneck3D(in_channels=80,out_channels=112,expanded_channels=480,stride=1,use_se=True,dropout=0.2), Bottleneck3D(in_channels=112,out_channels=112,expanded_channels=672,stride=1,use_se=True,dropout=0.2) ) #5x5 bottlenecks2 (3, h-swish, squeeze-excite): 14x14x112 -> 7x7x160 self.block5 = nn.Sequential( Bottleneck3D(in_channels=112,out_channels=160,expanded_channels=672,stride=2,use_se=True,kernel_size=5,dropout=0.2), Bottleneck3D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck3D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5,dropout=0.2) ) #conv3d (h-swish), avg pool 7x7: 7x7x960 -> 1x1x960 self.block6 = nn.Sequential( nn.Conv3d(in_channels=160,out_channels=960,stride=1,kernel_size=1), nn.BatchNorm3d(960), nn.Hardswish() ) #classifier: conv3d 1x1 NBN (2, first uses h-swish): 1x1x960 self.classifier = nn.Sequential( nn.AdaptiveAvgPool3d((1,1,1)), nn.Conv3d(in_channels=960,out_channels=1280,kernel_size=1,stride=1,padding=0), #2 classes for ball/strike nn.Hardswish(), nn.Conv3d(in_channels=1280,out_channels=self.num_classes,kernel_size=1,stride=1,padding=0) ) def forward(self,x): x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.classifier(x) x = x.view(x.shape[0], self.num_classes) return x def initialize_weights(self): for module in self.modules(): if isinstance(module, nn.Conv3d) or isinstance(module, nn.Linear): if hasattr(module, "nonlinearity"): if module.nonlinearity == 'relu': init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') elif module.nonlinearity == 'hardswish': init.xavier_uniform_(module.weight) elif isinstance(module, nn.BatchNorm3d): init.constant_(module.weight, 1) init.constant_(module.bias, 0) #mobilenet small 3d convolutions class MobileNetSmall3D(nn.Module): def __init__(self,num_classes=2): super().__init__() self.num_classes = num_classes #conv3d (h-swish): 224x224x3 -> 112x112x16 self.block1 = nn.Sequential( nn.Conv3d(in_channels=3,out_channels=16,kernel_size=3,stride=2,padding=1), nn.BatchNorm3d(16), nn.Hardswish() ) #3x3 bottlenecks (3, ReLU, first gets squeeze-excite): 112x112x16 -> 28x28x24 self.block2 = nn.Sequential( Bottleneck3D(in_channels=16,out_channels=16,expanded_channels=16,stride=2,use_se=True,nonlinearity=nn.LeakyReLU(),dropout=0.2), Bottleneck3D(in_channels=16,out_channels=24,expanded_channels=72,stride=2,nonlinearity=nn.LeakyReLU(),dropout=0.2), Bottleneck3D(in_channels=24,out_channels=24,expanded_channels=88,stride=1,nonlinearity=nn.LeakyReLU(),dropout=0.2) ) #5x5 bottlenecks (8, h-swish, squeeze-excite): 28x28x24 -> 7x7x96 self.block3 = nn.Sequential( Bottleneck3D(in_channels=24,out_channels=40,expanded_channels=96,stride=2,use_se=True,kernel_size=5,dropout=0.2), Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck3D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck3D(in_channels=40,out_channels=48,expanded_channels=120,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck3D(in_channels=48,out_channels=48,expanded_channels=144,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck3D(in_channels=48,out_channels=96,expanded_channels=288,stride=2,use_se=True,kernel_size=5,dropout=0.2), Bottleneck3D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck3D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2) ) #conv3d (h-swish), avg pool 7x7: 7x7x96 -> 1x1x576 self.block4 = nn.Sequential( nn.Conv3d(in_channels=96,out_channels=576,kernel_size=1,stride=1,padding=0), SEBlock3D(channels=576), nn.BatchNorm3d(576), nn.Hardswish() ) #conv3d 1x1, NBN, (2, first uses h-swish): 1x1x576 self.classifier = nn.Sequential( nn.AdaptiveAvgPool3d((1,1,1)), nn.Conv3d(in_channels=576,out_channels=1024,kernel_size=1,stride=1,padding=0), nn.Hardswish(), nn.Conv3d(in_channels=1024,out_channels=self.num_classes,kernel_size=1,stride=1,padding=0), ) def forward(self,x): x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.classifier(x) x = x.view(x.shape[0], self.num_classes) return x def initialize_weights(self): for module in self.modules(): if isinstance(module, nn.Conv3d) or isinstance(module, nn.Linear): if hasattr(module, "nonlinearity"): if module.nonlinearity == 'relu' or 'leaky_relu': init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') elif module.nonlinearity == 'hardswish': init.xavier_uniform_(module.weight) elif isinstance(module, nn.BatchNorm3d): init.constant_(module.weight, 1) init.constant_(module.bias, 0) #MobileNetV3-Large 2D + LSTM for helping with the temporal dimension class MobileNetLarge2D(nn.Module): def __init__(self, num_classes=2): super().__init__() self.num_classes = num_classes def initialize_weights(self): for module in self.modules(): if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): if hasattr(module, "nonlinearity"): if module.nonlinearity == 'relu': init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') elif module.nonlinearity == 'hardswish': init.xavier_uniform_(module.weight) elif isinstance(module, nn.BatchNorm2d): init.constant_(module.weight, 1) init.constant_(module.bias, 0) #conv2d (h-swish): 224x224x3 -> 112x112x16 self.block1 = nn.Sequential( nn.Conv2d(in_channels=3,out_channels=16,stride=2,kernel_size=3,padding=1), nn.BatchNorm2d(16), nn.Hardswish() ) #3x3 bottlenecks1 (3, ReLU): 112x112x16 -> 56x56x24 self.block2 = nn.Sequential( Bottleneck2D(in_channels=16,out_channels=16,expanded_channels=16,stride=1,nonlinearity=nn.ReLU(),dropout=0.2), Bottleneck2D(in_channels=16,out_channels=24,expanded_channels=64,stride=2,nonlinearity=nn.ReLU()), Bottleneck2D(in_channels=24,out_channels=24,expanded_channels=72,stride=1,nonlinearity=nn.ReLU(),dropout=0.2) ) #5x5 bottlenecks1 (3, ReLU, squeeze-excite): 56x56x24 -> 28x28x40 self.block3 = nn.Sequential( Bottleneck2D(in_channels=24,out_channels=40,expanded_channels=72,stride=2,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2), Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU()), Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=120,stride=1,use_se=True,kernel_size=5,nonlinearity=nn.ReLU(),dropout=0.2) ) #3x3 bottlenecks2 (6, h-swish, last two get squeeze-excite): 28x28x40 -> 14x14x112 self.block4 = nn.Sequential( Bottleneck2D(in_channels=40,out_channels=80,expanded_channels=240,stride=2,dropout=0.2), Bottleneck2D(in_channels=80,out_channels=80,expanded_channels=240,stride=1), Bottleneck2D(in_channels=80,out_channels=80,expanded_channels=184,stride=1,dropout=0.2), Bottleneck2D(in_channels=80,out_channels=80,expanded_channels=184,stride=1), Bottleneck2D(in_channels=80,out_channels=112,expanded_channels=480,stride=1,use_se=True,dropout=0.2), Bottleneck2D(in_channels=112,out_channels=112,expanded_channels=672,stride=1,use_se=True,dropout=0.2) ) #5x5 bottlenecks2 (3, h-swish, squeeze-excite): 14x14x112 -> 7x7x160 self.block5 = nn.Sequential( Bottleneck2D(in_channels=112,out_channels=160,expanded_channels=672,stride=2,use_se=True,kernel_size=5), Bottleneck2D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5), Bottleneck2D(in_channels=160,out_channels=160,expanded_channels=960,stride=1,use_se=True,kernel_size=5) ) #conv3d (h-swish), avg pool 7x7: 7x7x960 -> 1x1x960 self.block6 = nn.Sequential( nn.Conv2d(in_channels=160,out_channels=960,stride=1,kernel_size=1), nn.BatchNorm2d(960), nn.Hardswish(), nn.AvgPool2d(kernel_size=7,stride=1) ) #LSTM: 1x1x960 -> self.lstm = nn.LSTM(input_size=960,hidden_size=32,num_layers=5,batch_first=True) #classifier: conv3d 1x1 NBN (2, first uses h-swish): 1x1x960 self.classifier = nn.Sequential( nn.Linear(32,self.num_classes) #2 classes for ball/strike ) def forward(self,x): #x is shape (batch_size, timesteps, C, H, W) batch_size,timesteps,C,H,W = x.size() cnn_out = torch.zeros(batch_size,timesteps,960).to(x.device) #assuming the output of block6 is 960 #we're looping through the frames in the video for i in range(timesteps): # Select the frame at the ith position frame = x[:, i, :, :, :] frame = self.block1(frame) frame = self.block2(frame) frame = self.block3(frame) frame = self.block4(frame) frame = self.block5(frame) frame = self.block6(frame) # Flatten the frame (minus the batch dimension) frame = frame.view(frame.size(0), -1) cnn_out[:, i, :] = frame # reshape for LSTM x = cnn_out x, _ = self.lstm(x) # get the output from the last timestep only x = x[:, -1, :] x = self.classifier(x) return x #MobileNetV3-Small 2d with lstm for helping with the temporal dimension class MobileNetSmall2D(nn.Module): def __init__(self,num_classes=2): super().__init__() self.num_classes = num_classes #conv3d (h-swish): 224x224x3 -> 112x112x16 self.block1 = nn.Sequential( nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=2,padding=1), nn.BatchNorm2d(16), nn.Hardswish() ) #3x3 bottlenecks (3, ReLU, first gets squeeze-excite): 112x112x16 -> 28x28x24 self.block2 = nn.Sequential( Bottleneck2D(in_channels=16,out_channels=16,expanded_channels=16,stride=2,use_se=True,nonlinearity=nn.ReLU(),dropout=0.2), Bottleneck2D(in_channels=16,out_channels=24,expanded_channels=72,stride=2,nonlinearity=nn.ReLU(),dropout=0.2), Bottleneck2D(in_channels=24,out_channels=24,expanded_channels=88,stride=1,nonlinearity=nn.ReLU(),dropout=0.2) ) #5x5 bottlenecks (8, h-swish, squeeze-excite): 28x28x24 -> 7x7x96 self.block3 = nn.Sequential( Bottleneck2D(in_channels=24,out_channels=40,expanded_channels=96,stride=2,use_se=True,kernel_size=5,dropout=0.2), Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck2D(in_channels=40,out_channels=40,expanded_channels=240,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck2D(in_channels=40,out_channels=48,expanded_channels=120,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck2D(in_channels=48,out_channels=48,expanded_channels=144,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck2D(in_channels=48,out_channels=96,expanded_channels=288,stride=2,use_se=True,kernel_size=5,dropout=0.2), Bottleneck2D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2), Bottleneck2D(in_channels=96,out_channels=96,expanded_channels=576,stride=1,use_se=True,kernel_size=5,dropout=0.2) ) #conv2d (h-swish), avg pool 7x7: 7x7x96 -> 1x1x576 self.block4 = nn.Sequential( nn.Conv2d(in_channels=96,out_channels=576,kernel_size=1,stride=1,padding=0), SEBlock2D(channels=576), nn.BatchNorm2d(576), nn.Hardswish(), nn.AvgPool2d(kernel_size=7,stride=1) ) #LSTM: 1x1x576 -> self.lstm = nn.LSTM(input_size=576,hidden_size=64,num_layers=1,batch_first=True) #classifier: conv3d 1x1 NBN (2, first uses h-swish): 1x1x576 self.classifier = nn.Sequential( nn.Linear(64,self.num_classes) #2 classes for ball/strike ) def forward(self,x): # x is of shape (batch_size, timesteps, C, H, W) batch_size, timesteps, C, H, W = x.size() cnn_out = torch.zeros(batch_size, timesteps, 576).to(x.device) #assuming the output of block4 is 576 #we're looping through the frames in the video for i in range(timesteps): # Select the frame at the ith position frame = x[:, i, :, :, :] frame = self.block1(frame) frame = self.block2(frame) frame = self.block3(frame) frame = self.block4(frame) # Flatten the frame (minus the batch dimension) frame = frame.view(frame.size(0), -1) cnn_out[:, i, :] = frame # reshape for LSTM x = cnn_out x, _ = self.lstm(x) # get the output from the last timestep only x = x[:, -1, :] x = self.classifier(x) return x def initialize_weights(self): for module in self.modules(): if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): if hasattr(module, "nonlinearity"): if module.nonlinearity == 'relu': init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') elif module.nonlinearity == 'hardswish': init.xavier_uniform_(module.weight) elif isinstance(module, nn.BatchNorm2d): init.constant_(module.weight, 1) init.constant_(module.bias, 0)