lijialudew commited on
Commit
bc98ab0
1 Parent(s): 449ceda

Upload fairseq_wav2vec.py

Browse files
Files changed (1) hide show
  1. fairseq_wav2vec.py +294 -0
fairseq_wav2vec.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This snippet is adapted from the original SpeechBrain codebase.
3
+ This lobe enables the integration of fairseq pretrained wav2vec models.
4
+
5
+ Reference: https://arxiv.org/abs/2006.11477
6
+ Reference: https://arxiv.org/abs/1904.05862
7
+ FairSeq >= 1.0.0 needs to be installed: https://fairseq.readthedocs.io/en/latest/
8
+
9
+ Original Authors
10
+ * Titouan Parcollet 2021
11
+ * Salima Mdhaffar 2021
12
+
13
+ Modified by
14
+ * Jialu Li 2023
15
+ """
16
+
17
+ import logging
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+ from speechbrain.utils.data_utils import download_file
22
+
23
+ # We check if fairseq is installed.
24
+ try:
25
+ import fairseq
26
+ except ImportError:
27
+ MSG = "Please install Fairseq to use pretrained wav2vec\n"
28
+ MSG += "E.G. run: pip install fairseq"
29
+ raise ImportError(MSG)
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class FairseqWav2Vec2(nn.Module):
34
+ """This lobe enables the integration of fairseq pretrained wav2vec2.0 models.
35
+
36
+ Source paper: https://arxiv.org/abs/2006.11477
37
+ FairSeq >= 1.0.0 needs to be installed:
38
+ https://fairseq.readthedocs.io/en/latest/
39
+
40
+ The model can be used as a fixed features extractor or can be finetuned. It
41
+ will download automatically the model if a url is given (e.g FairSeq
42
+ repository from GitHub).
43
+
44
+ Arguments
45
+ ---------
46
+ save_path : str
47
+ Path and filename of the downloaded model.
48
+ input_norm : bool (default: None)
49
+ If True, a layer_norm (affine) will be applied to the input waveform.
50
+ By default, it is extracted from the checkpoint of the downloaded model
51
+ in order to match the pretraining conditions. However, if this information
52
+ is not given in the checkpoint, it has to be given manually.
53
+ output_norm : bool (default: True)
54
+ If True, a layer_norm (affine) will be applied to the output obtained
55
+ from the wav2vec model.
56
+ freeze : bool (default: True)
57
+ If True, the model is frozen. If False, the model will be trained
58
+ alongside with the rest of the pipeline.
59
+ pretrain : bool (default: True)
60
+ If True, the model is pretrained with the specified source.
61
+ If False, the randomly-initialized model is instantiated.
62
+ dropout : float (default: None)
63
+ If different from None (0.0 to 1.0), it will override the given fairseq
64
+ dropout rates. This is useful if the wav2vec2 model has been trained
65
+ without dropout and one wants to reactivate it for downstream task
66
+ fine-tuning (better performance observed).
67
+ encoder_dropout : float (default: None)
68
+ If different from None (0.0 to 1.0), it will override the given fairseq
69
+ encoder_layerdrop rates. It has certain probability to dropout random number
70
+ of layer features.
71
+ output_all_hiddens: bool (default: False)
72
+ If True, output the features from all 12 transformer layers.
73
+ If False, output the features from only the last transformer layer.
74
+ tgt_layer: int or list of int (default: None)
75
+ If not None, output the features of the front-end CNN or specified transformer layer(s).
76
+ (0-indexed. 0 - CNN front-end layer, 1-12 transformer layers).
77
+ For extracting front-end CNN features, specify it as "CNN".
78
+ For single layer, specify it as an int.
79
+ For multiple layers, specify it as a list of int.
80
+ include_CNN_layer: bool (default: False)
81
+ This should be used when output_all_hiddens==True.
82
+ If True, output the features from front-end CNN layer as well as all 12 transformer layers.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ save_path,
88
+ input_norm=None,
89
+ output_norm=True,
90
+ freeze=True,
91
+ pretrain=True,
92
+ dropout=None,
93
+ encoder_dropout = None,
94
+ output_all_hiddens = True,
95
+ tgt_layer = None,
96
+ include_CNN_layer = False,
97
+ ):
98
+ super().__init__()
99
+
100
+ # During pretraining dropout might be set to 0. However, we might want
101
+ # to apply dropout when fine-tuning on a downstream task. Hence we need
102
+ # to modify the fairseq cfg to activate dropout (if requested).
103
+ overrides={}
104
+ if encoder_dropout is not None:
105
+ overrides = {
106
+ "model": {
107
+ "encoder_layerdrop": encoder_dropout,
108
+ }
109
+ }
110
+ if not freeze:
111
+ if dropout is not None and encoder_dropout is not None:
112
+ overrides = {
113
+ "model": {
114
+ "dropout": dropout,
115
+ "encoder_layerdrop": encoder_dropout,
116
+ "dropout_input": dropout,
117
+ "attention_dropout": dropout,
118
+ }
119
+ }
120
+ elif dropout is not None:
121
+ overrides = {
122
+ "model": {
123
+ "dropout": dropout,
124
+ "dropout_input": dropout,
125
+ "attention_dropout": dropout,
126
+ }
127
+ }
128
+ (
129
+ model,
130
+ cfg,
131
+ task,
132
+ ) = fairseq.checkpoint_utils.load_model_ensemble_and_task(
133
+ [save_path], arg_overrides=overrides
134
+ )
135
+
136
+ # wav2vec pretrained models may need the input waveform to be normalized
137
+ # Hence, we check if the model has be trained with or without it.
138
+ # If the information isn't contained in the checkpoint IT HAS TO BE GIVEN
139
+ # BY THE USER.
140
+ if input_norm is None:
141
+ if hasattr(cfg["task"], "normalize"):
142
+ self.normalize = cfg["task"].normalize
143
+ elif hasattr(cfg, "normalize"):
144
+ self.normalize = cfg.normalize
145
+ else:
146
+ self.normalize = False
147
+ else:
148
+ self.normalize = input_norm
149
+
150
+ model = model[0]
151
+ self.model = model
152
+ self.freeze = freeze
153
+ self.output_norm = output_norm
154
+
155
+ if self.freeze:
156
+ self.model.eval()
157
+ # Freeze parameters
158
+ for param in model.parameters():
159
+ param.requires_grad = False
160
+ else:
161
+ self.model.train()
162
+ for param in model.parameters():
163
+ param.requires_grad = True
164
+
165
+ # Randomly initialized layers if pretrain is False
166
+ if not (pretrain):
167
+ self.reset_layer(self.model)
168
+
169
+ # Following the fairseq implementation of downstream training,
170
+ # we remove some modules that are unnecessary.
171
+ self.remove_pretraining_modules()
172
+ self.output_all_hiddens = output_all_hiddens
173
+ self.tgt_layer = tgt_layer
174
+ self.include_CNN_layer=include_CNN_layer
175
+ if not self.output_all_hiddens:
176
+ logger.info(
177
+ f"include_CNN_layer is not used when output_all_hidden is False"
178
+ )
179
+ if self.output_all_hiddens:
180
+ self.tgt_layer==None
181
+ logger.warning(
182
+ f"Set tgt_layer to None when output_all_hiddens is True"
183
+ )
184
+
185
+ def forward(self, wav):
186
+ """Takes an input waveform of shape (Batch, Time) and return its corresponding wav2vec encoding.
187
+
188
+ Arguments
189
+ ---------
190
+ wav : torch.Tensor (signal)
191
+ A batch of audio signals to transform to features.
192
+ """
193
+
194
+ # If we freeze, we simply remove all grads and features from the graph.
195
+ if self.freeze:
196
+ with torch.no_grad():
197
+ return self.extract_features(wav).detach()
198
+
199
+ return self.extract_features(wav)
200
+
201
+ def extract_features(self, wav):
202
+ """Extracts the wav2vect embeddings
203
+ wav: torch tensor
204
+ Retruning output dimension as # of Layers x Batch x Time x Dimension
205
+ """
206
+ # We normalize the input signal if needed.
207
+ if self.normalize:
208
+ wav = F.layer_norm(wav, wav.shape)
209
+
210
+ out = self.model.extract_features(wav, padding_mask=None, mask=False)
211
+ # Extract wav2vec output
212
+ if isinstance(self.tgt_layer, int):
213
+ features = out['layer_results'][self.tgt_layer][0].transpose(0, 1)
214
+ elif isinstance(self.tgt_layer, list):
215
+ features = []
216
+ for i in self.tgt_layer:
217
+ curr_feature = out['layer_results'][i][0].transpose(0, 1)
218
+ features.append(curr_feature)
219
+ features = torch.stack(features)
220
+ elif self.output_all_hiddens:
221
+ features = self.aggregate_features(out, include_CNN_layer=self.include_CNN_layer) # 13, B, T, D
222
+ else: # output last layer only
223
+ features = out['x']
224
+
225
+ out=features
226
+ # We normalize the output if required
227
+ if self.output_norm:
228
+ out = F.layer_norm(out, out.shape)
229
+
230
+ return out
231
+
232
+ def aggregate_features(self, out, include_CNN_layer=True):
233
+ features = []
234
+ self.model.layerdrop = 0
235
+ for i in range(len(out['layer_results'])):
236
+ curr_feature = out['layer_results'][i][0].transpose(0,1)
237
+ features.append(curr_feature)
238
+ features = torch.stack(features)
239
+ if not include_CNN_layer:
240
+ features=features[1:]
241
+ return features
242
+
243
+ def reset_layer(self, model):
244
+ """Reinitializes the parameters of the network"""
245
+ if hasattr(model, "reset_parameters"):
246
+ model.reset_parameters()
247
+ for child_layer in model.children():
248
+ if model != child_layer:
249
+ self.reset_layer(child_layer)
250
+
251
+ def _load_sb_pretrained_w2v2_parameters(self, path):
252
+ """Loads the parameter of a w2v2 model pretrained with SpeechBrain and the
253
+ HuggingFaceWav2Vec2Pretrain Object. It is necessary to perform a custom
254
+ loading because HuggingFace adds a level to the checkpoint when storing
255
+ the model breaking the compatibility between HuggingFaceWav2Vec2Pretrain
256
+ and HuggingFaceWav2Vec2.
257
+
258
+ In practice a typical HuggingFaceWav2Vec2 checkpoint for a given parameter
259
+ would be: model.conv.weight.data while for HuggingFaceWav2Vec2Pretrain it
260
+ is: model.wav2vec2.weight.data (wav2vec2 must be removed before loading).
261
+ """
262
+ modified_state_dict = {}
263
+ orig_state_dict = torch.load(path, map_location="cpu")
264
+
265
+ # We remove the .wav2vec2 in the state dict.
266
+ for key, params in orig_state_dict.items():
267
+ if "model." in key:
268
+ save_key = key.replace("model.", "")
269
+ modified_state_dict[save_key] = params
270
+
271
+ incompatible_keys = self.model.load_state_dict(
272
+ modified_state_dict, strict=False
273
+ )
274
+
275
+ for missing_key in incompatible_keys.missing_keys:
276
+ logger.warning(
277
+ f"During parameter transfer to {self.model} loading from "
278
+ + f"{path}, the transferred parameters did not have "
279
+ + f"parameters for the key: {missing_key}"
280
+ )
281
+
282
+ for unexpected_key in incompatible_keys.unexpected_keys:
283
+ logger.warning(
284
+ f"The param with the key: {unexpected_key} is discarded as it "
285
+ + "is useless for wav2vec 2.0 finetuning."
286
+ )
287
+
288
+ def remove_pretraining_modules(self):
289
+ """ Remove uneeded modules. Inspired by the same fairseq function."""
290
+
291
+ self.model.quantizer = None
292
+ self.model.project_q = None
293
+ self.model.target_glu = None
294
+ self.model.final_proj = None