lijialudew commited on
Commit
5490fe0
1 Parent(s): 6333913

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +85 -343
README.md CHANGED
@@ -39,350 +39,92 @@ We develop fine-tuning recipe using SpeechBrain toolkit available at
39
  ## Quick Start [optional]
40
 
41
  <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
42
- If you wish to use fairseq framework, the following code snippet can be used to load our pretrained model
43
-
44
- '''
45
- """This lobe enables the integration of fairseq pretrained wav2vec models.
46
-
47
- Reference: https://arxiv.org/abs/2006.11477
48
- Reference: https://arxiv.org/abs/1904.05862
49
- FairSeq >= 1.0.0 needs to be installed: https://fairseq.readthedocs.io/en/latest/
50
-
51
- Authors
52
- * Titouan Parcollet 2021
53
- * Salima Mdhaffar 2021
54
- """
55
-
56
- import torch
57
- import torch.nn.functional as F
58
- from torch import nn
59
- from speechbrain.utils.data_utils import download_file
60
- import pdb
61
-
62
- # We check if fairseq is installed.
63
- try:
64
- import fairseq
65
- except ImportError:
66
- MSG = "Please install Fairseq to use pretrained wav2vec\n"
67
- MSG += "E.G. run: pip install fairseq"
68
- raise ImportError(MSG)
69
-
70
-
71
- class FairseqWav2Vec2(nn.Module):
72
- """This lobe enables the integration of fairseq pretrained wav2vec2.0 models.
73
-
74
- Source paper: https://arxiv.org/abs/2006.11477
75
- FairSeq >= 1.0.0 needs to be installed:
76
- https://fairseq.readthedocs.io/en/latest/
77
-
78
- The model can be used as a fixed features extractor or can be finetuned. It
79
- will download automatically the model if a url is given (e.g FairSeq
80
- repository from GitHub).
81
-
82
- Arguments
83
- ---------
84
- pretrained_path : str
85
- Path of the pretrained wav2vec2 model. It can be a url or a local path.
86
- save_path : str
87
- Path and filename of the downloaded model.
88
- input_norm : bool (default: None)
89
- If True, a layer_norm (affine) will be applied to the input waveform.
90
- By default, it is extracted from the checkpoint of the downloaded model
91
- in order to match the pretraining conditions. However, if this information
92
- is not given in the checkpoint, it has to be given manually.
93
- output_norm : bool (default: True)
94
- If True, a layer_norm (affine) will be applied to the output obtained
95
- from the wav2vec model.
96
- freeze : bool (default: True)
97
- If True, the model is frozen. If False, the model will be trained
98
- alongside with the rest of the pipeline.
99
- pretrain : bool (default: True)
100
- If True, the model is pretrained with the specified source.
101
- If False, the randomly-initialized model is instantiated.
102
- dropout : float (default: None)
103
- If different from None (0.0 to 1.0), it will override the given fairseq
104
- dropout rates. This is useful if the wav2vec2 model has been trained
105
- without dropout and one wants to reactivate it for downstream task
106
- fine-tuning (better performance observed).
107
-
108
- Example
109
- -------
110
- >>> inputs = torch.rand([10, 600])
111
- >>> model_url = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt"
112
- >>> save_path = "models_checkpoints/wav2vec2.pt"
113
- >>> model = FairseqWav2Vec2(model_url, save_path)
114
- >>> outputs = model(inputs)
115
- >>> outputs.shape
116
- torch.Size([10, 100, 768])
117
- """
118
-
119
- def __init__(
120
- self,
121
- pretrained_path,
122
- save_path,
123
- input_norm=None,
124
- output_norm=True,
125
- freeze=True,
126
- pretrain=True,
127
- dropout=None,
128
- encoder_dropout = None,
129
- output_all_hiddens=False,
130
- tgt_layer=None,
131
- include_CNN_layer=True,
132
- ):
133
- super().__init__()
134
-
135
- # Download the pretrained wav2vec2 model. It can be local or online.
136
- download_file(pretrained_path, save_path)
137
-
138
- # During pretraining dropout might be set to 0. However, we might want
139
- # to apply dropout when fine-tuning on a downstream task. Hence we need
140
- # to modify the fairseq cfg to activate dropout (if requested).
141
- overrides={}
142
- if encoder_dropout is not None:
143
- overrides = {
144
- "model": {
145
- "encoder_layerdrop": encoder_dropout,
146
- }
147
- }
148
- if not freeze:
149
- if dropout is not None and encoder_dropout is not None:
150
- overrides = {
151
- "model": {
152
- "dropout": dropout,
153
- "encoder_layerdrop": encoder_dropout,
154
- "dropout_input": dropout,
155
- "attention_dropout": dropout,
156
- }
157
- }
158
- elif dropout is not None:
159
- overrides = {
160
- "model": {
161
- "dropout": dropout,
162
- "dropout_input": dropout,
163
- "attention_dropout": dropout,
164
- }
165
- }
166
- (
167
- model,
168
- cfg,
169
- task,
170
- ) = fairseq.checkpoint_utils.load_model_ensemble_and_task(
171
- [save_path], arg_overrides=overrides
172
- )
173
-
174
- # wav2vec pretrained models may need the input waveform to be normalized
175
- # Hence, we check if the model has be trained with or without it.
176
- # If the information isn't contained in the checkpoint IT HAS TO BE GIVEN
177
- # BY THE USER.
178
- if input_norm is None:
179
- if hasattr(cfg["task"], "normalize"):
180
- self.normalize = cfg["task"].normalize
181
- elif hasattr(cfg, "normalize"):
182
- self.normalize = cfg.normalize
183
- else:
184
- self.normalize = False
185
- else:
186
- self.normalize = input_norm
187
-
188
- model = model[0]
189
- self.model = model
190
- self.freeze = freeze
191
- self.output_norm = output_norm
192
-
193
- if self.freeze:
194
- self.model.eval()
195
- # Freeze parameters
196
- for param in model.parameters():
197
- param.requires_grad = False
198
- else:
199
- self.model.train()
200
- for param in model.parameters():
201
- param.requires_grad = True
202
-
203
- # Randomly initialized layers if pretrain is False
204
- if not (pretrain):
205
- self.reset_layer(self.model)
206
-
207
- # Following the fairseq implementation of downstream training,
208
- # we remove some modules that are unnecessary.
209
- self.remove_pretraining_modules()
210
- self.output_all_hiddens = output_all_hiddens
211
- self.tgt_layer = tgt_layer
212
- self.include_CNN_layer = include_CNN_layer
213
-
214
- def forward(self, wav):
215
- """Takes an input waveform and return its corresponding wav2vec encoding.
216
-
217
- Arguments
218
- ---------
219
- wav : torch.Tensor (signal)
220
- A batch of audio signals to transform to features.
221
- """
222
-
223
- # If we freeze, we simply remove all grads and features from the graph.
224
- if self.freeze:
225
- with torch.no_grad():
226
- return self.extract_features(wav).detach()
227
-
228
- return self.extract_features(wav)
229
-
230
- def extract_features(self, wav):
231
- """Extracts the wav2vect embeddings"""
232
- # We normalize the input signal if needed.
233
- if self.normalize:
234
- wav = F.layer_norm(wav, wav.shape)
235
-
236
- # Extract wav2vec output
237
- if self.tgt_layer=="CNN": #initial embeddings from conv
238
- out = self.model.extract_features(wav, padding_mask=None, mask=False)
239
- out = self.model.post_extract_proj(out['features'])
240
- elif isinstance(self.tgt_layer, int):
241
- out = self.model.extract_features(wav, padding_mask=None, mask=False, layer=self.tgt_layer)['x']
242
- else: #
243
- out = self.model.extract_features(wav, padding_mask=None, mask=False, layer=self.tgt_layer)
244
- if self.output_all_hiddens or isinstance(self.tgt_layer, list):
245
- out = self.aggregate_features(out, include_CNN_layer=self.include_CNN_layer) # 13, B, T, D
246
- if isinstance(self.tgt_layer, list):
247
- out = out[self.tgt_layer]
248
- else:
249
- out = out['x']
250
-
251
- # We normalize the output if required
252
- if self.output_norm:
253
- out = F.layer_norm(out, out.shape)
254
-
255
- return out
256
 
257
- def aggregate_features(self, out, include_CNN_layer=True):
258
- features = []
259
- if include_CNN_layer:
260
- features = [self.model.post_extract_proj(out['features'])]
261
- self.model.layerdrop = 0
262
- for i in range(len(out['layer_results'])):
263
- curr_feature = out['layer_results'][i][0].transpose(0,1)
264
- features.append(curr_feature)
265
- features = torch.stack(features)
266
- return features
267
-
268
-
269
- def reset_layer(self, model):
270
- """Reinitializes the parameters of the network"""
271
- if hasattr(model, "reset_parameters"):
272
- model.reset_parameters()
273
- for child_layer in model.children():
274
- if model != child_layer:
275
- self.reset_layer(child_layer)
276
-
277
- def remove_pretraining_modules(self):
278
- """ Remove uneeded modules. Inspired by the same fairseq function."""
279
-
280
- self.model.quantizer = None
281
- self.model.project_q = None
282
- self.model.target_glu = None
283
- self.model.final_proj = None
284
-
285
-
286
- class FairseqWav2Vec1(nn.Module):
287
- """This lobes enables the integration of fairseq pretrained wav2vec1.0 models.
288
-
289
- Arguments
290
- ---------
291
- pretrained_path : str
292
- Path of the pretrained wav2vec1 model. It can be a url or a local path.
293
- save_path : str
294
- Path and filename of the downloaded model.
295
- output_norm : bool (default: True)
296
- If True, a layer_norm (affine) will be applied to the output obtained
297
- from the wav2vec model.
298
- freeze : bool (default: True)
299
- If True, the model is frozen. If False, the model will be trained
300
- alongside with the rest of the pipeline.
301
- pretrain : bool (default: True)
302
- If True, the model is pretrained with the specified source.
303
- If False, the randomly-initialized model is instantiated.
304
-
305
- Example
306
- -------
307
- >>> inputs = torch.rand([10, 600])
308
- >>> model_url = ""
309
- >>> save_path = "models_checkpoints/wav2vec.pt"
310
- >>> model = FairseqWav2Vec1(model_url, save_path)
311
- >>> outputs = model(inputs)
312
- >>> outputs.shape
313
- torch.Size([10, 100, 512])
314
- """
315
-
316
- def __init__(
317
- self,
318
- pretrained_path,
319
- save_path,
320
- output_norm=True,
321
- freeze=True,
322
- pretrain=True,
323
- ):
324
- super().__init__()
325
- self.freeze = freeze
326
- self.output_norm = output_norm
327
-
328
- # Download the pretrained wav2vec1 model. It can be local or online.
329
- download_file(pretrained_path, save_path)
330
-
331
- (
332
- model,
333
- cfg,
334
- task,
335
- ) = fairseq.checkpoint_utils.load_model_ensemble_and_task(
336
- [pretrained_path]
337
- )
338
-
339
- self.model = model
340
- self.model = self.model[0]
341
- if self.freeze:
342
- model.eval()
343
-
344
- # Randomly initialized layers if pretrain is False
345
- if not (pretrain):
346
- self.reset_layer(self.model)
347
-
348
- def forward(self, wav):
349
- """Takes an input waveform and return its corresponding wav2vec encoding.
350
-
351
- Arguments
352
- ---------
353
- wav : torch.Tensor (signal)
354
- A batch of audio signals to transform to features.
355
- """
356
-
357
- # If we freeze, we simply remove all grads and features from the graph.
358
- if self.freeze:
359
- with torch.no_grad():
360
- return self.extract_features(wav).detach()
361
-
362
- return self.extract_features(wav)
363
-
364
- def extract_features(self, wav):
365
- """Extracts the wav2vect embeddings"""
366
-
367
- out = self.model.feature_extractor(wav)
368
- out = self.model.feature_aggregator(out).squeeze(0)
369
- out = out.transpose(2, 1)
370
-
371
- # We normalize the output if required
372
- if self.output_norm:
373
- out = F.layer_norm(out, out.shape)
374
-
375
- return out
376
-
377
- def reset_layer(self, model):
378
- """Reinitializes the parameters of the network"""
379
- if hasattr(model, "reset_parameters"):
380
- model.reset_parameters()
381
- for child_layer in model.children():
382
- if model != child_layer:
383
- self.reset_layer(child_layer)
384
- '''
385
-
386
  # Evaluation
387
 
388
  <!-- This section describes the evaluation protocols and provides the results. -->
 
39
  ## Quick Start [optional]
40
 
41
  <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
42
+ If you wish to use fairseq framework, the following code snippet can be used to load our pretrained model.
43
+
44
+ <pre><code>
45
+ import torch
46
+ import torch.nn.functional as F
47
+ from torch import nn
48
+ import fairseq
49
+ import torchaudio
50
+
51
+ def load_model(model_path, freeze=True):
52
+ '''
53
+ This function loads pretrained model using fairseq framework.
54
+ Arguments
55
+ ---------
56
+ model_path : str
57
+ Path and filename of the pretrained model
58
+ freeze : bool (default: True)
59
+ If True, the model is frozen with no parameter updates through training.
60
+ '''
61
+
62
+ model,_,_ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_path])
63
+ model = model[0]
64
+
65
+ if freeze:
66
+ model.eval()
67
+ # Freeze parameters
68
+ for param in model.parameters():
69
+ param.requires_grad = False
70
+ else:
71
+ model.train()
72
+ for param in model.parameters():
73
+ param.requires_grad = True
74
+
75
+ #remove unnecessary components
76
+ model.quantizer = None
77
+ model.project_q = None
78
+ model.target_glu = None
79
+ model.final_proj = None
80
+
81
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ def extract_features(model, wav, input_norm=None, output_norm=True, tgt_layer=None, output_all_hiddens=False):
84
+ '''
85
+ This function extracts features from w2v2 model. The function extracts the last transformer layer
86
+ feature by default. It allows for extracting features from certain layer, or features from all layers
87
+ Arguments
88
+ ---------
89
+ model : fairseq wav2vec
90
+ wav : tensor
91
+ audio wav for feature extraction
92
+ input_norm : bool (default: None)
93
+ If True, a layer_norm (affine) will be applied to the input waveform.
94
+ output_norm : bool (default: True)
95
+ If True, a layer_norm (affine) will be applied to the output obtained
96
+ from the wav2vec model.
97
+ tgt_layer : int (default: None)
98
+ Target transformer layer features, 0-indexed.
99
+ output_all_hiddens : bool (default: False)
100
+ Whether to extract features from all layers. Need to set tgt_layer as None
101
+ '''
102
+
103
+ if input_norm:
104
+ wav = F.layer_norm(wav, wav.shape)
105
+
106
+ # Extract wav2vec output
107
+ out = model.extract_features(wav, padding_mask=None, mask=False)['x']
108
+ if isinstance(tgt_layer, int):
109
+ out = model.extract_features(wav, padding_mask=None, mask=False, layer=tgt_layer)['x']
110
+ elif output_all_hiddens:
111
+ features = []
112
+ model.layerdrop = 0
113
+ for i in range(len(out['layer_results'])):
114
+ curr_feature = out['layer_results'][i][0].transpose(0,1)
115
+ features.append(curr_feature)
116
+ out = torch.stack(features)
117
+
118
+ if output_norm:
119
+ out = F.layer_norm(out, out.shape)
120
+ return out
121
+
122
+ model=load_model("your/path/to/LL_4300/checkpoint_best.pt")
123
+ audio, fs = torchaudio.load("sample.wav")
124
+ audio = audio.transpose(0,1).squeeze(1)
125
+ features = extract_features(model, audio)
126
+ </code></pre>
127
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # Evaluation
129
 
130
  <!-- This section describes the evaluation protocols and provides the results. -->