alefiury commited on
Commit
e5d0d0a
1 Parent(s): bb96b52

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +142 -55
README.md CHANGED
@@ -24,54 +24,107 @@ It achieves the following results on the evaluation set:
24
  ### Compute your inferences
25
 
26
  ```python
27
- class DataColletor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def __init__(
29
  self,
30
- processor: Wav2Vec2Processor,
 
31
  sampling_rate: int = 16000,
32
- padding: Union[bool, str] = True,
33
- max_length: Optional[int] = None,
34
- pad_to_multiple_of: Optional[int] = None,
35
- label2id: Dict = None,
36
- max_audio_len: int = 5
37
  ):
 
 
38
 
39
- self.processor = processor
40
  self.sampling_rate = sampling_rate
 
41
 
42
- self.padding = padding
43
- self.max_length = max_length
44
- self.pad_to_multiple_of = pad_to_multiple_of
 
 
45
 
46
- self.label2id = label2id
 
 
 
 
 
47
 
48
- self.max_audio_len = max_audio_len
 
 
49
 
50
- def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
51
- # split inputs and labels since they have to be of different lenghts and need
52
- # different padding methods
53
- input_features = []
54
- label_features = []
55
- for feature in features:
56
- speech_array, sampling_rate = torchaudio.load(feature["input_values"])
57
 
58
- # Transform to Mono
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  speech_array = torch.mean(speech_array, dim=0, keepdim=True)
60
 
61
- if sampling_rate != self.sampling_rate:
62
- transform = torchaudio.transforms.Resample(sampling_rate, self.sampling_rate)
63
- speech_array = transform(speech_array)
64
- sampling_rate = self.sampling_rate
65
 
66
- effective_size_len = sampling_rate * self.max_audio_len
67
 
68
- if speech_array.shape[-1] > effective_size_len:
69
- speech_array = speech_array[:, :effective_size_len]
70
 
71
- speech_array = speech_array.squeeze().numpy()
72
- input_tensor = self.processor(speech_array, sampling_rate=sampling_rate).input_values
73
- input_tensor = np.squeeze(input_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
 
 
 
 
 
75
  input_features.append({"input_values": input_tensor})
76
 
77
  batch = self.processor.pad(
@@ -85,6 +138,63 @@ class DataColletor:
85
  return batch
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  label2id = {
89
  "female": 0,
90
  "male": 1
@@ -97,30 +207,7 @@ id2label = {
97
 
98
  num_labels = 2
99
 
100
- feature_extractor = AutoFeatureExtractor.from_pretrained("alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech")
101
- model = AutoModelForAudioClassification.from_pretrained(
102
- pretrained_model_name_or_path="alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech",
103
- num_labels=num_labels,
104
- label2id=label2id,
105
- id2label=id2label,
106
- )
107
-
108
- data_collator = DataColletorTrain(
109
- feature_extractor,
110
- sampling_rate=16000,
111
- padding=True,
112
- label2id=label2id
113
- )
114
-
115
- test_dataloader = DataLoader(
116
- dataset=test_dataset,
117
- batch_size=16,
118
- collate_fn=data_collator,
119
- shuffle=False,
120
- num_workers=10
121
- )
122
-
123
- preds = predict(test_dataloader=test_dataloader, model=model)
124
  ```
125
 
126
 
 
24
  ### Compute your inferences
25
 
26
  ```python
27
+ import os
28
+ from typing import List, Optional, Union, Dict
29
+
30
+ import tqdm
31
+ import torch
32
+ import torchaudio
33
+ import numpy as np
34
+ import pandas as pd
35
+ from torch import nn
36
+ from torch.utils.data import DataLoader
37
+ from torch.nn import functional as F
38
+ from transformers import (
39
+ AutoFeatureExtractor,
40
+ AutoModelForAudioClassification,
41
+ Wav2Vec2Processor
42
+ )
43
+
44
+
45
+ class CustomDataset(torch.utils.data.Dataset):
46
  def __init__(
47
  self,
48
+ dataset: List,
49
+ basedir: Optional[str] = None,
50
  sampling_rate: int = 16000,
51
+ max_audio_len: int = 5,
 
 
 
 
52
  ):
53
+ self.dataset = dataset
54
+ self.basedir = basedir
55
 
 
56
  self.sampling_rate = sampling_rate
57
+ self.max_audio_len = max_audio_len
58
 
59
+ def __len__(self):
60
+ """
61
+ Return the length of the dataset
62
+ """
63
+ return len(self.dataset)
64
 
65
+ def _cutorpad(self, audio: np.ndarray) -> np.ndarray:
66
+ """
67
+ Cut or pad audio to the wished length
68
+ """
69
+ effective_length = self.sampling_rate * self.max_audio_len
70
+ len_audio = len(audio)
71
 
72
+ # If audio length is bigger than wished audio length
73
+ if len_audio > effective_length:
74
+ audio = audio[:effective_length]
75
 
76
+ # Expand one dimension related to the channel dimension
77
+ return audio
 
 
 
 
 
78
 
79
+
80
+ def __getitem__(self, index) -> torch.Tensor:
81
+ """
82
+ Return the audio and the sampling rate
83
+ """
84
+ if self.basedir is None:
85
+ filepath = self.dataset[index]
86
+ else:
87
+ filepath = os.path.join(self.basedir, self.dataset[index])
88
+
89
+ speech_array, sr = torchaudio.load(filepath)
90
+
91
+ # Transform to mono
92
+ if speech_array.shape[0] > 1:
93
  speech_array = torch.mean(speech_array, dim=0, keepdim=True)
94
 
95
+ if sr != self.sampling_rate:
96
+ transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
97
+ speech_array = transform(speech_array)
98
+ sr = self.sampling_rate
99
 
100
+ speech_array = speech_array.squeeze().numpy()
101
 
102
+ # Cut or pad audio
103
+ speech_array = self._cutorpad(speech_array)
104
 
105
+ return speech_array
106
+
107
+ class CollateFunc:
108
+ def __init__(
109
+ self,
110
+ processor: Wav2Vec2Processor,
111
+ max_length: Optional[int] = None,
112
+ padding: Union[bool, str] = True,
113
+ pad_to_multiple_of: Optional[int] = None,
114
+ sampling_rate: int = 16000,
115
+ ):
116
+ self.padding = padding
117
+ self.processor = processor
118
+ self.max_length = max_length
119
+ self.sampling_rate = sampling_rate
120
+ self.pad_to_multiple_of = pad_to_multiple_of
121
 
122
+ def __call__(self, batch: List):
123
+ input_features = []
124
+
125
+ for audio in batch:
126
+ input_tensor = self.processor(audio, sampling_rate=self.sampling_rate).input_values
127
+ input_tensor = np.squeeze(input_tensor)
128
  input_features.append({"input_values": input_tensor})
129
 
130
  batch = self.processor.pad(
 
138
  return batch
139
 
140
 
141
+ def predict(test_dataloader, model, device: torch.device):
142
+ """
143
+ Predict the class of the audio
144
+ """
145
+ model.to(device)
146
+ model.eval()
147
+ preds = []
148
+
149
+ with torch.no_grad():
150
+ for batch in tqdm.tqdm(test_dataloader):
151
+ input_values, attention_mask = batch['input_values'].to(device), batch['attention_mask'].to(device)
152
+
153
+ logits = model(input_values, attention_mask=attention_mask).logits
154
+ scores = F.softmax(logits, dim=-1)
155
+
156
+ pred = torch.argmax(scores, dim=1).cpu().detach().numpy()
157
+
158
+ preds.extend(pred)
159
+
160
+ return preds
161
+
162
+
163
+ def get_gender(model_name_or_path: str, audio_paths: List[str], label2id: Dict, id2label: Dict, device: torch.device):
164
+ num_labels = 2
165
+
166
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)
167
+ model = AutoModelForAudioClassification.from_pretrained(
168
+ pretrained_model_name_or_path=model_name_or_path,
169
+ num_labels=num_labels,
170
+ label2id=label2id,
171
+ id2label=id2label,
172
+ )
173
+
174
+ test_dataset = CustomDataset(audio_paths)
175
+ data_collator = CollateFunc(
176
+ processor=feature_extractor,
177
+ padding=True,
178
+ sampling_rate=16000,
179
+ )
180
+
181
+ test_dataloader = DataLoader(
182
+ dataset=test_dataset,
183
+ batch_size=16,
184
+ collate_fn=data_collator,
185
+ shuffle=False,
186
+ num_workers=10
187
+ )
188
+
189
+ preds = predict(test_dataloader=test_dataloader, model=model, device=device)
190
+
191
+ return preds
192
+
193
+
194
+ model_name_or_path = "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech"
195
+ audio_paths = [] # Must be a list with absolute paths of the audios that will be used in inference
196
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
197
+
198
  label2id = {
199
  "female": 0,
200
  "male": 1
 
207
 
208
  num_labels = 2
209
 
210
+ preds = get_gender(model_name_or_path, audio_paths, label2id, id2label, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  ```
212
 
213