from torch import nn from torchsummary import summary class CNNetwork(nn.Module): def __init__(self): super().__init__() # 4 conv blocks / flatten / linear / softmax self.conv1 = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=2 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ) self.conv2 = nn.Sequential( nn.Conv2d( in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=2 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ) self.conv3 = nn.Sequential( nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=2 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ) self.conv4 = nn.Sequential( nn.Conv2d( in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ) self.flatten = nn.Flatten() self.linear = nn.Linear(128 * 9 * 31, 3) self.softmax = nn.Softmax(dim=1) def forward(self, input_data): x = self.conv1(input_data) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.flatten(x) logits = self.linear(x) predictions = self.softmax(logits) return predictions