File size: 2,257 Bytes
d9083bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
class CustomModel(nn.Module):
def __init__(self, config, num_classes: int = 6, pretrained: bool = True):
super(CustomModel, self).__init__()
self.USE_KAGGLE_SPECTROGRAMS = True
self.USE_EEG_SPECTROGRAMS = True
self.model = timm.create_model(
config.MODEL_NAME,
pretrained=pretrained,
)
if config.FREEZE:
for i,(name, param) in enumerate(list(self.model.named_parameters())[0:config.NUM_FROZEN_LAYERS]):
param.requires_grad = False
self.features = nn.Sequential(*list(self.model.children())[:-2])
self.custom_layers = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(self.model.num_features, num_classes)
)
def __reshape_input(self, x):
"""
Reshapes input (128, 256, 8) -> (786, 786, 3) monotone image.
"""
# === Get spectograms ===
spectograms = [x[:, :, :, i:i+1] for i in range(4)]
spectograms = torch.cat(spectograms, dim=1)
# === Get EEG spectograms ===
eegs = [x[:, :, :, i:i+1] for i in range(4,8)]
eegs = torch.cat(eegs, dim=1)
# === Reshape (786, 786, 3) ===
if self.USE_KAGGLE_SPECTROGRAMS & self.USE_EEG_SPECTROGRAMS:
# Concatenate spectograms and eegs along the channels (dim=2)
x = torch.cat([spectograms, eegs], dim=2)
elif self.USE_EEG_SPECTROGRAMS:
x = eegs
else:
x = spectograms
# Replicate the single-channel data to create a monotone image
x = torch.cat([x, x, x], dim=3)
# Permute dimensions to match the desired shape (batch_size, channels, height, width)
x = x.permute(0, 3, 1, 2)
return x
def forward(self, x):
x = self.__reshape_input(x)
x = self.features(x)
x = self.custom_layers(x)
return x
|