File size: 2,257 Bytes
d9083bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class CustomModel(nn.Module):
    def __init__(self, config, num_classes: int = 6, pretrained: bool = True):
        super(CustomModel, self).__init__()
        self.USE_KAGGLE_SPECTROGRAMS = True
        self.USE_EEG_SPECTROGRAMS    = True
        self.model                   = timm.create_model(
                                                            config.MODEL_NAME,
                                                            pretrained=pretrained,
                                                        )
        if config.FREEZE:
            for i,(name, param) in enumerate(list(self.model.named_parameters())[0:config.NUM_FROZEN_LAYERS]):
                param.requires_grad = False

        self.features      = nn.Sequential(*list(self.model.children())[:-2])
        self.custom_layers = nn.Sequential(
                                                nn.AdaptiveAvgPool2d(1),
                                                nn.Flatten(),
                                                nn.Linear(self.model.num_features, num_classes)
                                            )

    def __reshape_input(self, x):
        """
        Reshapes input (128, 256, 8) -> (786, 786, 3) monotone image.
        """ 
        # === Get spectograms ===
        spectograms = [x[:, :, :, i:i+1] for i in range(4)]
        spectograms = torch.cat(spectograms, dim=1)

        # === Get EEG spectograms ===
        eegs = [x[:, :, :, i:i+1] for i in range(4,8)]
        eegs = torch.cat(eegs, dim=1)

        # === Reshape (786, 786, 3) ===
        if self.USE_KAGGLE_SPECTROGRAMS & self.USE_EEG_SPECTROGRAMS:
            # Concatenate spectograms and eegs along the channels (dim=2)
            x = torch.cat([spectograms, eegs], dim=2)
        elif self.USE_EEG_SPECTROGRAMS:
            x = eegs
        else:
            x = spectograms

        # Replicate the single-channel data to create a monotone image
        x = torch.cat([x, x, x], dim=3)

        # Permute dimensions to match the desired shape (batch_size, channels, height, width)
        x = x.permute(0, 3, 1, 2)

        return x


    
    def forward(self, x):
        x = self.__reshape_input(x)
        x = self.features(x)
        x = self.custom_layers(x)
        return x