File size: 4,247 Bytes
557fb53
 
4b8361a
557fb53
c914273
 
 
4b8361a
 
 
557fb53
ad4c4e2
557fb53
 
 
 
 
4b8361a
7b37b0e
c914273
557fb53
7b37b0e
557fb53
c914273
 
0030bc6
 
 
c914273
 
 
 
7b37b0e
 
557fb53
 
 
 
 
c914273
 
 
557fb53
 
 
0030bc6
c914273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557fb53
c914273
 
557fb53
c914273
7b37b0e
c914273
 
 
557fb53
 
 
c914273
557fb53
 
 
c914273
 
 
 
 
557fb53
 
 
 
 
 
 
c914273
 
 
 
 
 
 
 
 
 
 
 
 
4b8361a
 
 
557fb53
 
 
 
ba35f85
557fb53
 
 
 
 
 
 
51f4763
557fb53
 
 
51f4763
557fb53
 
 
 
 
ad4c4e2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import pytorch_lightning as pl
from pytorch_lightning import callbacks as cb
import torch
from torch import nn
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import torchaudio
import yaml
from models.training_environment import TrainingEnvironment
from models.utils import LabelWeightedBCELoss
from preprocessing.dataset import DanceDataModule, get_datasets
from preprocessing.pipelines import (
    SpectrogramTrainingPipeline,
    WaveformPreprocessing,
)

# Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py


class ResidualDancer(nn.Module):
    def __init__(self, n_channels=128, n_classes=50):
        super().__init__()

        self.n_channels = n_channels
        self.n_classes = n_classes

        self.spec_bn = nn.BatchNorm2d(1)

        # CNN
        self.res_layers = nn.Sequential(
            ResBlock(1, n_channels, stride=2),
            ResBlock(n_channels, n_channels, stride=2),
            ResBlock(n_channels, n_channels * 2, stride=2),
            ResBlock(n_channels * 2, n_channels * 2, stride=2),
            ResBlock(n_channels * 2, n_channels * 2, stride=2),
            ResBlock(n_channels * 2, n_channels * 2, stride=2),
            ResBlock(n_channels * 2, n_channels * 4, stride=2),
        )

        # Dense
        self.dense1 = nn.Linear(n_channels * 4, n_channels * 4)
        self.bn = nn.BatchNorm1d(n_channels * 4)
        self.dense2 = nn.Linear(n_channels * 4, n_classes)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.spec_bn(x)

        # CNN
        x = self.res_layers(x)
        x = x.squeeze(2)

        # Global Max Pooling
        if x.size(-1) != 1:
            x = nn.MaxPool1d(x.size(-1))(x)
        x = x.squeeze(2)

        # Dense
        x = self.dense1(x)
        x = self.bn(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.dense2(x)
        # x = nn.Sigmoid()(x)

        return x


class ResBlock(nn.Module):
    def __init__(self, input_channels, output_channels, shape=3, stride=2):
        super().__init__()
        # convolution
        self.conv_1 = nn.Conv2d(
            input_channels, output_channels, shape, stride=stride, padding=shape // 2
        )
        self.bn_1 = nn.BatchNorm2d(output_channels)
        self.conv_2 = nn.Conv2d(
            output_channels, output_channels, shape, padding=shape // 2
        )
        self.bn_2 = nn.BatchNorm2d(output_channels)

        # residual
        self.diff = False
        if (stride != 1) or (input_channels != output_channels):
            self.conv_3 = nn.Conv2d(
                input_channels,
                output_channels,
                shape,
                stride=stride,
                padding=shape // 2,
            )
            self.bn_3 = nn.BatchNorm2d(output_channels)
            self.diff = True
        self.relu = nn.ReLU()

    def forward(self, x):
        # convolution
        out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))

        # residual
        if self.diff:
            x = self.bn_3(self.conv_3(x))
        out = x + out
        out = self.relu(out)
        return out


def train_residual_dancer(config: dict):
    TARGET_CLASSES = config["dance_ids"]
    DEVICE = config["device"]
    SEED = config["seed"]
    torch.set_float32_matmul_precision("medium")
    pl.seed_everything(SEED, workers=True)
    feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
    dataset = get_datasets(config["datasets"], feature_extractor)

    data = DanceDataModule(dataset, **config["data_module"])
    model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
    label_weights = data.get_label_weights().to(DEVICE)
    criterion = LabelWeightedBCELoss(label_weights)

    train_env = TrainingEnvironment(model, criterion, config)
    callbacks = [
        cb.EarlyStopping("val/loss", patience=2),
        cb.StochasticWeightAveraging(1e-2),
        cb.RichProgressBar(),
    ]
    trainer = pl.Trainer(callbacks=callbacks, **config["trainer"])
    trainer.fit(train_env, datamodule=data)
    trainer.test(
        train_env,
        datamodule=data,
    )