Daniel Rasmussen commited on
Commit
ea8b1b4
·
1 Parent(s): ee0736e

Add AutoFeatureExtractor support

Browse files
config.json CHANGED
@@ -6,7 +6,8 @@
6
  "auto_map": {
7
  "AutoConfig": "config.Config",
8
  "AutoModel": "model.Model",
9
- "AutoModelForCTC": "model.Model"
 
10
  },
11
  "input_features": 80,
12
  "vocab_size": 256,
 
6
  "auto_map": {
7
  "AutoConfig": "config.Config",
8
  "AutoModel": "model.Model",
9
+ "AutoModelForCTC": "model.Model",
10
+ "AutoFeatureExtractor": "feature_extraction.FeatureExtractor"
11
  },
12
  "input_features": 80,
13
  "vocab_size": 256,
feature_extraction.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Feature extraction for ASR model."""
2
+
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from transformers import SequenceFeatureExtractor
8
+ from transformers.audio_utils import mel_filter_bank
9
+
10
+
11
+ class FeatureExtractor(SequenceFeatureExtractor):
12
+ """Feature extractor for ASR model that extracts MFCC features from audio.
13
+
14
+ Parameters
15
+ ----------
16
+ window_size_ms : int
17
+ Window size in milliseconds for STFT, default 25.
18
+ window_stride_ms : int
19
+ Window stride in milliseconds for STFT, default 10.
20
+ mel_lower_edge_hertz : int
21
+ Lower edge of mel frequency range, default 0.
22
+ mel_upper_edge_hertz : int
23
+ Upper edge of mel frequency range, default 8000.
24
+ mel_num_bins : int
25
+ Number of mel filterbank features, default 80.
26
+ sample_rate : int
27
+ Sample rate of audio input, default 16000.
28
+ padding_value : float
29
+ Value to use for padding variable-length inputs, default 1000.
30
+ """
31
+
32
+ model_input_names = ["input_features"]
33
+
34
+ def __init__(
35
+ self,
36
+ window_size_ms: int = 25,
37
+ window_stride_ms: int = 10,
38
+ mel_lower_edge_hertz: int = 0,
39
+ mel_upper_edge_hertz: int = 8000,
40
+ mel_num_bins: int = 80,
41
+ sample_rate: int = 16000,
42
+ padding_value: float = 1000.0,
43
+ **kwargs,
44
+ ):
45
+ super().__init__(
46
+ feature_size=mel_num_bins,
47
+ sampling_rate=sample_rate,
48
+ padding_value=padding_value,
49
+ **kwargs,
50
+ )
51
+
52
+ self.window_size_ms = window_size_ms
53
+ self.window_stride_ms = window_stride_ms
54
+ self.mel_lower_edge_hertz = mel_lower_edge_hertz
55
+ self.mel_upper_edge_hertz = mel_upper_edge_hertz
56
+ self.mel_num_bins = mel_num_bins
57
+ self.log_epsilon = 1e-12
58
+
59
+ # Calculate window parameters
60
+ self.window_size_samples = int(
61
+ round(self.sampling_rate * self.window_size_ms / 1000.0)
62
+ )
63
+ self.window_stride_samples = int(
64
+ round(self.sampling_rate * self.window_stride_ms / 1000.0)
65
+ )
66
+ self.fft_len = self.window_size_samples
67
+
68
+ # Precompute mel filterbank matrix
69
+ self.mel_matrix = mel_filter_bank(
70
+ num_frequency_bins=self.fft_len // 2 + 1,
71
+ num_mel_filters=self.mel_num_bins,
72
+ min_frequency=self.mel_lower_edge_hertz,
73
+ max_frequency=self.mel_upper_edge_hertz,
74
+ sampling_rate=self.sampling_rate,
75
+ )
76
+
77
+ # Cache for device-specific mel matrix (avoids repeated conversions)
78
+ self._mel_matrix_cache = {} # device -> torch.Tensor
79
+
80
+ # Default device for feature extraction
81
+ self._device = torch.device("cpu")
82
+
83
+ @property
84
+ def device(self):
85
+ """Get the device for feature extraction."""
86
+ return self._device
87
+
88
+ def to(self, device):
89
+ """Move feature extractor to a device.
90
+
91
+ Parameters
92
+ ----------
93
+ device : torch.device or str
94
+ Device to move to (e.g., 'cuda', 'cpu', torch.device('cuda:0'))
95
+
96
+ Returns
97
+ -------
98
+ self
99
+ """
100
+ self._device = torch.device(device)
101
+ return self
102
+
103
+ def cuda(self, device=None):
104
+ """Move feature extractor to CUDA device.
105
+
106
+ Parameters
107
+ ----------
108
+ device : int, optional
109
+ CUDA device index. If None, uses default CUDA device.
110
+
111
+ Returns
112
+ -------
113
+ self
114
+ """
115
+ if device is None:
116
+ self._device = torch.device("cuda")
117
+ else:
118
+ self._device = torch.device(f"cuda:{device}")
119
+ return self
120
+
121
+ def cpu(self):
122
+ """Move feature extractor to CPU.
123
+
124
+ Returns
125
+ -------
126
+ self
127
+ """
128
+ self._device = torch.device("cpu")
129
+ return self
130
+
131
+ def to_dict(self):
132
+ """Serialize to dict, excluding non-serializable attributes."""
133
+ output = super().to_dict()
134
+ # Remove non-serializable attributes
135
+ output.pop("_device", None)
136
+ output.pop("_mel_matrix_cache", None)
137
+ return output
138
+
139
+ def __call__(
140
+ self,
141
+ raw_speech: Union[
142
+ np.ndarray,
143
+ torch.Tensor,
144
+ List[float],
145
+ List[np.ndarray],
146
+ List[torch.Tensor],
147
+ List[List[float]],
148
+ ],
149
+ sampling_rate: Optional[int] = None,
150
+ mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
151
+ **kwargs,
152
+ ):
153
+ """Extract MFCC features from raw audio.
154
+
155
+ Parameters
156
+ ----------
157
+ raw_speech : np.ndarray or torch.Tensor or List[float] or List[np.ndarray] or List[torch.Tensor] or List[List[float]]
158
+ The raw audio waveform(s) to extract features from. Can be:
159
+ - A single waveform as a 1D array/tensor
160
+ - A batch of waveforms as a 2D array/tensor
161
+ - A list of waveforms (can be variable length, mask auto-generated)
162
+ sampling_rate : int, optional
163
+ Sampling rate of the audio. If provided, must match the feature
164
+ extractor's sampling_rate.
165
+ mask : np.ndarray or torch.Tensor, optional
166
+ Mask for the input audio when input is array/tensor. Should have the same
167
+ shape as raw_speech. Values should be 1 for real audio and 0 for padding.
168
+ Not used when raw_speech is a list (mask is auto-generated in that case).
169
+
170
+ Returns
171
+ -------
172
+ torch.Tensor or dict
173
+ If no output mask is needed, returns the features tensor directly with
174
+ shape (batch, time, features). If an output mask is computed, returns a
175
+ dictionary containing:
176
+ - input_features: Extracted MFCC features of shape (batch, time, features)
177
+ - mask: Mask for the features of shape (batch, time)
178
+ """
179
+
180
+ # Validate sampling rate
181
+ if sampling_rate is not None and sampling_rate != self.sampling_rate:
182
+ raise ValueError(
183
+ f"The sampling_rate of the provided audio ({sampling_rate}) "
184
+ f"doesn't match the feature extractor's sampling_rate ({self.sampling_rate})"
185
+ )
186
+
187
+ input_mask = None
188
+
189
+ # Handle tensor/array inputs directly (no padding needed)
190
+ if isinstance(raw_speech, (torch.Tensor, np.ndarray)):
191
+ # Ensure input is 2D
192
+ if raw_speech.ndim == 1:
193
+ raw_speech = (
194
+ raw_speech[np.newaxis, :]
195
+ if isinstance(raw_speech, np.ndarray)
196
+ else raw_speech.unsqueeze(0)
197
+ )
198
+ if mask is not None:
199
+ mask = (
200
+ mask[np.newaxis, :]
201
+ if isinstance(mask, np.ndarray)
202
+ else mask.unsqueeze(0)
203
+ )
204
+ elif raw_speech.ndim != 2:
205
+ raise ValueError(f"Input must be 1D or 2D, got {raw_speech.ndim}D")
206
+
207
+ # Convert to torch
208
+ batched_speech = (
209
+ raw_speech
210
+ if isinstance(raw_speech, torch.Tensor)
211
+ else torch.from_numpy(raw_speech)
212
+ )
213
+ # Move to device
214
+ batched_speech = batched_speech.to(self._device)
215
+
216
+ if mask is not None:
217
+ input_mask = (
218
+ mask if isinstance(mask, torch.Tensor) else torch.from_numpy(mask)
219
+ )
220
+ # Move to device
221
+ input_mask = input_mask.to(self._device)
222
+ else:
223
+ # Handle list inputs (may need padding)
224
+ if not isinstance(raw_speech, list):
225
+ raw_speech = [raw_speech]
226
+
227
+ # Convert to torch tensors and move to device
228
+ torch_speech = []
229
+ for speech in raw_speech:
230
+ if isinstance(speech, torch.Tensor):
231
+ torch_speech.append(speech.float().to(self._device))
232
+ else:
233
+ torch_speech.append(
234
+ torch.from_numpy(np.asarray(speech, dtype=np.float32)).to(
235
+ self._device
236
+ )
237
+ )
238
+
239
+ # Find max length and pad to it
240
+ max_length = max(len(speech) for speech in torch_speech)
241
+
242
+ # Pad all sequences to max_length and create mask
243
+ padded_speech = []
244
+ masks = []
245
+ for speech in torch_speech:
246
+ original_length = len(speech)
247
+ if original_length < max_length:
248
+ padding = torch.full(
249
+ (max_length - original_length,),
250
+ self.padding_value,
251
+ dtype=speech.dtype,
252
+ device=self._device,
253
+ )
254
+ speech = torch.cat([speech, padding])
255
+
256
+ # Create mask: 1 for real data, 0 for padding
257
+ mask = torch.ones(max_length, dtype=torch.bool, device=self._device)
258
+ mask[original_length:] = 0
259
+ else:
260
+ mask = torch.ones(max_length, dtype=torch.bool, device=self._device)
261
+
262
+ padded_speech.append(speech)
263
+ masks.append(mask)
264
+
265
+ # Stack into batch
266
+ batched_speech = torch.stack(padded_speech, dim=0)
267
+ input_mask = torch.stack(masks, dim=0)
268
+
269
+ # Extract features
270
+ with torch.no_grad():
271
+ features = self._extract_features(batched_speech)
272
+
273
+ # Compute output mask if we have an input mask
274
+ output_mask = None
275
+ if input_mask is not None:
276
+ output_mask = self._compute_mask(input_mask)
277
+ # Set masked features to padding_value
278
+ # output_mask is (batch, time), features is (batch, time, features)
279
+ # Need to expand mask to broadcast: (batch, time, 1)
280
+ mask_expanded = output_mask.unsqueeze(-1)
281
+ features = torch.where(
282
+ mask_expanded,
283
+ features,
284
+ torch.tensor(
285
+ self.padding_value, dtype=features.dtype, device=features.device
286
+ ),
287
+ )
288
+
289
+ # Return features directly if no mask, otherwise return dict
290
+ if output_mask is not None:
291
+ return {
292
+ "input_features": features,
293
+ "mask": output_mask,
294
+ }
295
+ else:
296
+ return features
297
+
298
+ def _extract_features(self, waveform: torch.Tensor) -> torch.Tensor:
299
+ """Extract MFCC features from waveform.
300
+
301
+ Parameters
302
+ ----------
303
+ waveform : torch.Tensor
304
+ Input waveform of shape (batch, time)
305
+
306
+ Returns
307
+ -------
308
+ torch.Tensor
309
+ Log mel spectrogram features of shape (batch, time_frames, mel_bins)
310
+ """
311
+ # Zero pad if there isn't enough data for at least one frame
312
+ if waveform.shape[1] < self.window_size_samples:
313
+ padding = self.window_size_samples - waveform.shape[1]
314
+ waveform = torch.nn.functional.pad(waveform, (0, padding))
315
+
316
+ # Compute spectrogram using STFT
317
+ spectrogram = torch.stft(
318
+ waveform,
319
+ n_fft=self.fft_len,
320
+ hop_length=self.window_stride_samples,
321
+ win_length=self.window_size_samples,
322
+ window=torch.hann_window(self.window_size_samples, device=waveform.device),
323
+ center=False,
324
+ return_complex=True,
325
+ )
326
+
327
+ # Take absolute value to get magnitude
328
+ spectrogram = torch.abs(spectrogram)
329
+
330
+ # Get mel matrix from cache or create it
331
+ device = spectrogram.device
332
+ dtype = spectrogram.dtype
333
+ cache_key = (device, dtype)
334
+
335
+ if cache_key not in self._mel_matrix_cache:
336
+ # Convert and cache the mel matrix for this device/dtype combination
337
+ self._mel_matrix_cache[cache_key] = torch.from_numpy(self.mel_matrix).to(
338
+ device=device, dtype=dtype
339
+ )
340
+
341
+ mel_matrix = self._mel_matrix_cache[cache_key]
342
+
343
+ # Apply mel filterbank: (batch, freq, time) @ (freq, mel) -> (batch, time, mel)
344
+ # Need to transpose spectrogram from (batch, freq, time) to (batch, time, freq)
345
+ spectrogram = spectrogram.transpose(1, 2)
346
+ mel_spectrogram = torch.matmul(spectrogram, mel_matrix)
347
+
348
+ # Compute log (with epsilon for stability)
349
+ log_mel_spectrogram = torch.log(
350
+ torch.clamp(mel_spectrogram, min=self.log_epsilon)
351
+ )
352
+
353
+ return log_mel_spectrogram
354
+
355
+ def _compute_mask(self, input_mask: torch.Tensor) -> torch.Tensor:
356
+ """Compute output mask for features based on input mask.
357
+
358
+ Parameters
359
+ ----------
360
+ input_mask : torch.Tensor
361
+ Input mask of shape (batch, time) with 1 for real data, 0 for padding
362
+
363
+ Returns
364
+ -------
365
+ torch.Tensor
366
+ Output mask of shape (batch, time_frames) where a frame is True only if
367
+ all samples in that frame were valid (not padded)
368
+ """
369
+ # Split mask into frames using unfold
370
+ # unfold(dimension, size, step)
371
+ mask_frames = input_mask.unfold(
372
+ 1, self.window_size_samples, self.window_stride_samples
373
+ )
374
+
375
+ # A frame is valid only if ALL samples in that frame are valid
376
+ output_mask = torch.all(mask_frames, dim=-1)
377
+
378
+ return output_mask
preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feature_extractor_type": "FeatureExtractor",
3
+ "processor_class": "FeatureExtractor",
4
+ "auto_map": {
5
+ "AutoFeatureExtractor": "feature_extraction.FeatureExtractor"
6
+ },
7
+ "window_size_ms": 25,
8
+ "window_stride_ms": 10,
9
+ "mel_lower_edge_hertz": 0,
10
+ "mel_upper_edge_hertz": 8000,
11
+ "mel_num_bins": 80,
12
+ "sample_rate": 16000,
13
+ "padding_value": 1000.0
14
+ }
speech_features.py DELETED
@@ -1,125 +0,0 @@
1
- """A layer for extracting features from speech data."""
2
-
3
- from typing import Iterable, Optional
4
-
5
- import keras
6
- import torch
7
- from keras import ops
8
- from numpy.typing import NDArray
9
- from transformers.audio_utils import mel_filter_bank
10
-
11
-
12
- class SpeechFeatures(keras.layers.Layer):
13
- """
14
- Computes MFCC features from audio signals.
15
- """
16
-
17
- def __init__(
18
- self,
19
- window_size_ms=25,
20
- window_stride_ms=10,
21
- mel_lower_edge_hertz=0,
22
- mel_upper_edge_hertz=8000,
23
- mel_num_bins=80,
24
- sample_rate=16000,
25
- ):
26
- super().__init__()
27
-
28
- self.window_size_ms = window_size_ms
29
- self.window_stride_ms = window_stride_ms
30
- self.mel_lower_edge_hertz = mel_lower_edge_hertz
31
- self.mel_upper_edge_hertz = mel_upper_edge_hertz
32
- self.mel_num_bins = mel_num_bins
33
- self.sample_rate = sample_rate
34
- self.log_epsilon = 1e-12
35
-
36
- self.window_size_samples = int(
37
- round(self.sample_rate * self.window_size_ms / 1000.0)
38
- )
39
- self.window_stride_samples = int(
40
- round(self.sample_rate * self.window_stride_ms / 1000.0)
41
- )
42
-
43
- self.supports_masking = True
44
- self.fft_len = self.window_size_samples
45
-
46
- def build(self, input_shape: Iterable[int]) -> None:
47
- # precompute the mel matrix
48
- self.mel_matrix = mel_filter_bank(
49
- num_frequency_bins=self.fft_len // 2 + 1,
50
- num_mel_filters=self.mel_num_bins,
51
- min_frequency=self.mel_lower_edge_hertz,
52
- max_frequency=self.mel_upper_edge_hertz,
53
- sampling_rate=self.sample_rate,
54
- )
55
-
56
- def call(self, inputs: NDArray) -> NDArray:
57
- """Apply this layer to inputs."""
58
-
59
- if len(inputs.shape) != 2: # [Batch, Time]
60
- raise ValueError(f"Input rank ({len(inputs.shape)}) must be 2")
61
-
62
- # Zero pad if there isn't enough data for at least one frame (so we don't end up
63
- # with size 0 axes)
64
- inp = ops.pad(
65
- inputs,
66
- [
67
- (0, 0),
68
- (
69
- 0,
70
- ops.maximum(self.window_size_samples - ops.shape(inputs)[1], 0),
71
- ),
72
- ],
73
- )
74
-
75
- # compute spectrogram
76
- spectrogram = self.spectrogram(inp)
77
-
78
- # compute mel spectrogram
79
- outputs = self.log_mel(spectrogram)
80
-
81
- return outputs
82
-
83
- def spectrogram(self, inputs: NDArray) -> NDArray:
84
- """Compute spectrogram from raw audio."""
85
-
86
- spectrogram = ops.stft(
87
- inputs,
88
- self.window_size_samples,
89
- self.window_stride_samples,
90
- fft_length=self.fft_len,
91
- center=False,
92
- )
93
-
94
- spectrogram = torch.complex(*spectrogram)
95
-
96
- spectrogram = ops.abs(spectrogram)
97
-
98
- return spectrogram
99
-
100
- def log_mel(self, spectrogram: NDArray) -> NDArray:
101
- """Transform spectrogram into (log) Mel scale."""
102
-
103
- # multiply spectrogram by mel matrix
104
- mel_spectrogram = ops.tensordot(spectrogram, self.mel_matrix, 1)
105
-
106
- # compute log (with epsilon for stability)
107
- log_mel_spectrogram = ops.log(ops.maximum(mel_spectrogram, self.log_epsilon))
108
-
109
- return log_mel_spectrogram
110
-
111
- def compute_mask(
112
- self, inputs: NDArray, previous_mask: Optional[NDArray] = None
113
- ) -> Optional[NDArray]:
114
- if previous_mask is None:
115
- return None
116
-
117
- # split up mask into frames
118
- mask = ops.extract_sequences(
119
- previous_mask,
120
- self.window_size_samples,
121
- self.window_stride_samples,
122
- )
123
- # mask all the frames that had masked samples in them
124
- mask = ops.all(mask, axis=-1)
125
- return mask