amanmibra commited on
Commit
47bf442
1 Parent(s): d6c7ff9

Add transformations to datset (to spec)

Browse files
__pycache__/dataset.cpython-39.pyc CHANGED
Binary files a/__pycache__/dataset.cpython-39.pyc and b/__pycache__/dataset.cpython-39.pyc differ
 
dataset.py CHANGED
@@ -1,16 +1,22 @@
1
  import os
2
 
 
3
  from torch.utils.data import Dataset
4
  import pandas as pd
5
  import torchaudio
6
 
7
  class VoiceDataset(Dataset):
8
 
9
- def __init__(self, data_directory):
 
10
  self._data_path = os.path.join(data_directory)
11
  self._labels = os.listdir(self._data_path)
12
 
13
- self.audio_files, self.audio_labels = self._join_audio_files()
 
 
 
 
14
 
15
  def __len__(self):
16
  total_audio_files = 0
@@ -20,16 +26,39 @@ class VoiceDataset(Dataset):
20
  return total_audio_files
21
 
22
  def __getitem__(self, index):
23
- return self.audio_files[index], self.audio_labels[index]
 
 
 
 
 
 
 
 
 
24
 
25
  def _join_audio_files(self):
26
- audio_files = []
27
- audio_labels = []
28
 
29
  for label in self._labels:
30
  label_path = os.path.join(self._data_path, label)
31
  for f in os.listdir(label_path):
32
- audio_files.append(f)
33
- audio_labels.append(label)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- return audio_files, audio_labels
 
1
  import os
2
 
3
+ import torch
4
  from torch.utils.data import Dataset
5
  import pandas as pd
6
  import torchaudio
7
 
8
  class VoiceDataset(Dataset):
9
 
10
+ def __init__(self, data_directory, transformation, target_sample_rate):
11
+ # file processing
12
  self._data_path = os.path.join(data_directory)
13
  self._labels = os.listdir(self._data_path)
14
 
15
+ self.audio_files_labels = self._join_audio_files()
16
+
17
+ # audio processing
18
+ self.transformation = transformation
19
+ self.target_sample_rate = target_sample_rate
20
 
21
  def __len__(self):
22
  total_audio_files = 0
 
26
  return total_audio_files
27
 
28
  def __getitem__(self, index):
29
+ file, label = self.audio_files_labels[index]
30
+ filepath = os.path.join(self._data_path, label, file)
31
+
32
+ wav, sr = torchaudio.load(filepath, normalize=True)
33
+ wav = self._resample(wav, sr)
34
+ wav = self._mix_down(wav)
35
+ wav = self.transformation(wav)
36
+
37
+ return wav, label
38
+
39
 
40
  def _join_audio_files(self):
41
+ """Join all the audio file names and labels into one single dimenional array"""
42
+ audio_files_labels = []
43
 
44
  for label in self._labels:
45
  label_path = os.path.join(self._data_path, label)
46
  for f in os.listdir(label_path):
47
+ audio_files_labels.append((f, label))
48
+
49
+ return audio_files_labels
50
+
51
+ def _resample(self, wav, current_sample_rate):
52
+ """Resample audio to the target sample rate, if necessary"""
53
+ if current_sample_rate != self.target_sample_rate:
54
+ resampler = torchaudio.transforms.Resample(current_sample_rate, self.target_sample_rate)
55
+ wav = resampler(wav)
56
+
57
+ return wav
58
+
59
+ def _mix_down(self, wav):
60
+ """Mix down audio to a single channel, if necessary"""
61
+ if wav.shape[0] > 1:
62
+ wav = torch.mean(wav, dim=0, keepdim=True)
63
 
64
+ return wav
notebooks/playground.ipynb CHANGED
@@ -3,7 +3,7 @@
3
  {
4
  "cell_type": "code",
5
  "execution_count": 8,
6
- "id": "8b292047",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
@@ -14,7 +14,7 @@
14
  {
15
  "cell_type": "code",
16
  "execution_count": 10,
17
- "id": "88db7a26",
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
@@ -25,7 +25,7 @@
25
  {
26
  "cell_type": "code",
27
  "execution_count": 18,
28
- "id": "d4ac5e60",
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
@@ -35,7 +35,7 @@
35
  {
36
  "cell_type": "code",
37
  "execution_count": 14,
38
- "id": "903c1d7d",
39
  "metadata": {},
40
  "outputs": [],
41
  "source": [
@@ -45,7 +45,7 @@
45
  {
46
  "cell_type": "code",
47
  "execution_count": 15,
48
- "id": "7dec6dd0",
49
  "metadata": {},
50
  "outputs": [
51
  {
@@ -65,39 +65,67 @@
65
  },
66
  {
67
  "cell_type": "code",
68
- "execution_count": 22,
69
- "id": "1eea9cf8",
70
  "metadata": {},
71
  "outputs": [],
72
  "source": [
73
- "dataset = VoiceDataset('../data')"
 
 
 
 
 
 
74
  ]
75
  },
76
  {
77
  "cell_type": "code",
78
- "execution_count": 20,
79
- "id": "cee3b661",
80
  "metadata": {},
81
  "outputs": [
82
  {
83
  "data": {
84
  "text/plain": [
85
- "'../data'"
86
  ]
87
  },
88
- "execution_count": 20,
89
  "metadata": {},
90
  "output_type": "execute_result"
91
  }
92
  ],
93
  "source": [
94
- "dataset[1]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  ]
96
  },
97
  {
98
  "cell_type": "code",
99
  "execution_count": null,
100
- "id": "d1a4615a",
101
  "metadata": {},
102
  "outputs": [],
103
  "source": []
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 8,
6
+ "id": "17f47516",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
 
14
  {
15
  "cell_type": "code",
16
  "execution_count": 10,
17
+ "id": "3959c95c",
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
 
25
  {
26
  "cell_type": "code",
27
  "execution_count": 18,
28
+ "id": "53328491",
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
 
35
  {
36
  "cell_type": "code",
37
  "execution_count": 14,
38
+ "id": "24923a03",
39
  "metadata": {},
40
  "outputs": [],
41
  "source": [
 
45
  {
46
  "cell_type": "code",
47
  "execution_count": 15,
48
+ "id": "08f1c4c3",
49
  "metadata": {},
50
  "outputs": [
51
  {
 
65
  },
66
  {
67
  "cell_type": "code",
68
+ "execution_count": 46,
69
+ "id": "9554ab2c",
70
  "metadata": {},
71
  "outputs": [],
72
  "source": [
73
+ "mel_spectrogram = torchaudio.transforms.MelSpectrogram(\n",
74
+ " sample_rate=16000,\n",
75
+ " n_fft=1024,\n",
76
+ " hop_length=512,\n",
77
+ " n_mels=64\n",
78
+ " )\n",
79
+ "dataset = VoiceDataset('../data', mel_spectrogram, 16000)"
80
  ]
81
  },
82
  {
83
  "cell_type": "code",
84
+ "execution_count": 47,
85
+ "id": "f1413af4",
86
  "metadata": {},
87
  "outputs": [
88
  {
89
  "data": {
90
  "text/plain": [
91
+ "5718"
92
  ]
93
  },
94
+ "execution_count": 47,
95
  "metadata": {},
96
  "output_type": "execute_result"
97
  }
98
  ],
99
  "source": [
100
+ "len(dataset)"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 48,
106
+ "id": "e81b46ee",
107
+ "metadata": {},
108
+ "outputs": [
109
+ {
110
+ "data": {
111
+ "text/plain": [
112
+ "(tensor([[ 0.0220, 0.0041, -0.0153, ..., 0.0006, -0.0056, -0.0064]]),\n",
113
+ " 'aman')"
114
+ ]
115
+ },
116
+ "execution_count": 48,
117
+ "metadata": {},
118
+ "output_type": "execute_result"
119
+ }
120
+ ],
121
+ "source": [
122
+ "dataset[0]"
123
  ]
124
  },
125
  {
126
  "cell_type": "code",
127
  "execution_count": null,
128
+ "id": "48574640",
129
  "metadata": {},
130
  "outputs": [],
131
  "source": []