Spaces:
Runtime error
Runtime error
Update training settings
Browse files
cnn.py
CHANGED
@@ -52,7 +52,7 @@ class CNNetwork(nn.Module):
|
|
52 |
nn.MaxPool2d(kernel_size=2)
|
53 |
)
|
54 |
self.flatten = nn.Flatten()
|
55 |
-
self.linear = nn.Linear(128 *
|
56 |
self.softmax = nn.Softmax(dim=1)
|
57 |
|
58 |
def forward(self, input_data):
|
|
|
52 |
nn.MaxPool2d(kernel_size=2)
|
53 |
)
|
54 |
self.flatten = nn.Flatten()
|
55 |
+
self.linear = nn.Linear(128 * 9 * 31, 3)
|
56 |
self.softmax = nn.Softmax(dim=1)
|
57 |
|
58 |
def forward(self, input_data):
|
train.py
CHANGED
@@ -111,9 +111,9 @@ if __name__ == "__main__":
|
|
111 |
# instantiating our dataset object and create data loader
|
112 |
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
113 |
sample_rate=SAMPLE_RATE,
|
114 |
-
n_fft=
|
115 |
hop_length=512,
|
116 |
-
n_mels=
|
117 |
)
|
118 |
|
119 |
train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, device)
|
|
|
111 |
# instantiating our dataset object and create data loader
|
112 |
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
113 |
sample_rate=SAMPLE_RATE,
|
114 |
+
n_fft=2048,
|
115 |
hop_length=512,
|
116 |
+
n_mels=128
|
117 |
)
|
118 |
|
119 |
train_dataset = VoiceDataset(TRAIN_FILE, mel_spectrogram, device)
|