Spaces:
Runtime error
Runtime error
Add cut/pad files
Browse files- __pycache__/dataset.cpython-39.pyc +0 -0
- dataset.py +27 -3
- notebooks/playground.ipynb +43 -16
__pycache__/dataset.cpython-39.pyc
CHANGED
Binary files a/__pycache__/dataset.cpython-39.pyc and b/__pycache__/dataset.cpython-39.pyc differ
|
|
dataset.py
CHANGED
@@ -7,7 +7,7 @@ import torchaudio
|
|
7 |
|
8 |
class VoiceDataset(Dataset):
|
9 |
|
10 |
-
def __init__(self, data_directory, transformation, target_sample_rate):
|
11 |
# file processing
|
12 |
self._data_path = os.path.join(data_directory)
|
13 |
self._labels = os.listdir(self._data_path)
|
@@ -17,6 +17,7 @@ class VoiceDataset(Dataset):
|
|
17 |
# audio processing
|
18 |
self.transformation = transformation
|
19 |
self.target_sample_rate = target_sample_rate
|
|
|
20 |
|
21 |
def __len__(self):
|
22 |
total_audio_files = 0
|
@@ -26,14 +27,20 @@ class VoiceDataset(Dataset):
|
|
26 |
return total_audio_files
|
27 |
|
28 |
def __getitem__(self, index):
|
|
|
29 |
file, label = self.audio_files_labels[index]
|
30 |
filepath = os.path.join(self._data_path, label, file)
|
31 |
|
|
|
32 |
wav, sr = torchaudio.load(filepath, normalize=True)
|
|
|
|
|
33 |
wav = self._resample(wav, sr)
|
34 |
wav = self._mix_down(wav)
|
|
|
|
|
|
|
35 |
wav = self.transformation(wav)
|
36 |
-
|
37 |
return wav, label
|
38 |
|
39 |
|
@@ -61,4 +68,21 @@ class VoiceDataset(Dataset):
|
|
61 |
if wav.shape[0] > 1:
|
62 |
wav = torch.mean(wav, dim=0, keepdim=True)
|
63 |
|
64 |
-
return wav
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
class VoiceDataset(Dataset):
|
9 |
|
10 |
+
def __init__(self, data_directory, transformation, target_sample_rate, time_limit_in_secs=5):
|
11 |
# file processing
|
12 |
self._data_path = os.path.join(data_directory)
|
13 |
self._labels = os.listdir(self._data_path)
|
|
|
17 |
# audio processing
|
18 |
self.transformation = transformation
|
19 |
self.target_sample_rate = target_sample_rate
|
20 |
+
self.num_samples = time_limit_in_secs * self.target_sample_rate
|
21 |
|
22 |
def __len__(self):
|
23 |
total_audio_files = 0
|
|
|
27 |
return total_audio_files
|
28 |
|
29 |
def __getitem__(self, index):
|
30 |
+
# get file
|
31 |
file, label = self.audio_files_labels[index]
|
32 |
filepath = os.path.join(self._data_path, label, file)
|
33 |
|
34 |
+
# load wav
|
35 |
wav, sr = torchaudio.load(filepath, normalize=True)
|
36 |
+
|
37 |
+
# modify wav file, if necessary
|
38 |
wav = self._resample(wav, sr)
|
39 |
wav = self._mix_down(wav)
|
40 |
+
wav = self._cut_or_pad(wav)
|
41 |
+
|
42 |
+
# apply transformation
|
43 |
wav = self.transformation(wav)
|
|
|
44 |
return wav, label
|
45 |
|
46 |
|
|
|
68 |
if wav.shape[0] > 1:
|
69 |
wav = torch.mean(wav, dim=0, keepdim=True)
|
70 |
|
71 |
+
return wav
|
72 |
+
|
73 |
+
def _cut_or_pad(self, wav):
|
74 |
+
"""Modify audio if number of samples != target number of samples of the dataset.
|
75 |
+
|
76 |
+
If there are too many samples, cut the audio.
|
77 |
+
If there are not enough samples, pad the audio with zeros.
|
78 |
+
"""
|
79 |
+
|
80 |
+
length_signal = wav.shape[1]
|
81 |
+
if length_signal > self.num_samples:
|
82 |
+
wav = wav[:, :self.num_samples]
|
83 |
+
elif length_signal < self.num_samples:
|
84 |
+
num_of_missing_samples = self.num_samples - length_signal
|
85 |
+
pad = (0, num_of_missing_samples)
|
86 |
+
wav = torch.nn.functional.pad(wav, pad)
|
87 |
+
|
88 |
+
return wav
|
notebooks/playground.ipynb
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 8,
|
6 |
-
"id": "
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
@@ -14,7 +14,7 @@
|
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": 10,
|
17 |
-
"id": "
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
@@ -25,7 +25,7 @@
|
|
25 |
{
|
26 |
"cell_type": "code",
|
27 |
"execution_count": 18,
|
28 |
-
"id": "
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
@@ -35,7 +35,7 @@
|
|
35 |
{
|
36 |
"cell_type": "code",
|
37 |
"execution_count": 14,
|
38 |
-
"id": "
|
39 |
"metadata": {},
|
40 |
"outputs": [],
|
41 |
"source": [
|
@@ -45,7 +45,7 @@
|
|
45 |
{
|
46 |
"cell_type": "code",
|
47 |
"execution_count": 15,
|
48 |
-
"id": "
|
49 |
"metadata": {},
|
50 |
"outputs": [
|
51 |
{
|
@@ -65,8 +65,8 @@
|
|
65 |
},
|
66 |
{
|
67 |
"cell_type": "code",
|
68 |
-
"execution_count":
|
69 |
-
"id": "
|
70 |
"metadata": {},
|
71 |
"outputs": [],
|
72 |
"source": [
|
@@ -76,13 +76,13 @@
|
|
76 |
" hop_length=512,\n",
|
77 |
" n_mels=64\n",
|
78 |
" )\n",
|
79 |
-
"dataset = VoiceDataset('../data', mel_spectrogram, 16000)"
|
80 |
]
|
81 |
},
|
82 |
{
|
83 |
"cell_type": "code",
|
84 |
-
"execution_count":
|
85 |
-
"id": "
|
86 |
"metadata": {},
|
87 |
"outputs": [
|
88 |
{
|
@@ -91,7 +91,7 @@
|
|
91 |
"5718"
|
92 |
]
|
93 |
},
|
94 |
-
"execution_count":
|
95 |
"metadata": {},
|
96 |
"output_type": "execute_result"
|
97 |
}
|
@@ -102,18 +102,24 @@
|
|
102 |
},
|
103 |
{
|
104 |
"cell_type": "code",
|
105 |
-
"execution_count":
|
106 |
-
"id": "
|
107 |
"metadata": {},
|
108 |
"outputs": [
|
109 |
{
|
110 |
"data": {
|
111 |
"text/plain": [
|
112 |
-
"(tensor([[
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
" 'aman')"
|
114 |
]
|
115 |
},
|
116 |
-
"execution_count":
|
117 |
"metadata": {},
|
118 |
"output_type": "execute_result"
|
119 |
}
|
@@ -122,10 +128,31 @@
|
|
122 |
"dataset[0]"
|
123 |
]
|
124 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
{
|
126 |
"cell_type": "code",
|
127 |
"execution_count": null,
|
128 |
-
"id": "
|
129 |
"metadata": {},
|
130 |
"outputs": [],
|
131 |
"source": []
|
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 8,
|
6 |
+
"id": "26db4cdb",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
|
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": 10,
|
17 |
+
"id": "c8244b70",
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
|
|
25 |
{
|
26 |
"cell_type": "code",
|
27 |
"execution_count": 18,
|
28 |
+
"id": "f3fd2d28",
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
|
|
35 |
{
|
36 |
"cell_type": "code",
|
37 |
"execution_count": 14,
|
38 |
+
"id": "da9fe647",
|
39 |
"metadata": {},
|
40 |
"outputs": [],
|
41 |
"source": [
|
|
|
45 |
{
|
46 |
"cell_type": "code",
|
47 |
"execution_count": 15,
|
48 |
+
"id": "70905d2d",
|
49 |
"metadata": {},
|
50 |
"outputs": [
|
51 |
{
|
|
|
65 |
},
|
66 |
{
|
67 |
"cell_type": "code",
|
68 |
+
"execution_count": 64,
|
69 |
+
"id": "523d28f9",
|
70 |
"metadata": {},
|
71 |
"outputs": [],
|
72 |
"source": [
|
|
|
76 |
" hop_length=512,\n",
|
77 |
" n_mels=64\n",
|
78 |
" )\n",
|
79 |
+
"dataset = VoiceDataset('../data', mel_spectrogram, 16000,)"
|
80 |
]
|
81 |
},
|
82 |
{
|
83 |
"cell_type": "code",
|
84 |
+
"execution_count": 65,
|
85 |
+
"id": "0044724d",
|
86 |
"metadata": {},
|
87 |
"outputs": [
|
88 |
{
|
|
|
91 |
"5718"
|
92 |
]
|
93 |
},
|
94 |
+
"execution_count": 65,
|
95 |
"metadata": {},
|
96 |
"output_type": "execute_result"
|
97 |
}
|
|
|
102 |
},
|
103 |
{
|
104 |
"cell_type": "code",
|
105 |
+
"execution_count": 66,
|
106 |
+
"id": "df7a9e58",
|
107 |
"metadata": {},
|
108 |
"outputs": [
|
109 |
{
|
110 |
"data": {
|
111 |
"text/plain": [
|
112 |
+
"(tensor([[[0.2647, 0.0247, 0.0324, ..., 0.0000, 0.0000, 0.0000],\n",
|
113 |
+
" [0.0812, 0.0178, 0.0890, ..., 0.0000, 0.0000, 0.0000],\n",
|
114 |
+
" [0.0052, 0.0212, 0.1341, ..., 0.0000, 0.0000, 0.0000],\n",
|
115 |
+
" ...,\n",
|
116 |
+
" [0.5154, 0.3950, 0.4497, ..., 0.0000, 0.0000, 0.0000],\n",
|
117 |
+
" [0.1919, 0.4804, 0.5144, ..., 0.0000, 0.0000, 0.0000],\n",
|
118 |
+
" [0.1208, 0.4357, 0.4016, ..., 0.0000, 0.0000, 0.0000]]]),\n",
|
119 |
" 'aman')"
|
120 |
]
|
121 |
},
|
122 |
+
"execution_count": 66,
|
123 |
"metadata": {},
|
124 |
"output_type": "execute_result"
|
125 |
}
|
|
|
128 |
"dataset[0]"
|
129 |
]
|
130 |
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 67,
|
134 |
+
"id": "df064dbc",
|
135 |
+
"metadata": {},
|
136 |
+
"outputs": [
|
137 |
+
{
|
138 |
+
"data": {
|
139 |
+
"text/plain": [
|
140 |
+
"torch.Size([1, 64, 313])"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
"execution_count": 67,
|
144 |
+
"metadata": {},
|
145 |
+
"output_type": "execute_result"
|
146 |
+
}
|
147 |
+
],
|
148 |
+
"source": [
|
149 |
+
"dataset[0][0].shape"
|
150 |
+
]
|
151 |
+
},
|
152 |
{
|
153 |
"cell_type": "code",
|
154 |
"execution_count": null,
|
155 |
+
"id": "ed4899bf",
|
156 |
"metadata": {},
|
157 |
"outputs": [],
|
158 |
"source": []
|