amanmibra commited on
Commit
40f7298
1 Parent(s): b823f0d

Update voice dataset to process wavs at init

Browse files
Files changed (1) hide show
  1. dataset.py +24 -17
dataset.py CHANGED
@@ -28,28 +28,35 @@ class VoiceDataset(Dataset):
28
  self.target_sample_rate = target_sample_rate
29
  self.num_samples = time_limit_in_secs * self.target_sample_rate
30
 
 
 
 
31
  def __len__(self):
32
  return len(self.audio_files_labels)
33
 
34
  def __getitem__(self, index):
35
- # get file
36
- file, label = self.audio_files_labels[index]
37
- filepath = os.path.join(self._data_path, label, file)
38
-
39
- # load wav
40
- wav, sr = torchaudio.load(filepath, normalize=True)
41
-
42
- # modify wav file, if necessary
43
- wav = wav.to(self.device)
44
- wav = self._resample(wav, sr)
45
- wav = self._mix_down(wav)
46
- wav = self._cut_or_pad(wav)
 
 
 
 
 
 
 
 
47
 
48
- # apply transformation
49
- wav = self.transformation(wav)
50
-
51
- # return wav and integer representation of the label
52
- return wav, self.label_mapping[label]
53
 
54
 
55
  def _join_audio_files(self):
 
28
  self.target_sample_rate = target_sample_rate
29
  self.num_samples = time_limit_in_secs * self.target_sample_rate
30
 
31
+ # preprocess all wavs
32
+ self.wavs = self._process_wavs()
33
+
34
  def __len__(self):
35
  return len(self.audio_files_labels)
36
 
37
  def __getitem__(self, index):
38
+ return self.wavs[index]
39
+
40
+ def _process_wavs(self):
41
+ wavs = []
42
+ for file, label in self.audio_files_labels:
43
+ filepath = os.path.join(self._data_path, label, file)
44
+
45
+ # load wav
46
+ wav, sr = torchaudio.load(filepath, normalize=True)
47
+
48
+ # modify wav file, if necessary
49
+ wav = wav.to(self.device)
50
+ wav = self._resample(wav, sr)
51
+ wav = self._mix_down(wav)
52
+ wav = self._cut_or_pad(wav)
53
+
54
+ # apply transformation
55
+ wav = self.transformation(wav)
56
+
57
+ wavs.append((wav, self.label_mapping[label]))
58
 
59
+ return wavs
 
 
 
 
60
 
61
 
62
  def _join_audio_files(self):