herwoww commited on
Commit
1547a56
1 Parent(s): 6ddbec2

first upload

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +2 -1
  2. app.py +59 -0
  3. artst/__init__.py +1 -0
  4. artst/__pycache__/__init__.cpython-38.pyc +0 -0
  5. artst/__pycache__/sequence_generator.cpython-38.pyc +0 -0
  6. artst/criterions/__init__.py +10 -0
  7. artst/criterions/__pycache__/__init__.cpython-38.pyc +0 -0
  8. artst/criterions/__pycache__/artst_criterion.cpython-38.pyc +0 -0
  9. artst/criterions/__pycache__/speech_pretrain_criterion.cpython-38.pyc +0 -0
  10. artst/criterions/__pycache__/speech_to_text_loss.cpython-38.pyc +0 -0
  11. artst/criterions/__pycache__/speecht5_criterion.cpython-38.pyc +0 -0
  12. artst/criterions/__pycache__/text_pretrain_criterion.cpython-38.pyc +0 -0
  13. artst/criterions/__pycache__/text_to_speech_loss.cpython-38.pyc +0 -0
  14. artst/criterions/artst_criterion.py +443 -0
  15. artst/criterions/speech_pretrain_criterion.py +265 -0
  16. artst/criterions/speech_to_text_loss.py +473 -0
  17. artst/criterions/text_pretrain_criterion.py +142 -0
  18. artst/criterions/text_to_speech_loss.py +425 -0
  19. artst/data/__init__.py +0 -0
  20. artst/data/__pycache__/__init__.cpython-38.pyc +0 -0
  21. artst/data/__pycache__/multitask_dataset.cpython-38.pyc +0 -0
  22. artst/data/__pycache__/speech_dataset.cpython-38.pyc +0 -0
  23. artst/data/__pycache__/speech_to_class_dataset.cpython-38.pyc +0 -0
  24. artst/data/__pycache__/speech_to_speech_dataset.cpython-38.pyc +0 -0
  25. artst/data/__pycache__/speech_to_text_dataset.cpython-38.pyc +0 -0
  26. artst/data/__pycache__/text_dataset.cpython-38.pyc +0 -0
  27. artst/data/__pycache__/text_to_speech_dataset.cpython-38.pyc +0 -0
  28. artst/data/multitask_dataset.py +263 -0
  29. artst/data/speech_dataset.py +475 -0
  30. artst/data/speech_to_class_dataset.py +260 -0
  31. artst/data/speech_to_speech_dataset.py +280 -0
  32. artst/data/speech_to_text_dataset.py +298 -0
  33. artst/data/text_dataset.py +474 -0
  34. artst/data/text_to_speech_dataset.py +344 -0
  35. artst/models/__init__.py +2 -0
  36. artst/models/__pycache__/__init__.cpython-38.pyc +0 -0
  37. artst/models/__pycache__/artst.cpython-38.pyc +0 -0
  38. artst/models/__pycache__/speecht5.cpython-38.pyc +0 -0
  39. artst/models/__pycache__/t5_transformer_lm.cpython-38.pyc +0 -0
  40. artst/models/artst.py +1448 -0
  41. artst/models/modules/__init__.py +0 -0
  42. artst/models/modules/__pycache__/__init__.cpython-38.pyc +0 -0
  43. artst/models/modules/__pycache__/decoder.cpython-38.pyc +0 -0
  44. artst/models/modules/__pycache__/encoder.cpython-38.pyc +0 -0
  45. artst/models/modules/__pycache__/multihead_attention.cpython-38.pyc +0 -0
  46. artst/models/modules/__pycache__/speaker_decoder_postnet.cpython-38.pyc +0 -0
  47. artst/models/modules/__pycache__/speech_decoder_postnet.cpython-38.pyc +0 -0
  48. artst/models/modules/__pycache__/speech_decoder_prenet.cpython-38.pyc +0 -0
  49. artst/models/modules/__pycache__/speech_encoder_postnet.cpython-38.pyc +0 -0
  50. artst/models/modules/__pycache__/speech_encoder_prenet.cpython-38.pyc +0 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Artst
3
  emoji: 🔥
4
  colorFrom: yellow
5
  colorTo: gray
@@ -7,6 +7,7 @@ sdk: gradio
7
  sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: ArtstTTS
3
  emoji: 🔥
4
  colorFrom: yellow
5
  colorTo: gray
 
7
  sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
+ python_version: 3.8.2
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import os.path as op
5
+ import pyarabic.araby as araby
6
+
7
+ from artst.tasks.artst import ArTSTTask
8
+ from transformers import SpeechT5HifiGan
9
+ from artst.models.artst import ArTSTTransformerModel
10
+ from fairseq.tasks.hubert_pretraining import LabelEncoder
11
+ from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ WORK_DIR = os.getcwd()
16
+ checkpoint = torch.load('ckpts/clartts_tts.pt')
17
+ checkpoint['cfg']['task'].t5_task = 't2s'
18
+ task = ArTSTTask.setup_task(checkpoint['cfg']['task'])
19
+
20
+ emb_path='embs/clartts.npy'
21
+ model = ArTSTTransformerModel.build_model(checkpoint['cfg']['model'], task)
22
+ model.load_state_dict(checkpoint['model'])
23
+
24
+ checkpoint['cfg']['task'].bpe_tokenizer = task.build_bpe(checkpoint['cfg']['model'])
25
+ tokenizer = checkpoint['cfg']['task'].bpe_tokenizer
26
+
27
+ processor = LabelEncoder(task.dicts['text'])
28
+
29
+ vocoder = SpeechT5HifiGan.from_pretrained('microsoft/speecht5_hifigan').to(device)
30
+
31
+ def get_embs(emb_path):
32
+ spkembs = get_features_or_waveform(emb_path)
33
+ spkembs = torch.from_numpy(spkembs).float().unsqueeze(0)
34
+ return spkembs
35
+
36
+ def process_text(text):
37
+ text = araby.strip_diacritics(text)
38
+ return processor(tokenizer.encode(text)).reshape(1, -1)
39
+
40
+ net_input = {}
41
+
42
+ def inference(text, spkr=emb_path):
43
+ net_input['src_tokens'] = process_text(text)
44
+ net_input['spkembs'] = get_embs(spkr)
45
+ outs, _, attn = task.generate_speech(
46
+ [model],
47
+ net_input,
48
+ )
49
+ with torch.no_grad():
50
+ gen_audio = vocoder(outs.to(device))
51
+ return (16000,gen_audio.cpu().numpy())
52
+
53
+ text_box = gr.Textbox(max_lines=2, label="Arabic Text")
54
+ out = gr.Audio(label="Synthesized Audio", type="numpy")
55
+ demo = gr.Interface(inference, \
56
+ inputs=text_box, outputs=out, title="ArTST")
57
+
58
+ if __name__ == "__main__":
59
+ demo.launch(share=True)
artst/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import data, tasks, criterions, models # noqa
artst/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (218 Bytes). View file
 
artst/__pycache__/sequence_generator.cpython-38.pyc ADDED
Binary file (26.2 kB). View file
 
artst/criterions/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+
4
+
5
+ for file in os.listdir(os.path.dirname(__file__)):
6
+ if file.endswith(".py") and not file.startswith("_"):
7
+ criterion_name = file[: file.find(".py")]
8
+ importlib.import_module(
9
+ "artst.criterions." + criterion_name
10
+ )
artst/criterions/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (395 Bytes). View file
 
artst/criterions/__pycache__/artst_criterion.cpython-38.pyc ADDED
Binary file (16.3 kB). View file
 
artst/criterions/__pycache__/speech_pretrain_criterion.cpython-38.pyc ADDED
Binary file (9.08 kB). View file
 
artst/criterions/__pycache__/speech_to_text_loss.cpython-38.pyc ADDED
Binary file (12.8 kB). View file
 
artst/criterions/__pycache__/speecht5_criterion.cpython-38.pyc ADDED
Binary file (16.3 kB). View file
 
artst/criterions/__pycache__/text_pretrain_criterion.cpython-38.pyc ADDED
Binary file (5.55 kB). View file
 
artst/criterions/__pycache__/text_to_speech_loss.cpython-38.pyc ADDED
Binary file (14 kB). View file
 
artst/criterions/artst_criterion.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import re
9
+ from dataclasses import dataclass
10
+
11
+ import math
12
+ from fairseq import metrics, utils
13
+ from fairseq.criterions import FairseqCriterion, register_criterion
14
+ from artst.criterions.text_to_speech_loss import TexttoSpeechLoss
15
+ from artst.criterions.text_pretrain_criterion import TextPretrainCriterion, TextPretrainCriterionConfig
16
+ from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterionConfig
17
+ from artst.criterions.speech_pretrain_criterion import SpeechPretrainCriterion, SpeechPretrainCriterionConfig
18
+ from artst.criterions.speech_to_text_loss import SpeechtoTextLoss, SpeechtoTextLossConfig
19
+ from fairseq.logging.meters import safe_round
20
+
21
+ @dataclass
22
+ class ArTSTCriterionConfig(
23
+ LabelSmoothedCrossEntropyCriterionConfig,
24
+ TextPretrainCriterionConfig,
25
+ SpeechPretrainCriterionConfig,
26
+ SpeechtoTextLossConfig
27
+ ):
28
+ pass
29
+
30
+ @register_criterion(
31
+ "artst", dataclass=ArTSTCriterionConfig
32
+ )
33
+ class ArTSTCriterion(FairseqCriterion):
34
+ def __init__(
35
+ self,
36
+ task,
37
+ sentence_avg,
38
+ label_smoothing,
39
+ pred_masked_weight,
40
+ pred_nomask_weight,
41
+ loss_weights=None,
42
+ log_keys=None,
43
+ ignore_prefix_size=0,
44
+ report_accuracy=False,
45
+ use_masking=True,
46
+ use_weighted_masking=False,
47
+ loss_type="L1",
48
+ bce_pos_weight=5.0,
49
+ bce_loss_lambda=1.0,
50
+ use_guided_attn_loss=False,
51
+ num_heads_applied_guided_attn=2,
52
+ ce_weight=1.0,
53
+ ctc_weight=0.0,
54
+ hubert_weight=1.0,
55
+ dec_weight=1.0,
56
+ bart_weight=1.0,
57
+ ):
58
+ super().__init__(task)
59
+ self.speech_criterion = TexttoSpeechLoss(
60
+ task,
61
+ sentence_avg,
62
+ use_masking,
63
+ use_weighted_masking,
64
+ loss_type,
65
+ bce_pos_weight,
66
+ bce_loss_lambda,
67
+ use_guided_attn_loss,
68
+ num_heads_applied_guided_attn=num_heads_applied_guided_attn,
69
+ )
70
+ self.text_criterion = SpeechtoTextLoss(
71
+ SpeechtoTextLossConfig,
72
+ task,
73
+ sentence_avg,
74
+ label_smoothing,
75
+ ignore_prefix_size,
76
+ report_accuracy,
77
+ ce_weight,
78
+ ctc_weight
79
+ )
80
+ self.text_pretrain_criterion = TextPretrainCriterion(
81
+ task,
82
+ sentence_avg,
83
+ bart_weight,
84
+ loss_weights,
85
+ )
86
+ self.speech_pretrain_criterion = SpeechPretrainCriterion(
87
+ task,
88
+ sentence_avg,
89
+ pred_masked_weight,
90
+ pred_nomask_weight,
91
+ loss_weights,
92
+ log_keys,
93
+ use_masking,
94
+ use_weighted_masking,
95
+ loss_type,
96
+ bce_pos_weight,
97
+ hubert_weight,
98
+ dec_weight
99
+ )
100
+
101
+ def forward(self, model, sample, reduce=True):
102
+ """Compute the loss for the given sample.
103
+
104
+ Returns a tuple with three elements:
105
+ 1) the loss
106
+ 2) the sample size, which is used as the denominator for the gradient
107
+ 3) logging outputs to display while training
108
+ """
109
+
110
+ task_name = sample['task_name']
111
+ if task_name == 's2t' or task_name == 's2c':
112
+ return self.text_criterion(model, sample, reduce)
113
+ elif task_name == 't2s' or task_name == 's2s':
114
+ return self.speech_criterion(model, sample)
115
+ elif task_name == 'text_pretrain':
116
+ return self.text_pretrain_criterion(model, sample, reduce)
117
+ elif task_name == 'speech_pretrain':
118
+ return self.speech_pretrain_criterion(model, sample, reduce)
119
+
120
+ @classmethod
121
+ def reduce_metrics(cls, logging_outputs):
122
+ """Aggregate logging outputs from data parallel training."""
123
+ logging_outputs_dict = {}
124
+ for logging_output in logging_outputs:
125
+ for task_name in logging_output:
126
+ if task_name not in ['s2t', 't2s', 's2c', 's2s', 'text_pretrain', 'speech_pretrain']:
127
+ continue
128
+
129
+ if task_name not in logging_outputs_dict:
130
+ logging_outputs_dict[task_name] = []
131
+ logging_outputs_dict[task_name].append(logging_output[task_name])
132
+
133
+ for task_name in logging_outputs_dict:
134
+ if task_name == 's2t':
135
+ # LabelSmoothedCrossEntropyCriterion.reduce_metrics([logging_output['s2t'] for logging_output in logging_outputs])
136
+ s2t_logging_output = logging_outputs_dict[task_name]
137
+ # s2t_sum = sum(log.get("ce_loss", 0) for log in logging_outputs)
138
+ loss_sum = sum(log.get("loss", 0) for log in s2t_logging_output)
139
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2t_logging_output)
140
+ ntokens = sum(log.get("ntokens", 0) for log in s2t_logging_output)
141
+ ce_loss_sum = sum(log.get("ce_loss", 0) for log in s2t_logging_output)
142
+ ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in s2t_logging_output)
143
+
144
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in s2t_logging_output))
145
+ metrics.log_scalar(
146
+ "s2t_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3
147
+ )
148
+
149
+ metrics.log_scalar(
150
+ "s2t_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
151
+ )
152
+ metrics.log_derived(
153
+ "s2t_ppl", lambda meters: utils.get_perplexity(meters["s2t_nll_loss"].avg, 2)
154
+ )
155
+ metrics.log_scalar(
156
+ "ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
157
+ )
158
+ metrics.log_scalar(
159
+ "ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3
160
+ )
161
+
162
+ total = utils.item(sum(log.get("total", 0) for log in s2t_logging_output))
163
+ if total > 0:
164
+ metrics.log_scalar("s2t_total", total)
165
+ n_correct = utils.item(
166
+ sum(log.get("n_correct", 0) for log in s2t_logging_output)
167
+ )
168
+ metrics.log_scalar("s2t_n_correct", n_correct)
169
+ metrics.log_derived(
170
+ "s2t_accuracy",
171
+ lambda meters: round(
172
+ meters["s2t_n_correct"].sum * 100.0 / meters["s2t_total"].sum, 3
173
+ )
174
+ if meters["s2t_total"].sum > 0
175
+ else float("nan"),
176
+ 2
177
+ )
178
+ c_errors = sum(log.get("c_errors", 0) for log in s2t_logging_output)
179
+ metrics.log_scalar("_c_errors", c_errors)
180
+ c_total = sum(log.get("c_total", 0) for log in s2t_logging_output)
181
+ metrics.log_scalar("_c_total", c_total)
182
+ w_errors = sum(log.get("w_errors", 0) for log in s2t_logging_output)
183
+ metrics.log_scalar("_w_errors", w_errors)
184
+ wv_errors = sum(log.get("wv_errors", 0) for log in s2t_logging_output)
185
+ metrics.log_scalar("_wv_errors", wv_errors)
186
+ w_total = sum(log.get("w_total", 0) for log in s2t_logging_output)
187
+ metrics.log_scalar("_w_total", w_total)
188
+ if c_total > 0:
189
+ metrics.log_derived(
190
+ "uer",
191
+ lambda meters: safe_round(
192
+ meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
193
+ )
194
+ if meters["_c_total"].sum > 0
195
+ else float("nan"),
196
+ )
197
+ if w_total > 0:
198
+ metrics.log_derived(
199
+ "wer",
200
+ lambda meters: safe_round(
201
+ meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
202
+ )
203
+ if meters["_w_total"].sum > 0
204
+ else float("nan"),
205
+ )
206
+ metrics.log_derived(
207
+ "raw_wer",
208
+ lambda meters: safe_round(
209
+ meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
210
+ )
211
+ if meters["_w_total"].sum > 0
212
+ else float("nan"),
213
+ )
214
+
215
+ if task_name == 't2s':
216
+ # TTSLossCriterion.reduce_metrics([logging_output['t2s'] for logging_output in logging_outputs])
217
+ # t2s_sum = sum(log.get("speech_loss", 0) for log in logging_outputs)
218
+ t2s_logging_output = logging_outputs_dict[task_name]
219
+ loss_sum = sum(log.get("loss", 0) for log in t2s_logging_output)
220
+ l1_loss_sum = sum(log.get("l1_loss", 0) for log in t2s_logging_output)
221
+ l2_loss_sum = sum(log.get("l2_loss", 0) for log in t2s_logging_output)
222
+ bce_loss_sum = sum(log.get("bce_loss", 0) for log in t2s_logging_output)
223
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in t2s_logging_output))
224
+ metrics.log_scalar(
225
+ "t2s_loss", loss_sum / sample_size, sample_size, 1, round=5
226
+ )
227
+ encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in t2s_logging_output)
228
+ decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in t2s_logging_output)
229
+ ngpu = sum(log.get("ngpu", 0) for log in t2s_logging_output)
230
+
231
+ metrics.log_scalar(
232
+ "t2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
233
+ )
234
+ metrics.log_scalar(
235
+ "t2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
236
+ )
237
+ metrics.log_scalar(
238
+ "t2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
239
+ )
240
+ metrics.log_scalar(
241
+ "t2s_encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5
242
+ )
243
+ metrics.log_scalar(
244
+ "t2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
245
+ )
246
+
247
+ if "enc_dec_attn_loss" in t2s_logging_output[0]:
248
+ enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in t2s_logging_output)
249
+ metrics.log_scalar(
250
+ "t2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
251
+ )
252
+
253
+ if task_name == 's2c':
254
+ s2c_logging_output = logging_outputs_dict[task_name]
255
+ loss_sum = sum(log.get("loss", 0) for log in s2c_logging_output)
256
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2c_logging_output)
257
+ ntokens = sum(log.get("ntokens", 0) for log in s2c_logging_output)
258
+
259
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in s2c_logging_output))
260
+ metrics.log_scalar(
261
+ "s2c_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3
262
+ )
263
+
264
+ metrics.log_scalar(
265
+ "s2c_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
266
+ )
267
+
268
+ total = utils.item(sum(log.get("total", 0) for log in s2c_logging_output))
269
+ if total > 0:
270
+ metrics.log_scalar("s2c_total", total)
271
+ n_correct = utils.item(sum(log.get("n_correct", 0) for log in s2c_logging_output))
272
+ metrics.log_scalar("s2c_n_correct", n_correct)
273
+ metrics.log_derived(
274
+ "s2c_accuracy",
275
+ lambda meters: round(
276
+ meters["s2c_n_correct"].sum * 100.0 / meters["s2c_total"].sum, 3
277
+ )
278
+ if meters["s2c_total"].sum > 0
279
+ else float("nan"),
280
+ 2
281
+ )
282
+
283
+ if task_name == 's2s':
284
+ s2s_logging_output = logging_outputs_dict[task_name]
285
+ loss_sum = sum(log.get("loss", 0) for log in s2s_logging_output)
286
+ l1_loss_sum = sum(log.get("l1_loss", 0) for log in s2s_logging_output)
287
+ l2_loss_sum = sum(log.get("l2_loss", 0) for log in s2s_logging_output)
288
+ bce_loss_sum = sum(log.get("bce_loss", 0) for log in s2s_logging_output)
289
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in s2s_logging_output))
290
+ metrics.log_scalar(
291
+ "s2s_loss", loss_sum / sample_size, sample_size, 1, round=5
292
+ )
293
+ encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in s2s_logging_output)
294
+ decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in s2s_logging_output)
295
+ ngpu = sum(log.get("ngpu", 0) for log in s2s_logging_output)
296
+
297
+ metrics.log_scalar(
298
+ "s2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
299
+ )
300
+ metrics.log_scalar(
301
+ "s2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
302
+ )
303
+ metrics.log_scalar(
304
+ "s2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
305
+ )
306
+ metrics.log_scalar(
307
+ "s2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
308
+ )
309
+
310
+ if "enc_dec_attn_loss" in s2s_logging_output[0]:
311
+ enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in s2s_logging_output)
312
+ metrics.log_scalar(
313
+ "s2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
314
+ )
315
+
316
+ if task_name == 'text_pretrain':
317
+ bart_logging_output = logging_outputs_dict[task_name]
318
+ loss_sum = sum(log.get("loss", 0) for log in bart_logging_output)
319
+ ntokens = sum(log.get("ntokens", 0) for log in bart_logging_output)
320
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in bart_logging_output))
321
+ bart_loss_sum = sum(log.get("bart_loss", 0) for log in bart_logging_output)
322
+
323
+ # we divide by log(2) to convert the loss from base e to base 2
324
+ metrics.log_scalar(
325
+ "text_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
326
+ )
327
+ metrics.log_scalar(
328
+ "bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
329
+ )
330
+ if sample_size != ntokens:
331
+ metrics.log_scalar(
332
+ "bart_nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3
333
+ )
334
+ metrics.log_derived(
335
+ "bart_ppl", lambda meters: utils.get_perplexity(meters["bart_nll_loss"].avg)
336
+ )
337
+ else:
338
+ metrics.log_derived(
339
+ "bart_ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg)
340
+ )
341
+ metrics.log_scalar("bart_wpb", ntokens, priority=180, round=1)
342
+
343
+ val_prob_perplexity = 0
344
+ val_code_perplexity = 0
345
+ sample_size_pp = 0
346
+ count_log_cp = 0
347
+ for log in bart_logging_output:
348
+ if "loss_prob_perplexity" in log:
349
+ val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"]
350
+ sample_size_pp = sample_size_pp + log["sample_size"]
351
+ if "code_perplexity" in log:
352
+ val_code_perplexity = val_code_perplexity + log["code_perplexity"]
353
+ count_log_cp = count_log_cp + 1
354
+ if val_prob_perplexity > 0:
355
+ metrics.log_scalar("text_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3)
356
+ if val_code_perplexity > 0:
357
+ metrics.log_scalar("text_code_perplexity", val_code_perplexity / count_log_cp, round=3)
358
+
359
+ if task_name == 'speech_pretrain':
360
+ hubert_logging_output = logging_outputs_dict[task_name]
361
+ loss_sum = sum(log.get("loss", 0) for log in hubert_logging_output)
362
+ ntokens = sum(log.get("ntokens", 0) for log in hubert_logging_output)
363
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in hubert_logging_output))
364
+ dec_loss_sum = sum(log.get("dec_loss", 0) for log in hubert_logging_output)
365
+ l1_loss_sum = sum(log.get("l1_loss", 0) for log in hubert_logging_output)
366
+ l2_loss_sum = sum(log.get("l2_loss", 0) for log in hubert_logging_output)
367
+ bce_loss_sum = sum(log.get("bce_loss", 0) for log in hubert_logging_output)
368
+ ngpu = sum(log.get("ngpu", 0) for log in hubert_logging_output)
369
+
370
+ metrics.log_scalar("hubert_loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
371
+ if sample_size != ntokens:
372
+ metrics.log_scalar("hubert_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
373
+ metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_nll_loss"].avg))
374
+ else:
375
+ metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_loss"].avg))
376
+
377
+ counts = {}
378
+ for lk in hubert_logging_output[0].keys():
379
+ if lk.startswith("count_"):
380
+ val = sum(log[lk] for log in hubert_logging_output)
381
+ metrics.log_scalar("hubert_" + lk, val)
382
+ counts[lk] = val
383
+
384
+ for lk in hubert_logging_output[0].keys():
385
+ if lk.startswith("loss_") and lk != 'loss_prob_perplexity':
386
+ val = sum(log[lk] for log in hubert_logging_output)
387
+ metrics.log_scalar("hubert_" + lk, val / sample_size / math.log(2), round=3)
388
+ elif lk.startswith("correct_"):
389
+ val = sum(log[lk] for log in hubert_logging_output)
390
+ metrics.log_scalar("hubert_" + lk, val / counts[re.sub("correct", "count", lk)])
391
+ # elif lk == 'code_perplexity':
392
+ # val = sum(log[lk] for log in hubert_logging_output)
393
+ # metrics.log_scalar("hubert_" + lk, val / len(hubert_logging_output), round=3)
394
+
395
+ val_prob_perplexity = 0
396
+ val_code_perplexity = 0
397
+ sample_size_pp = 0
398
+ count_log_cp = 0
399
+ for log in hubert_logging_output:
400
+ if "loss_prob_perplexity" in log:
401
+ val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"]
402
+ sample_size_pp = sample_size_pp + log["sample_size"]
403
+ if "code_perplexity" in log:
404
+ val_code_perplexity = val_code_perplexity + log["code_perplexity"]
405
+ count_log_cp = count_log_cp + 1
406
+ if val_prob_perplexity > 0:
407
+ metrics.log_scalar("hubert_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3)
408
+ if val_code_perplexity > 0:
409
+ metrics.log_scalar("hubert_code_perplexity", val_code_perplexity / count_log_cp, round=3)
410
+
411
+ metrics.log_scalar(
412
+ "hubert_dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5
413
+ )
414
+ metrics.log_scalar(
415
+ "hubert_l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5
416
+ )
417
+ metrics.log_scalar(
418
+ "hubert_l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5
419
+ )
420
+ metrics.log_scalar(
421
+ "hubert_bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5
422
+ )
423
+ if "enc_dec_attn_loss" in hubert_logging_output[0]:
424
+ enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in hubert_logging_output)
425
+ metrics.log_scalar(
426
+ "hubert_enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8
427
+ )
428
+ metrics.log_scalar("hubert_wpb", ntokens, priority=180, round=1)
429
+
430
+ loss = sum(log.get("loss", 0) for log in logging_outputs)
431
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs))
432
+ metrics.log_scalar(
433
+ "loss", loss / sample_size, sample_size, 1, round=5
434
+ )
435
+
436
+ @staticmethod
437
+ def logging_outputs_can_be_summed() -> bool:
438
+ """
439
+ Whether the logging outputs returned by `forward` can be summed
440
+ across workers prior to calling `reduce_metrics`. Setting this
441
+ to True will improves distributed training speed.
442
+ """
443
+ return False
artst/criterions/speech_pretrain_criterion.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ import re
10
+ from dataclasses import dataclass, field
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from fairseq import metrics, utils
16
+ from fairseq.criterions import FairseqCriterion
17
+ from artst.criterions.text_to_speech_loss import TexttoSpeechLoss, TexttoSpeechLossConfig
18
+
19
+
20
+ @dataclass
21
+ class SpeechPretrainCriterionConfig(TexttoSpeechLossConfig):
22
+ pred_masked_weight: float = field(
23
+ default=1.0,
24
+ metadata={"help": "weight for predictive loss for masked frames"},
25
+ )
26
+ pred_nomask_weight: float = field(
27
+ default=0.0,
28
+ metadata={"help": "weight for predictive loss for unmasked frames"},
29
+ )
30
+ loss_weights: Optional[List[float]] = field(
31
+ default_factory=lambda: [10,],
32
+ metadata={"help": "weights for additional loss terms (not first one)"},
33
+ )
34
+ log_keys: List[str] = field(
35
+ default_factory=lambda: [],
36
+ metadata={"help": "output keys to log"},
37
+ )
38
+ hubert_weight: float = field(
39
+ default=1.0,
40
+ metadata={"help": "weight of hubert loss"},
41
+ )
42
+ dec_weight: float = field(
43
+ default=1.0,
44
+ metadata={"help": "weight of decoder loss"},
45
+ )
46
+
47
+
48
+ class SpeechPretrainCriterion(FairseqCriterion):
49
+ def __init__(
50
+ self,
51
+ task,
52
+ sentence_avg,
53
+ pred_masked_weight,
54
+ pred_nomask_weight,
55
+ loss_weights=None,
56
+ log_keys=None,
57
+ use_masking=True,
58
+ use_weighted_masking=False,
59
+ loss_type="L1",
60
+ bce_pos_weight=5.0,
61
+ hubert_weight=1.0,
62
+ dec_weight=1.0,
63
+ ):
64
+ super().__init__(task)
65
+ self.pred_masked_weight = pred_masked_weight
66
+ self.pred_nomask_weight = pred_nomask_weight
67
+ self.loss_weights = loss_weights
68
+ self.log_keys = [] if log_keys is None else log_keys
69
+ self.hubert_weight = hubert_weight
70
+ self.dec_weight = dec_weight
71
+
72
+ self.speech_criterion = TexttoSpeechLoss(
73
+ task,
74
+ sentence_avg,
75
+ use_masking,
76
+ use_weighted_masking,
77
+ loss_type,
78
+ bce_pos_weight,
79
+ )
80
+
81
+ def forward(self, model, sample, reduce=True, log_pred=False):
82
+ """Compute the loss for the given sample.
83
+ Returns a tuple with three elements:
84
+ 1) the loss
85
+ 2) the sample size, which is used as the denominator for the gradient
86
+ 3) logging outputs to display while training
87
+ """
88
+ if self.dec_weight == 0:
89
+ sample["net_input"]["only_hubert"] = True
90
+ net_output, net_output_dec = model(target_list=sample["target_list"], **sample["net_input"])
91
+ loss = 0.
92
+ sample_size = 0
93
+ logging_output = {}
94
+ reduction = "sum" if reduce else "none"
95
+
96
+ loss_m_list = []
97
+ logp_m_list = model.get_logits(net_output, True)
98
+ targ_m_list = model.get_targets(None, net_output, True)
99
+ assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
100
+ for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
101
+ loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
102
+ loss_m_list.append(loss_m)
103
+ logging_output[f"loss_m_{i}"] = loss_m.detach().item()
104
+ if self.pred_masked_weight > 0:
105
+ loss += self.pred_masked_weight * sum(loss_m_list)
106
+ sample_size += targ_m_list[0].numel()
107
+
108
+ loss_u_list = []
109
+ logp_u_list = model.get_logits(net_output, False)
110
+ targ_u_list = model.get_targets(None, net_output, False)
111
+ assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
112
+ for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
113
+ loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
114
+ loss_u_list.append(loss_u)
115
+ logging_output[f"loss_u_{i}"] = loss_u.detach().item()
116
+ if self.pred_nomask_weight > 0:
117
+ loss += self.pred_nomask_weight * sum(loss_u_list)
118
+ sample_size += targ_u_list[0].numel()
119
+
120
+ if self.loss_weights is not None:
121
+ assert hasattr(model, "get_extra_losses")
122
+ extra_losses, names = model.get_extra_losses(net_output)
123
+ if torch.is_tensor(extra_losses):
124
+ extra_losses = [extra_losses]
125
+ names = [names]
126
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
127
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
128
+ if len(self.loss_weights) > len(extra_losses):
129
+ modified_loss_weight = self.loss_weights[:len(extra_losses)]
130
+ else:
131
+ modified_loss_weight = self.loss_weights
132
+
133
+ # assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
134
+ for p, n, coef in zip(extra_losses, names, modified_loss_weight):
135
+ # print(n + str(coef))
136
+ if coef != 0 and p is not None:
137
+ p = coef * p.float() * sample_size
138
+ loss += p
139
+ logging_output[f"loss_{n}"] = p.detach().item()
140
+
141
+ logging_output = {
142
+ "ntokens": sample_size,
143
+ "nsentences": sample["id"].numel(),
144
+ "sample_size": sample_size,
145
+ "ngpu": 1,
146
+ **logging_output,
147
+ }
148
+
149
+ if 'loss_prob_perplexity' in logging_output:
150
+ logging_output['code_perplexity'] = net_output['code_perplexity'].detach().item()
151
+
152
+ for lk in self.log_keys:
153
+ if lk in net_output:
154
+ logging_output[lk] = float((net_output[lk].item()))
155
+
156
+ def compute_correct(logits):
157
+ if logits.numel() == 0:
158
+ return 0, 0
159
+ else:
160
+ assert logits.dim() > 1, logits.shape
161
+ max = logits.argmax(-1) == 0
162
+ min = logits.argmin(-1) == 0
163
+ both = max & min
164
+ corr = max.long().sum().item() - both.long().sum().item()
165
+ count = max.numel()
166
+ return corr, count
167
+
168
+ with torch.no_grad():
169
+ for i, logp_m in enumerate(logp_m_list):
170
+ corr_m, count_m = compute_correct(logp_m)
171
+ logging_output[f"correct_m_{i}"] = corr_m
172
+ logging_output[f"count_m_{i}"] = count_m
173
+
174
+ for i, logp_u in enumerate(logp_u_list):
175
+ corr_u, count_u = compute_correct(logp_u)
176
+ logging_output[f"correct_u_{i}"] = corr_u
177
+ logging_output[f"count_u_{i}"] = count_u
178
+
179
+ if self.dec_weight == 0.0:
180
+ logging_output["loss"] = loss.item() if reduce else loss
181
+ return loss, sample_size, logging_output
182
+
183
+ # ## dec loss
184
+ dec_loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.speech_criterion.compute_loss(model, net_output_dec, sample)
185
+
186
+ # Log tts loss
187
+ logging_output['dec_loss'] = dec_loss.item()
188
+ logging_output['l1_loss'] = l1_loss.item()
189
+ logging_output['l2_loss'] = l2_loss.item()
190
+ logging_output['bce_loss'] = bce_loss.item()
191
+ if enc_dec_attn_loss is not None:
192
+ logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item()
193
+
194
+ loss = self.hubert_weight * loss + self.dec_weight * sample_size * dec_loss
195
+ logging_output["loss"] = loss.item() if reduce else loss
196
+ return loss, sample_size, logging_output
197
+
198
+ @staticmethod
199
+ def reduce_metrics(logging_outputs) -> None:
200
+ """Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
201
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
202
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
203
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
204
+ dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
205
+ l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs)
206
+ l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs)
207
+ bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs)
208
+ ngpu = sum(log.get("ngpu", 0) for log in logging_outputs)
209
+
210
+ metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
211
+ if sample_size != ntokens:
212
+ metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
213
+ metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
214
+ else:
215
+ metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
216
+
217
+ counts = {}
218
+ for lk in logging_outputs[0].keys():
219
+ if lk.startswith("count_"):
220
+ val = sum(log[lk] for log in logging_outputs)
221
+ metrics.log_scalar(lk, val)
222
+ counts[lk] = val
223
+
224
+ for lk in logging_outputs[0].keys():
225
+ if lk.startswith("loss_"):
226
+ val = sum(log[lk] for log in logging_outputs)
227
+ metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
228
+ elif lk.startswith("correct_"):
229
+ val = sum(log[lk] for log in logging_outputs)
230
+ metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
231
+ elif lk == 'code_perplexity':
232
+ val = sum(log[lk] for log in logging_outputs)
233
+ metrics.log_scalar(lk, val / len(logging_outputs), round=3)
234
+
235
+ metrics.log_scalar(
236
+ "dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5
237
+ )
238
+ metrics.log_scalar(
239
+ "l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5
240
+ )
241
+ metrics.log_scalar(
242
+ "l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5
243
+ )
244
+ metrics.log_scalar(
245
+ "bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5
246
+ )
247
+ if "enc_dec_attn_loss" in logging_outputs[0]:
248
+ enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs)
249
+ metrics.log_scalar(
250
+ "enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8
251
+ )
252
+
253
+ @staticmethod
254
+ def aggregate_logging_outputs(logging_outputs):
255
+ """Aggregate logging outputs from data parallel training."""
256
+ raise NotImplementedError()
257
+
258
+ @staticmethod
259
+ def logging_outputs_can_be_summed() -> bool:
260
+ """
261
+ Whether the logging outputs returned by `forward` can be summed
262
+ across workers prior to calling `reduce_metrics`. Setting this
263
+ to True will improves distributed training speed.
264
+ """
265
+ return False
artst/criterions/speech_to_text_loss.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ from argparse import Namespace
10
+ from dataclasses import dataclass, field
11
+ from omegaconf import II
12
+ from typing import Optional
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq import metrics, utils
17
+ from fairseq.criterions import FairseqCriterion, register_criterion
18
+ from fairseq.dataclass import FairseqDataclass
19
+ from fairseq.data.data_utils import post_process
20
+ from fairseq.tasks import FairseqTask
21
+ from fairseq.logging.meters import safe_round
22
+
23
+ import logging
24
+ logger = logging.getLogger(__name__)
25
+
26
+ @dataclass
27
+ class SpeechtoTextLossConfig(FairseqDataclass):
28
+ zero_infinity: bool = field(
29
+ default=False,
30
+ metadata={"help": "zero inf loss when source length <= target length"},
31
+ )
32
+ sentence_avg: bool = II("optimization.sentence_avg")
33
+ post_process: Optional[str] = field(
34
+ default="sentencepiece",
35
+ metadata={
36
+ "help": "how to post process predictions into words. can be letter, "
37
+ "wordpiece, BPE symbols, etc. "
38
+ "See fairseq.data.data_utils.post_process() for full list of options"
39
+ },
40
+ )
41
+ wer_kenlm_model: Optional[str] = field(
42
+ default=None,
43
+ metadata={
44
+ "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
45
+ },
46
+ )
47
+ wer_lexicon: Optional[str] = field(
48
+ default=None,
49
+ metadata={"help": "lexicon to use with wer_kenlm_model"},
50
+ )
51
+ wer_lm_weight: float = field(
52
+ default=2.0,
53
+ metadata={"help": "lm weight to use with wer_kenlm_model"},
54
+ )
55
+ wer_word_score: float = field(
56
+ default=-1.0,
57
+ metadata={"help": "lm word score to use with wer_kenlm_model"},
58
+ )
59
+
60
+ wer_args: Optional[str] = field(
61
+ default=None,
62
+ metadata={
63
+ "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
64
+ },
65
+ )
66
+
67
+ label_smoothing: float = field(
68
+ default=0.0,
69
+ metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
70
+ )
71
+ report_accuracy: bool = field(
72
+ default=False,
73
+ metadata={"help": "report accuracy metric"},
74
+ )
75
+ ignore_prefix_size: int = field(
76
+ default=0,
77
+ metadata={"help": "Ignore first N tokens"},
78
+ )
79
+ #: bool = II("optimization.sentence_avg")
80
+
81
+ ce_weight: float = field(
82
+ default=1.0,
83
+ metadata={"help": "loss weight for cross entropy"},
84
+ )
85
+ ctc_weight: float = field(
86
+ default=0.0,
87
+ metadata={"help": "loss weiehgt for ctc in ASR"},
88
+ )
89
+
90
+
91
+ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
92
+ if target.dim() == lprobs.dim() - 1:
93
+ target = target.unsqueeze(-1)
94
+ nll_loss = -lprobs.gather(dim=-1, index=target)
95
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
96
+ if ignore_index is not None:
97
+ pad_mask = target.eq(ignore_index)
98
+ nll_loss.masked_fill_(pad_mask, 0.0)
99
+ smooth_loss.masked_fill_(pad_mask, 0.0)
100
+ else:
101
+ nll_loss = nll_loss.squeeze(-1)
102
+ smooth_loss = smooth_loss.squeeze(-1)
103
+ if reduce:
104
+ nll_loss = nll_loss.sum()
105
+ smooth_loss = smooth_loss.sum()
106
+ eps_i = epsilon / (lprobs.size(-1) - 1)
107
+ loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
108
+ return loss, nll_loss
109
+
110
+
111
+ class SpeechtoTextLoss(FairseqCriterion):
112
+ def __init__(
113
+ self,
114
+ cfg: SpeechtoTextLossConfig,
115
+ task: FairseqTask,
116
+ sentence_avg=True,
117
+ label_smoothing=0.1,
118
+ ignore_prefix_size=0,
119
+ report_accuracy=False,
120
+ ce_weight=1.0,
121
+ ctc_weight=0.0,
122
+ ):
123
+
124
+ super().__init__(task)
125
+ self.blank_idx = (
126
+ task.target_dictionary.index(task.blank_symbol)
127
+ if hasattr(task, "blank_symbol")
128
+ else 0
129
+ )
130
+ #print ("self.blank_idx: ", self.blank_idx)
131
+
132
+ self.pad_idx = task.target_dictionary.pad()
133
+ self.eos_idx = task.target_dictionary.eos()
134
+ self.post_process = cfg.post_process
135
+ self.ce_weight = ce_weight
136
+ self.ctc_weight = ctc_weight
137
+
138
+ ## for ce
139
+ self.sentence_avg = sentence_avg
140
+ self.eps = label_smoothing
141
+ self.ignore_prefix_size = ignore_prefix_size
142
+ self.report_accuracy = report_accuracy
143
+
144
+ if cfg.wer_args is not None:
145
+ (
146
+ cfg.wer_kenlm_model,
147
+ cfg.wer_lexicon,
148
+ cfg.wer_lm_weight,
149
+ cfg.wer_word_score,
150
+ ) = eval(cfg.wer_args)
151
+
152
+ if cfg.wer_kenlm_model is not None:
153
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
154
+
155
+ dec_args = Namespace()
156
+ dec_args.nbest = 1
157
+ dec_args.criterion = "ctc"
158
+ dec_args.kenlm_model = cfg.wer_kenlm_model
159
+ dec_args.lexicon = cfg.wer_lexicon
160
+ dec_args.beam = 50
161
+ dec_args.beam_size_token = min(50, len(task.target_dictionary))
162
+ dec_args.beam_threshold = min(50, len(task.target_dictionary))
163
+ dec_args.lm_weight = cfg.wer_lm_weight
164
+ dec_args.word_score = cfg.wer_word_score
165
+ dec_args.unk_weight = -math.inf
166
+ dec_args.sil_weight = 0
167
+
168
+ self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
169
+ else:
170
+ self.w2l_decoder = None
171
+
172
+ self.zero_infinity = cfg.zero_infinity
173
+ #self.sentence_avg = cfg.sentence_avg
174
+
175
+ if self.ce_weight > 0 and self.ctc_weight > 0:
176
+ logger.info("Using cross entropy loss and CTC loss for ASR")
177
+ elif self.ce_weight > 0:
178
+ logger.info("Only using CE loss")
179
+ elif self.ctc_weight > 0:
180
+ logger.info("Only using CTC loss for ASR")
181
+ else:
182
+ logger.info("ERROR")
183
+
184
+ def forward(self, model, sample, reduce=True):
185
+
186
+ if self.ce_weight == 0 and self.ctc_weight > 0:
187
+ sample["only_ctc"] = True
188
+
189
+ net_output_decoder, net_output = model(**sample["net_input"])
190
+
191
+ if self.ce_weight > 0:
192
+ loss_ce, nll_loss_ce = self.compute_loss(model, net_output_decoder, sample, reduce=reduce)
193
+ #print ("loss_ce: ", loss_ce)
194
+ else:
195
+ nll_loss_ce = None
196
+
197
+ if self.ctc_weight > 0:
198
+ loss_ctc, lprobs, input_lengths = self.compute_loss_ctc(model, net_output, sample)
199
+
200
+ if self.ce_weight > 0 and self.ctc_weight > 0:
201
+ loss = self.ce_weight * loss_ce + self.ctc_weight * loss_ctc
202
+ elif self.ce_weight > 0:
203
+ loss = loss_ce
204
+ elif self.ctc_weight > 0:
205
+ loss = loss_ctc
206
+ else:
207
+ logger.info("ERROR: must ce_weight > 0 or ctc_weight > 0")
208
+
209
+ ntokens = (
210
+ sample["ntokens"] if "ntokens" in sample else sample["target_lengths"].sum().item()
211
+ )
212
+
213
+ sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
214
+
215
+ logging_output = {
216
+ "loss": loss.item(),
217
+ "ce_loss": loss_ce.item() if self.ce_weight > 0 else 0,
218
+ "ctc_loss": loss_ctc.item() if self.ctc_weight > 0 else 0,
219
+ "nll_loss": nll_loss_ce.item() if nll_loss_ce is not None else 0,
220
+ "ntokens": sample["ntokens"],
221
+ "nsentences": sample["target"].size(0),
222
+ "sample_size": sample_size,
223
+ }
224
+
225
+ if self.ce_weight > 0 and self.report_accuracy:
226
+ n_correct, total = self.compute_accuracy(model, net_output_decoder, sample)
227
+ logging_output["n_correct"] = utils.item(n_correct.item())
228
+ logging_output["total"] = utils.item(total.data)
229
+
230
+ if self.ctc_weight > 0 and not model.training:
231
+ import editdistance
232
+
233
+ with torch.no_grad():
234
+ lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
235
+
236
+ c_err = 0
237
+ c_len = 0
238
+ w_errs = 0
239
+ w_len = 0
240
+ wv_errs = 0
241
+ for lp, t, inp_l in zip(
242
+ lprobs_t,
243
+ sample["target_label"]
244
+ if "target_label" in sample
245
+ else sample["target"],
246
+ input_lengths,
247
+ ):
248
+ lp = lp[:inp_l].unsqueeze(0)
249
+
250
+ decoded = None
251
+ if self.w2l_decoder is not None:
252
+ decoded = self.w2l_decoder.decode(lp)
253
+ if len(decoded) < 1:
254
+ decoded = None
255
+ else:
256
+ decoded = decoded[0]
257
+ if len(decoded) < 1:
258
+ decoded = None
259
+ else:
260
+ decoded = decoded[0]
261
+
262
+ p = (t != self.task.target_dictionary.pad()) & (
263
+ t != self.task.target_dictionary.eos()
264
+ )
265
+ targ = t[p]
266
+ targ_units = self.task.target_dictionary.string(targ)
267
+ targ_units_arr = targ.tolist()
268
+
269
+ toks = lp.argmax(dim=-1).unique_consecutive()
270
+ pred_units_arr = toks[toks != self.blank_idx].tolist()
271
+
272
+ c_err += editdistance.eval(pred_units_arr, targ_units_arr)
273
+ c_len += len(targ_units_arr)
274
+
275
+ targ_words = post_process(targ_units, self.post_process).split()
276
+
277
+ pred_units = self.task.target_dictionary.string(pred_units_arr)
278
+ pred_words_raw = post_process(pred_units, self.post_process).split()
279
+
280
+ if decoded is not None and "words" in decoded:
281
+ pred_words = decoded["words"]
282
+ w_errs += editdistance.eval(pred_words, targ_words)
283
+ wv_errs += editdistance.eval(pred_words_raw, targ_words)
284
+ else:
285
+ dist = editdistance.eval(pred_words_raw, targ_words)
286
+ w_errs += dist
287
+ wv_errs += dist
288
+
289
+ w_len += len(targ_words)
290
+
291
+ logging_output["wv_errors"] = wv_errs
292
+ logging_output["w_errors"] = w_errs
293
+ logging_output["w_total"] = w_len
294
+ logging_output["c_errors"] = c_err
295
+ logging_output["c_total"] = c_len
296
+
297
+ return loss, sample_size, logging_output
298
+
299
+ def compute_loss_ctc(self, model, net_output, sample):
300
+ lprobs = model.get_normalized_probs_for_ctc(
301
+ net_output, log_probs=True
302
+ ).contiguous() # (T, B, C) from the encoder
303
+
304
+ if net_output["encoder_padding_mask"] is not None:
305
+ non_padding_mask = ~net_output["encoder_padding_mask"][0]
306
+ input_lengths = non_padding_mask.long().sum(-1)
307
+ else:
308
+ input_lengths = lprobs.new_full(
309
+ (lprobs.size(1),), lprobs.size(0), dtype=torch.long
310
+ )
311
+
312
+ pad_mask = (sample["target"] != self.pad_idx) & (
313
+ sample["target"] != self.eos_idx
314
+ )
315
+ targets_flat = sample["target"].masked_select(pad_mask)
316
+ if "target_lengths" in sample:
317
+ target_lengths = sample["target_lengths"]
318
+ else:
319
+ target_lengths = pad_mask.sum(-1)
320
+
321
+ ##processing
322
+ target_lengths = target_lengths - 1
323
+
324
+ with torch.backends.cudnn.flags(enabled=False):
325
+ loss_ctc = F.ctc_loss(
326
+ lprobs,
327
+ targets_flat,
328
+ input_lengths,
329
+ target_lengths,
330
+ blank=self.blank_idx,
331
+ reduction="sum",
332
+ zero_infinity=True,
333
+ )
334
+
335
+ return loss_ctc, lprobs, input_lengths
336
+
337
+ ## for ce
338
+ def get_lprobs_and_target(self, model, net_output, sample):
339
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
340
+ target = model.get_targets(sample, net_output)
341
+ if self.ignore_prefix_size > 0:
342
+ if getattr(lprobs, "batch_first", False):
343
+ lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
344
+ target = target[:, self.ignore_prefix_size :].contiguous()
345
+ else:
346
+ lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
347
+ target = target[self.ignore_prefix_size :, :].contiguous()
348
+ return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
349
+
350
+ def compute_loss(self, model, net_output, sample, reduce=True):
351
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
352
+ loss, nll_loss = label_smoothed_nll_loss(
353
+ lprobs,
354
+ target,
355
+ self.eps,
356
+ ignore_index=self.padding_idx,
357
+ reduce=reduce,
358
+ )
359
+ return loss, nll_loss
360
+
361
+ def compute_accuracy(self, model, net_output, sample):
362
+ lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
363
+ mask = target.ne(self.padding_idx)
364
+ n_correct = torch.sum(
365
+ lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
366
+ )
367
+ total = torch.sum(mask)
368
+ return n_correct, total
369
+
370
+
371
+ @staticmethod
372
+ def reduce_metrics(logging_outputs) -> None:
373
+ """Aggregate logging outputs from data parallel training."""
374
+
375
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
376
+ nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
377
+ ce_loss_sum = sum(log.get("ce_loss", 0) for log in logging_outputs)
378
+ ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
379
+ ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
380
+ nsentences = utils.item(
381
+ sum(log.get("nsentences", 0) for log in logging_outputs)
382
+ )
383
+ sample_size = utils.item(
384
+ sum(log.get("sample_size", 0) for log in logging_outputs)
385
+ )
386
+
387
+ metrics.log_scalar(
388
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
389
+ )
390
+
391
+ metrics.log_scalar(
392
+ "ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
393
+ )
394
+ metrics.log_scalar(
395
+ "ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3
396
+ )
397
+ metrics.log_scalar(
398
+ "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
399
+ )
400
+ metrics.log_derived(
401
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, 2)
402
+ )
403
+
404
+ total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
405
+ if total > 0:
406
+ metrics.log_scalar("total", total)
407
+ n_correct = utils.item(
408
+ sum(log.get("n_correct", 0) for log in logging_outputs)
409
+ )
410
+ metrics.log_scalar("n_correct", n_correct)
411
+ metrics.log_derived(
412
+ "accuracy",
413
+ lambda meters: round(
414
+ meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
415
+ )
416
+ if meters["total"].sum > 0
417
+ else float("nan"),
418
+ 2
419
+ )
420
+
421
+ metrics.log_scalar("ntokens", ntokens)
422
+ metrics.log_scalar("nsentences", nsentences)
423
+ if sample_size != ntokens:
424
+ metrics.log_scalar(
425
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
426
+ )
427
+
428
+ c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
429
+ metrics.log_scalar("_c_errors", c_errors)
430
+ c_total = sum(log.get("c_total", 0) for log in logging_outputs)
431
+ metrics.log_scalar("_c_total", c_total)
432
+ w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
433
+ metrics.log_scalar("_w_errors", w_errors)
434
+ wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
435
+ metrics.log_scalar("_wv_errors", wv_errors)
436
+ w_total = sum(log.get("w_total", 0) for log in logging_outputs)
437
+ metrics.log_scalar("_w_total", w_total)
438
+
439
+ if c_total > 0:
440
+ metrics.log_derived(
441
+ "uer",
442
+ lambda meters: safe_round(
443
+ meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
444
+ )
445
+ if meters["_c_total"].sum > 0
446
+ else float("nan"),
447
+ )
448
+ if w_total > 0:
449
+ metrics.log_derived(
450
+ "wer",
451
+ lambda meters: safe_round(
452
+ meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
453
+ )
454
+ if meters["_w_total"].sum > 0
455
+ else float("nan"),
456
+ )
457
+ metrics.log_derived(
458
+ "raw_wer",
459
+ lambda meters: safe_round(
460
+ meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
461
+ )
462
+ if meters["_w_total"].sum > 0
463
+ else float("nan"),
464
+ )
465
+
466
+ @staticmethod
467
+ def logging_outputs_can_be_summed() -> bool:
468
+ """
469
+ Whether the logging outputs returned by `forward` can be summed
470
+ across workers prior to calling `reduce_metrics`. Setting this
471
+ to True will improves distributed training speed.
472
+ """
473
+ return True
artst/criterions/text_pretrain_criterion.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ from dataclasses import dataclass, field
10
+ from typing import List, Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from fairseq import metrics, utils
15
+ from fairseq.criterions import FairseqCriterion, register_criterion
16
+ from fairseq.dataclass import FairseqDataclass
17
+ from omegaconf import II
18
+
19
+
20
+ @dataclass
21
+ class TextPretrainCriterionConfig(FairseqDataclass):
22
+ sentence_avg: bool = II("optimization.sentence_avg")
23
+ loss_weights: Optional[List[float]] = field(
24
+ default_factory=lambda: [0.1,],
25
+ metadata={"help": "weights for additional loss terms (not first one)"},
26
+ )
27
+ bart_weight: float = field(
28
+ default=1.0,
29
+ metadata={"help": "loss weight for cross entropy"},
30
+ )
31
+
32
+
33
+ class TextPretrainCriterion(FairseqCriterion):
34
+ def __init__(self, task, sentence_avg, bart_weight, loss_weights=None):
35
+ super().__init__(task)
36
+ self.sentence_avg = sentence_avg
37
+ self.loss_weights = loss_weights
38
+ self.bart_weight = bart_weight
39
+
40
+ def forward(self, model, sample, reduce=True):
41
+ """Compute the loss for the given sample.
42
+
43
+ Returns a tuple with three elements:
44
+ 1) the loss
45
+ 2) the sample size, which is used as the denominator for the gradient
46
+ 3) logging outputs to display while training
47
+ """
48
+ net_output, codebook_out, encoder_output = model(**sample["net_input"])
49
+ bart_loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
50
+ sample_size = (
51
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
52
+ )
53
+
54
+ loss = self.bart_weight * bart_loss
55
+ logging_output = {
56
+ "loss": loss.item(),
57
+ "ntokens": sample["ntokens"],
58
+ "nsentences": sample["target"].size(0),
59
+ "bart_loss": bart_loss.item(),
60
+ "sample_size": sample_size,
61
+ }
62
+
63
+ if "prob_perplexity" in codebook_out:
64
+ assert hasattr(model, "get_extra_losses")
65
+ extra_losses, names = model.get_extra_losses(codebook_out)
66
+ if torch.is_tensor(extra_losses):
67
+ extra_losses = [extra_losses]
68
+ names = [names]
69
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
70
+ self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
71
+ if len(self.loss_weights) > len(extra_losses):
72
+ modified_loss_weight = self.loss_weights[len(extra_losses):]
73
+ else:
74
+ modified_loss_weight = self.loss_weights
75
+
76
+ # assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
77
+ for p, n, coef in zip(extra_losses, names, modified_loss_weight):
78
+ # print(n + str(coef))
79
+ if coef != 0 and p is not None:
80
+ p = coef * p.float() * sample_size
81
+ loss += p
82
+ logging_output[f"loss_{n}"] = p.item()
83
+
84
+ if 'loss_prob_perplexity' in logging_output:
85
+ logging_output['code_perplexity'] = codebook_out['code_perplexity'].item()
86
+
87
+ return loss, sample_size, logging_output
88
+
89
+ def compute_loss(self, model, net_output, sample, reduce=True):
90
+ lprobs = model.get_normalized_probs(net_output, log_probs=True)
91
+ lprobs = lprobs.view(-1, lprobs.size(-1))
92
+ target = model.get_targets(sample, net_output).view(-1)
93
+ loss = F.nll_loss(
94
+ lprobs,
95
+ target,
96
+ ignore_index=self.padding_idx,
97
+ reduction="sum" if reduce else "none",
98
+ )
99
+ return loss, loss
100
+
101
+ @staticmethod
102
+ def reduce_metrics(logging_outputs) -> None:
103
+ """Aggregate logging outputs from data parallel training."""
104
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
105
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
106
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
107
+ bart_loss_sum = sum(log.get("bart_loss", 0) for log in logging_outputs)
108
+
109
+ # we divide by log(2) to convert the loss from base e to base 2
110
+ metrics.log_scalar(
111
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
112
+ )
113
+ metrics.log_scalar(
114
+ "bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
115
+ )
116
+ if sample_size != ntokens:
117
+ metrics.log_scalar(
118
+ "nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3
119
+ )
120
+ metrics.log_derived(
121
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
122
+ )
123
+ else:
124
+ metrics.log_derived(
125
+ "ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg)
126
+ )
127
+
128
+ if "loss_prob_perplexity" in logging_outputs[0].keys():
129
+ val = sum(log["loss_prob_perplexity"] for log in logging_outputs)
130
+ metrics.log_scalar("loss_prob_perplexity", val / sample_size / math.log(2), round=3)
131
+ if "code_perplexity" in logging_outputs[0].keys():
132
+ val = sum(log["code_perplexity"] for log in logging_outputs)
133
+ metrics.log_scalar("code_perplexity", val / len(logging_outputs), round=3)
134
+
135
+ @staticmethod
136
+ def logging_outputs_can_be_summed() -> bool:
137
+ """
138
+ Whether the logging outputs returned by `forward` can be summed
139
+ across workers prior to calling `reduce_metrics`. Setting this
140
+ to True will improves distributed training speed.
141
+ """
142
+ return True
artst/criterions/text_to_speech_loss.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ from dataclasses import dataclass, field
9
+
10
+ import torch
11
+ from fairseq import metrics, utils
12
+ from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
13
+ from fairseq.criterions import FairseqCriterion, register_criterion
14
+ from fairseq.dataclass import FairseqDataclass
15
+ from artst.models.modules.speech_encoder_prenet import SpeechEncoderPrenet
16
+ from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss
17
+ from omegaconf import II
18
+ from typing import Any
19
+
20
+
21
+ @dataclass
22
+ class TexttoSpeechLossConfig(FairseqDataclass):
23
+ use_masking: bool = field(
24
+ default=True,
25
+ metadata={"help": "Whether to use masking in calculation of loss"},
26
+ )
27
+ use_weighted_masking: bool = field(
28
+ default=False,
29
+ metadata={"help": "Whether to use weighted masking in calculation of loss"},
30
+ )
31
+ loss_type: str = field(
32
+ default="L1",
33
+ metadata={"help": "How to calc loss"},
34
+ )
35
+ bce_pos_weight: float = field(
36
+ default=5.0,
37
+ metadata={"help": "Positive sample weight in BCE calculation (only for use-masking=True)"},
38
+ )
39
+ bce_loss_lambda: float = field(
40
+ default=1.0,
41
+ metadata={"help": "Lambda in bce loss"},
42
+ )
43
+ use_guided_attn_loss: bool = field(
44
+ default=False,
45
+ metadata={"help": "Whether to use guided attention loss"},
46
+ )
47
+ guided_attn_loss_sigma: float = field(
48
+ default=0.4,
49
+ metadata={"help": "Sigma in guided attention loss"},
50
+ )
51
+ guided_attn_loss_lambda: float = field(
52
+ default=10.0,
53
+ metadata={"help": "Lambda in guided attention loss"},
54
+ )
55
+ num_layers_applied_guided_attn: int = field(
56
+ default=2,
57
+ metadata={"help": "Number of layers to be applied guided attention loss, if set -1, all of the layers will be applied."},
58
+ )
59
+ num_heads_applied_guided_attn: int = field(
60
+ default=2,
61
+ metadata={"help": "Number of heads in each layer to be applied guided attention loss, if set -1, all of the heads will be applied."},
62
+ )
63
+ modules_applied_guided_attn: Any = field(
64
+ default=("encoder-decoder",),
65
+ metadata={"help": "Module name list to be applied guided attention loss"},
66
+ )
67
+ sentence_avg: bool = II("optimization.sentence_avg")
68
+
69
+
70
+ class TexttoSpeechLoss(FairseqCriterion):
71
+ def __init__(
72
+ self,
73
+ task,
74
+ sentence_avg,
75
+ use_masking=True,
76
+ use_weighted_masking=False,
77
+ loss_type="L1",
78
+ bce_pos_weight=5.0,
79
+ bce_loss_lambda=1.0,
80
+ use_guided_attn_loss=False,
81
+ guided_attn_loss_sigma=0.4,
82
+ guided_attn_loss_lambda=1.0,
83
+ num_layers_applied_guided_attn=2,
84
+ num_heads_applied_guided_attn=2,
85
+ modules_applied_guided_attn=["encoder-decoder"],
86
+ ):
87
+ super().__init__(task)
88
+ self.sentence_avg = sentence_avg
89
+ self.use_masking = use_masking
90
+ self.use_weighted_masking = use_weighted_masking
91
+ self.loss_type = loss_type
92
+ self.bce_pos_weight = bce_pos_weight
93
+ self.bce_loss_lambda = bce_loss_lambda
94
+ self.use_guided_attn_loss = use_guided_attn_loss
95
+ self.guided_attn_loss_sigma = guided_attn_loss_sigma
96
+ self.guided_attn_loss_lambda = guided_attn_loss_lambda
97
+ # define loss function
98
+ self.criterion = Tacotron2Loss(
99
+ use_masking=use_masking,
100
+ use_weighted_masking=use_weighted_masking,
101
+ bce_pos_weight=bce_pos_weight,
102
+ )
103
+ if self.use_guided_attn_loss:
104
+ self.num_layers_applied_guided_attn = num_layers_applied_guided_attn
105
+ self.num_heads_applied_guided_attn = num_heads_applied_guided_attn
106
+ self.modules_applied_guided_attn = modules_applied_guided_attn
107
+ if self.use_guided_attn_loss:
108
+ self.attn_criterion = GuidedMultiHeadAttentionLoss(
109
+ sigma=guided_attn_loss_sigma,
110
+ alpha=guided_attn_loss_lambda,
111
+ )
112
+
113
+ def forward(self, model, sample):
114
+ """Compute the loss for the given sample.
115
+
116
+ Returns a tuple with three elements:
117
+ 1) the loss
118
+ 2) the sample size, which is used as the denominator for the gradient
119
+ 3) logging outputs to display while training
120
+ """
121
+ net_output = model(**sample["net_input"])
122
+ loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.compute_loss(model, net_output, sample)
123
+ # sample_size = (
124
+ # sample["target"].size(0) if self.sentence_avg else sample["nframes"]
125
+ # )
126
+ sample_size = 1
127
+ logging_output = {
128
+ "loss": loss.item(),
129
+ "l1_loss": l1_loss.item(),
130
+ "l2_loss": l2_loss.item(),
131
+ "bce_loss": bce_loss.item(),
132
+ "sample_size": 1,
133
+ "ntokens": sample["ntokens"],
134
+ "nsentences": sample["target"].size(0),
135
+ }
136
+
137
+ if enc_dec_attn_loss is not None:
138
+ logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item()
139
+
140
+ if hasattr(model, 'text_encoder_prenet'):
141
+ logging_output["encoder_alpha"] = model.text_encoder_prenet.encoder_prenet[-1].alpha.item()
142
+ logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item()
143
+ elif hasattr(model, "speech_encoder_prenet"):
144
+ logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item()
145
+ else:
146
+ if 'task' not in sample:
147
+ logging_output["encoder_alpha"] = model.encoder_prenet.encoder_prenet[-1].alpha.item()
148
+ logging_output["decoder_alpha"] = model.decoder_prenet.decoder_prenet[-1].alpha.item()
149
+
150
+ return loss, sample_size, logging_output
151
+
152
+ def compute_loss(self, model, net_output, sample):
153
+ before_outs, after_outs, logits, attn = net_output
154
+ labels = sample["labels"]
155
+ ys = sample["dec_target"]
156
+ olens = sample["dec_target_lengths"]
157
+ ilens = sample["src_lengths"]
158
+
159
+ # modifiy mod part of groundtruth
160
+ if model.reduction_factor > 1:
161
+ olens_in = olens.new([torch.div(olen, model.reduction_factor, rounding_mode='floor') for olen in olens])
162
+ olens = olens.new([olen - olen % model.reduction_factor for olen in olens])
163
+ max_olen = max(olens)
164
+ ys = ys[:, :max_olen]
165
+ labels = labels[:, :max_olen]
166
+ labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # make sure at least one frame has 1
167
+ # labels[:, -1] = 1.0
168
+ else:
169
+ olens_in = olens
170
+
171
+ # caluculate loss values
172
+ l1_loss, l2_loss, bce_loss = self.criterion(
173
+ after_outs, before_outs, logits, ys, labels, olens
174
+ )
175
+
176
+ # l1_loss = l1_loss / ys.size(2)
177
+ # l2_loss = l2_loss / ys.size(2)
178
+
179
+ if self.loss_type == "L1":
180
+ loss = l1_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss
181
+ elif self.loss_type == "L2":
182
+ loss = l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l2_loss
183
+ elif self.loss_type == "L1+L2":
184
+ loss = l1_loss + l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss + l2_loss
185
+ else:
186
+ raise ValueError("unknown --loss-type " + self.loss_type)
187
+
188
+ # calculate guided attention loss
189
+ enc_dec_attn_loss = None
190
+ if self.use_guided_attn_loss:
191
+ # calculate the input lengths of encoder, which is determined by encoder prenet
192
+ if hasattr(model, 'encoder_reduction_factor') and model.encoder_reduction_factor > 1:
193
+ ilens_in = ilens.new([ilen // model.encoder_reduction_factor for ilen in ilens])
194
+ else:
195
+ ilens_in = ilens
196
+ # work for speech to speech model's input
197
+ if "task_name" in sample and sample["task_name"] == "s2s":
198
+ m = None
199
+ if hasattr(model, 'encoder_prenet'):
200
+ m = model.encoder_prenet
201
+ elif hasattr(model, 'speech_encoder_prenet'):
202
+ m = model.speech_encoder_prenet
203
+ if m is not None and isinstance(m, SpeechEncoderPrenet):
204
+ ilens_in = m.get_src_lengths(ilens_in)
205
+ # calculate for encoder-decoder
206
+ if "encoder-decoder" in self.modules_applied_guided_attn:
207
+ attn = [att_l[:, : self.num_heads_applied_guided_attn] for att_l in attn]
208
+ att_ws = torch.cat(attn, dim=1) # (B, H*L, T_out, T_in)
209
+ enc_dec_attn_loss = self.attn_criterion(att_ws, ilens_in, olens_in)
210
+ loss = loss + enc_dec_attn_loss
211
+
212
+ return loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss
213
+
214
+ @classmethod
215
+ def reduce_metrics(cls, logging_outputs) -> None:
216
+ """Aggregate logging outputs from data parallel training."""
217
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
218
+ l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs)
219
+ l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs)
220
+ bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs)
221
+ sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs))
222
+ metrics.log_scalar(
223
+ "loss", loss_sum / sample_size, sample_size, 1, round=5
224
+ )
225
+ encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in logging_outputs)
226
+ decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in logging_outputs)
227
+ ngpu = sum(log.get("ngpu", 0) for log in logging_outputs)
228
+
229
+ metrics.log_scalar(
230
+ "l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
231
+ )
232
+ metrics.log_scalar(
233
+ "l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
234
+ )
235
+ metrics.log_scalar(
236
+ "bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
237
+ )
238
+ metrics.log_scalar(
239
+ "encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5
240
+ )
241
+ metrics.log_scalar(
242
+ "decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
243
+ )
244
+
245
+ if "enc_dec_attn_loss" in logging_outputs[0]:
246
+ enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs)
247
+ metrics.log_scalar(
248
+ "enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
249
+ )
250
+
251
+
252
+ @staticmethod
253
+ def logging_outputs_can_be_summed() -> bool:
254
+ """
255
+ Whether the logging outputs returned by `forward` can be summed
256
+ across workers prior to calling `reduce_metrics`. Setting this
257
+ to True will improves distributed training speed.
258
+ """
259
+ return True
260
+
261
+ class Tacotron2Loss(torch.nn.Module):
262
+ """Loss function module for Tacotron2."""
263
+
264
+ def __init__(
265
+ self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0
266
+ ):
267
+ """Initialize Tactoron2 loss module.
268
+
269
+ Args:
270
+ use_masking (bool): Whether to apply masking
271
+ for padded part in loss calculation.
272
+ use_weighted_masking (bool):
273
+ Whether to apply weighted masking in loss calculation.
274
+ bce_pos_weight (float): Weight of positive sample of stop token.
275
+
276
+ """
277
+ super(Tacotron2Loss, self).__init__()
278
+ assert (use_masking != use_weighted_masking) or not use_masking
279
+ self.use_masking = use_masking
280
+ self.use_weighted_masking = use_weighted_masking
281
+
282
+ # define criterions
283
+ # reduction = "none" if self.use_weighted_masking else "sum"
284
+ reduction = "none" if self.use_weighted_masking else "mean"
285
+ self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
286
+ self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
287
+ self.bce_criterion = torch.nn.BCEWithLogitsLoss(
288
+ reduction=reduction, pos_weight=torch.tensor(bce_pos_weight)
289
+ )
290
+
291
+ # NOTE(kan-bayashi): register pre hook function for the compatibility
292
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
293
+
294
+ def forward(self, after_outs, before_outs, logits, ys, labels, olens):
295
+ """Calculate forward propagation.
296
+
297
+ Args:
298
+ after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
299
+ before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
300
+ logits (Tensor): Batch of stop logits (B, Lmax).
301
+ ys (Tensor): Batch of padded target features (B, Lmax, odim).
302
+ labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax).
303
+ olens (LongTensor): Batch of the lengths of each target (B,).
304
+
305
+ Returns:
306
+ Tensor: L1 loss value.
307
+ Tensor: Mean square error loss value.
308
+ Tensor: Binary cross entropy loss value.
309
+
310
+ """
311
+ # make mask and apply it
312
+ if self.use_masking:
313
+ masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
314
+ ys = ys.masked_select(masks)
315
+ after_outs = after_outs.masked_select(masks)
316
+ before_outs = before_outs.masked_select(masks)
317
+ labels = labels.masked_select(masks[:, :, 0])
318
+ logits = logits.masked_select(masks[:, :, 0])
319
+
320
+ # calculate loss
321
+ l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys)
322
+ mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(
323
+ before_outs, ys
324
+ )
325
+ bce_loss = self.bce_criterion(logits, labels)
326
+
327
+ # make weighted mask and apply it
328
+ if self.use_weighted_masking:
329
+ masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
330
+ weights = masks.float() / masks.sum(dim=1, keepdim=True).float()
331
+ out_weights = weights.div(ys.size(0) * ys.size(2))
332
+ logit_weights = weights.div(ys.size(0))
333
+
334
+ # apply weight
335
+ l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum()
336
+ mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum()
337
+ bce_loss = (
338
+ bce_loss.mul(logit_weights.squeeze(-1))
339
+ .masked_select(masks.squeeze(-1))
340
+ .sum()
341
+ )
342
+
343
+ return l1_loss, mse_loss, bce_loss
344
+
345
+ def _load_state_dict_pre_hook(
346
+ self,
347
+ state_dict,
348
+ prefix,
349
+ local_metadata,
350
+ strict,
351
+ missing_keys,
352
+ unexpected_keys,
353
+ error_msgs,
354
+ ):
355
+ """Apply pre hook fucntion before loading state dict.
356
+
357
+ From v.0.6.1 `bce_criterion.pos_weight` param is registered as a parameter but
358
+ old models do not include it and as a result, it causes missing key error when
359
+ loading old model parameter. This function solve the issue by adding param in
360
+ state dict before loading as a pre hook function
361
+ of the `load_state_dict` method.
362
+
363
+ """
364
+ key = prefix + "bce_criterion.pos_weight"
365
+ if key not in state_dict:
366
+ state_dict[key] = self.bce_criterion.pos_weight
367
+
368
+ class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
369
+ """Guided attention loss function module for multi head attention.
370
+ Args:
371
+ sigma (float, optional): Standard deviation to control
372
+ how close attention to a diagonal.
373
+ alpha (float, optional): Scaling coefficient (lambda).
374
+ reset_always (bool, optional): Whether to always reset masks.
375
+ """
376
+
377
+ def forward(self, att_ws, ilens, olens):
378
+ """Calculate forward propagation.
379
+ Args:
380
+ att_ws (Tensor):
381
+ Batch of multi head attention weights (B, H, T_max_out, T_max_in).
382
+ ilens (LongTensor): Batch of input lenghts (B,).
383
+ olens (LongTensor): Batch of output lenghts (B,).
384
+ Returns:
385
+ Tensor: Guided attention loss value.
386
+ """
387
+ if self.guided_attn_masks is None:
388
+ self.guided_attn_masks = (
389
+ self._make_guided_attention_masks(ilens, olens)
390
+ .to(att_ws.device)
391
+ .unsqueeze(1)
392
+ )
393
+ if self.masks is None:
394
+ self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
395
+ losses = self.guided_attn_masks * att_ws
396
+ loss = torch.mean(losses.masked_select(self.masks))
397
+ if self.reset_always:
398
+ self._reset_masks()
399
+
400
+ return self.alpha * loss
401
+
402
+ def _make_guided_attention_masks(self, ilens, olens):
403
+ n_batches = len(ilens)
404
+ max_ilen = max(ilens)
405
+ max_olen = max(olens)
406
+ guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=olens.device)
407
+ for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
408
+ guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(
409
+ ilen, olen, self.sigma
410
+ )
411
+ return guided_attn_masks
412
+
413
+ @staticmethod
414
+ def _make_guided_attention_mask(ilen, olen, sigma):
415
+ grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=olen.device))
416
+ grid_x, grid_y = grid_x.float(), grid_y.float()
417
+ return 1.0 - torch.exp(
418
+ -((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))
419
+ )
420
+
421
+ @staticmethod
422
+ def _make_masks(ilens, olens):
423
+ in_masks = make_non_pad_mask(ilens).to(ilens.device) # (B, T_in)
424
+ out_masks = make_non_pad_mask(olens).to(olens.device) # (B, T_out)
425
+ return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
artst/data/__init__.py ADDED
File without changes
artst/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (140 Bytes). View file
 
artst/data/__pycache__/multitask_dataset.cpython-38.pyc ADDED
Binary file (8.91 kB). View file
 
artst/data/__pycache__/speech_dataset.cpython-38.pyc ADDED
Binary file (16.7 kB). View file
 
artst/data/__pycache__/speech_to_class_dataset.cpython-38.pyc ADDED
Binary file (8.35 kB). View file
 
artst/data/__pycache__/speech_to_speech_dataset.cpython-38.pyc ADDED
Binary file (9.76 kB). View file
 
artst/data/__pycache__/speech_to_text_dataset.cpython-38.pyc ADDED
Binary file (9.7 kB). View file
 
artst/data/__pycache__/text_dataset.cpython-38.pyc ADDED
Binary file (11.6 kB). View file
 
artst/data/__pycache__/text_to_speech_dataset.cpython-38.pyc ADDED
Binary file (12.1 kB). View file
 
artst/data/multitask_dataset.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import bisect
9
+
10
+ import logging
11
+ import numpy as np
12
+ from torch.utils.data.dataloader import default_collate
13
+ from fairseq.data import data_utils
14
+
15
+ from fairseq.data.fairseq_dataset import FairseqDataset
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class MultitaskDataset(FairseqDataset):
20
+ @staticmethod
21
+ def cumsum(sequence):
22
+ r, s = [], 0
23
+ for e in sequence:
24
+ curr_len = len(e)
25
+ r.append(curr_len + s)
26
+ s += curr_len
27
+ return r
28
+
29
+ def __init__(self, datasets, sample_ratios=1, batch_ratio=None):
30
+ super(MultitaskDataset, self).__init__()
31
+ assert len(datasets) > 0, "datasets should not be an empty iterable"
32
+ self.datasets = list(datasets)
33
+ if isinstance(sample_ratios, int):
34
+ sample_ratios = [sample_ratios] * len(self.datasets)
35
+ if batch_ratio is not None:
36
+ logger.info('batch ratio is ' + str(batch_ratio))
37
+ self.batch_ratio = batch_ratio
38
+ else:
39
+ self.batch_ratio = None
40
+ else:
41
+ logger.info('set sample ratio to ' + str(sample_ratios))
42
+ if batch_ratio is not None:
43
+ logger.info('batch ratio is ' + str(batch_ratio))
44
+ self.batch_ratio = batch_ratio
45
+ else:
46
+ self.batch_ratio = None
47
+ self.sample_ratios = sample_ratios
48
+ self._ordered_indices = None
49
+ self._update_size()
50
+
51
+ def __len__(self):
52
+ return self.cumulative_sizes[-1]
53
+
54
+ def __getitem__(self, idx):
55
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
56
+ sample = self.datasets[dataset_idx][sample_idx]
57
+ if isinstance(sample, dict):
58
+ sample["dataset_idx"] = dataset_idx
59
+ else:
60
+ sample = sample + (dataset_idx,)
61
+ return sample
62
+
63
+ def _update_size(self):
64
+ self.cumulative_sizes = self.cumsum(self.datasets)
65
+ self.real_sizes = [len(d) for d in self.datasets]
66
+
67
+ def _get_dataset_and_sample_index(self, idx: int):
68
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
69
+ if dataset_idx == 0:
70
+ sample_idx = idx
71
+ else:
72
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
73
+ sample_idx = sample_idx % self.real_sizes[dataset_idx]
74
+ return dataset_idx, sample_idx
75
+
76
+ def collater(self, samples, **extra_args):
77
+ # For now only supports datasets with same underlying collater implementations
78
+ if samples is not None and len(samples) > 0:
79
+ if isinstance(samples[0], dict):
80
+ dataset_idx = samples[0]["dataset_idx"]
81
+ else:
82
+ dataset_idx = samples[0][-1]
83
+ samples = [sample[:-1] for sample in samples]
84
+ else:
85
+ dataset_idx = 0
86
+
87
+ if hasattr(self.datasets[dataset_idx], "collater"):
88
+ return self.datasets[dataset_idx].collater(samples, **extra_args)
89
+ else:
90
+ return default_collate(samples, **extra_args)
91
+
92
+ def size(self, idx: int):
93
+ """
94
+ Return an example's size as a float or tuple.
95
+ """
96
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
97
+ return self.datasets[dataset_idx].size(sample_idx)
98
+
99
+ def num_tokens(self, index: int):
100
+ return np.max(self.size(index))
101
+
102
+ def attr(self, attr: str, index: int):
103
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
104
+ return getattr(self.datasets[dataset_idx], attr, None)
105
+
106
+ @property
107
+ def sizes(self):
108
+ _dataset_sizes = []
109
+ for ds in self.datasets:
110
+ if isinstance(ds.sizes, np.ndarray):
111
+ _dataset_sizes.append(ds.sizes)
112
+ else:
113
+ # Only support underlying dataset with single size array.
114
+ assert isinstance(ds.sizes, list)
115
+ _dataset_sizes.append(ds.sizes[0])
116
+ return np.concatenate(_dataset_sizes)
117
+
118
+ @property
119
+ def supports_prefetch(self):
120
+ return all(d.supports_prefetch for d in self.datasets)
121
+
122
+ def ordered_indices(self):
123
+ # ordered_indices = []
124
+ # for i, dataset in enumerate(self.datasets):
125
+ # indice = dataset.ordered_indices()
126
+ # ordered_indices.append(indice)
127
+ if self._ordered_indices is None:
128
+ # Call the underlying dataset's ordered_indices() here, so that we
129
+ # get the same random ordering as we would have from using the
130
+ # underlying sub-datasets directly.
131
+ self._ordered_indices = [
132
+ dataset.ordered_indices()
133
+ for dataset in self.datasets
134
+ ]
135
+ return np.arange(len(self))
136
+
137
+ def prefetch(self, indices):
138
+ frm = 0
139
+ for to, ds in zip(self.cumulative_sizes, self.datasets):
140
+ real_size = len(ds)
141
+ if getattr(ds, "supports_prefetch", False):
142
+ ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
143
+ frm = to
144
+
145
+ def batch_by_size(
146
+ self,
147
+ indices,
148
+ max_tokens=None,
149
+ max_sentences=None,
150
+ required_batch_size_multiple=1,
151
+ ):
152
+ if not hasattr(self, "max_tokens"):
153
+ self.max_tokens = max_tokens
154
+ if not hasattr(self, "max_sentences"):
155
+ self.max_sentences = max_sentences
156
+ if not hasattr(self, "required_batch_size_multiple"):
157
+ self.required_batch_size_multiple = required_batch_size_multiple
158
+ batch_samplers = []
159
+ for i, dataset in enumerate(self.datasets):
160
+ batch_sampler = dataset.batch_by_size(
161
+ self._ordered_indices[i],
162
+ max_tokens=max_tokens if self.batch_ratio is None else max_tokens * self.batch_ratio[i],
163
+ max_sentences=max_sentences,
164
+ required_batch_size_multiple=required_batch_size_multiple,
165
+ )
166
+ if i > 0:
167
+ for batch in batch_sampler:
168
+ batch += self.cumulative_sizes[i - 1]
169
+ if self.sample_ratios[i] != 1.0:
170
+ batch_sampler = np.array(batch_sampler)
171
+ batch_sampler = np.random.choice(batch_sampler, int(len(batch_sampler) * self.sample_ratios[i]))
172
+ batch_sampler = list(batch_sampler)
173
+ logger.info('Adjust batch by ratio ' + str(self.sample_ratios[i]) + ' and the number of batch is ' + str(int(len(batch_sampler))) + ' for dataset ' + str(i))
174
+ batch_samplers.extend(batch_sampler)
175
+ return batch_samplers
176
+
177
+ def filter_indices_by_size(self, indices, max_positions):
178
+ """
179
+ Filter each sub-dataset independently, then update the round robin to work
180
+ on the filtered sub-datasets.
181
+ """
182
+ if not hasattr(self, "max_positions"):
183
+ self.max_positions = max_positions
184
+ ignored_some = False
185
+ for i in range(len(self.datasets)):
186
+ # ignored = []
187
+ self._ordered_indices[i], ignored = self.datasets[i].filter_indices_by_size(
188
+ self._ordered_indices[i], self.max_positions[i]
189
+ )
190
+ if len(ignored) > 0:
191
+ ignored_some = True
192
+ logger.warning(
193
+ f"{len(ignored)} samples from {i} have invalid sizes and will be skipped, "
194
+ f"max_positions={self.max_positions[i]}, first few sample ids={ignored[:10]}"
195
+ )
196
+
197
+ logger.info('update dataset size')
198
+ self._update_size()
199
+
200
+ # Since we are modifying in place the _ordered_indices,
201
+ # it's not possible anymore to return valid ignored indices.
202
+ # Hopefully the extra debug information print above should be enough to debug.
203
+ # Ideally we would receive ignore_invalid_inputs so that we could have
204
+ # a proper error message.
205
+ return (np.arange(len(self)), [0] if ignored_some else [])
206
+
207
+ @property
208
+ def can_reuse_epoch_itr_across_epochs(self):
209
+ return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
210
+
211
+ def set_epoch(self, epoch):
212
+ super().set_epoch(epoch)
213
+ for ds in self.datasets:
214
+ if hasattr(ds, "set_epoch"):
215
+ ds.set_epoch(epoch)
216
+
217
+ def shuffle_batches(self, batches, seed):
218
+ logger.info("shuffle batches")
219
+ new_batches_fromlist = []
220
+ new_batches_notlist = []
221
+ new_batches = []
222
+ with data_utils.numpy_seed(seed):
223
+ np.random.shuffle(batches)
224
+ for batch in batches:
225
+ if isinstance(batch, list):
226
+ # np.random.shuffle(batch)
227
+ new_batches_fromlist.append(batch)
228
+ else:
229
+ new_batches_notlist.append(batch)
230
+ logger.info("Get " + str(len(new_batches_fromlist)) + " chunk from speech sides")
231
+ logger.info("Get " + str(sum([len(batch_list) for batch_list in new_batches_fromlist])) + " batches from speech sides")
232
+ logger.info("Get " + str(len(new_batches_notlist)) + " batches from text sides")
233
+ if len(new_batches_fromlist) == 0:
234
+ return new_batches_notlist
235
+ st_ratio = int(len(new_batches_notlist) / len(new_batches_fromlist))
236
+ logger.info("Get st_ratio " + str(st_ratio))
237
+ last_idx = 0
238
+ for i in range(len(new_batches_fromlist)):
239
+ if i == len(new_batches_fromlist) - 1:
240
+ new_batches_fromlist[i].extend(new_batches_notlist[last_idx:])
241
+ else:
242
+ new_batches_fromlist[i].extend(new_batches_notlist[last_idx : last_idx + st_ratio])
243
+ np.random.shuffle(new_batches_fromlist[i])
244
+ new_batches.extend(new_batches_fromlist[i])
245
+ last_idx = last_idx + st_ratio
246
+ logger.info("Finish shuffle")
247
+ return new_batches
248
+
249
+ def reset_batch_sampler(self):
250
+ logger.info("reset batch sampler")
251
+ self._ordered_indices = [
252
+ self.datasets[i].ordered_indices()
253
+ for i in range(len(self.datasets))
254
+ ]
255
+ self.filter_indices_by_size(None, None)
256
+
257
+ batch_samplers = self.batch_by_size(
258
+ None,
259
+ self.max_tokens,
260
+ self.max_sentences,
261
+ self.required_batch_size_multiple
262
+ )
263
+ return batch_samplers
artst/data/speech_dataset.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+
5
+ # Based on speecht5, fairseq and espnet code bases
6
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
7
+ # --------------------------------------------------------
8
+
9
+ import itertools
10
+ import logging
11
+ import os
12
+ import sys
13
+ from typing import Any, List, Optional, Union
14
+
15
+ import numpy as np
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import librosa
20
+ from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
21
+ from fairseq.data import data_utils
22
+ from fairseq.data.fairseq_dataset import FairseqDataset
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ def _collate_frames(
27
+ frames: List[torch.Tensor], is_audio_input: bool = False
28
+ ):
29
+ """
30
+ Convert a list of 2D frames into a padded 3D tensor
31
+ Args:
32
+ frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
33
+ length of i-th frame and f_dim is static dimension of features
34
+ Returns:
35
+ 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
36
+ """
37
+ max_len = max(frame.size(0) for frame in frames)
38
+ if is_audio_input:
39
+ out = frames[0].new_zeros((len(frames), max_len))
40
+ else:
41
+ out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
42
+ for i, v in enumerate(frames):
43
+ out[i, : v.size(0)] = v
44
+ return out
45
+
46
+ def add_first_frame_and_remove_last_frame(ys):
47
+ ys_in = torch.cat(
48
+ [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1
49
+ )
50
+ return ys_in
51
+
52
+ def load_audio(manifest_path, max_keep, min_keep):
53
+ n_long, n_short = 0, 0
54
+ names, inds, sizes, spk_embeds = [], [], [], []
55
+ with open(manifest_path) as f:
56
+ root = f.readline().strip()
57
+ for ind, line in enumerate(f):
58
+ items = line.strip().split("\t")
59
+ assert len(items) == 3, line
60
+ sz = int(items[1])
61
+ if min_keep is not None and sz < min_keep:
62
+ n_short += 1
63
+ elif max_keep is not None and sz > max_keep:
64
+ n_long += 1
65
+ else:
66
+ names.append(items[0])
67
+ spk_embeds.append(items[2])
68
+ inds.append(ind)
69
+ sizes.append(sz)
70
+ tot = ind + 1
71
+ logger.info(
72
+ (
73
+ f"max_keep={max_keep}, min_keep={min_keep}, "
74
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
75
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
76
+ )
77
+ )
78
+ return root, names, inds, tot, sizes, spk_embeds
79
+
80
+
81
+ def load_label(label_path, inds, tot):
82
+ with open(label_path) as f:
83
+ labels = [line.rstrip() for line in f]
84
+ assert (
85
+ len(labels) == tot
86
+ ), f"number of labels does not match ({len(labels)} != {tot})"
87
+ labels = [labels[i] for i in inds]
88
+ return labels
89
+
90
+
91
+ def load_label_offset(label_path, inds, tot):
92
+ with open(label_path) as f:
93
+ code_lengths = [len(line.encode("utf-8")) for line in f]
94
+ assert (
95
+ len(code_lengths) == tot
96
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
97
+ offsets = list(itertools.accumulate([0] + code_lengths))
98
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
99
+ return offsets
100
+
101
+
102
+ def verify_label_lengths(
103
+ audio_sizes,
104
+ audio_rate,
105
+ label_path,
106
+ label_rate,
107
+ inds,
108
+ tot,
109
+ tol=0.1, # tolerance in seconds
110
+ ):
111
+ if label_rate < 0:
112
+ logger.info(f"{label_path} is sequence label. skipped")
113
+ return
114
+
115
+ with open(label_path) as f:
116
+ lengths = [len(line.rstrip().split()) for line in f]
117
+ assert len(lengths) == tot
118
+ lengths = [lengths[i] for i in inds]
119
+ num_invalid = 0
120
+ for i, ind in enumerate(inds):
121
+ dur_from_audio = audio_sizes[i] / audio_rate
122
+ dur_from_label = lengths[i] / label_rate
123
+ if abs(dur_from_audio - dur_from_label) > tol:
124
+ logger.warning(
125
+ (
126
+ f"audio and label duration differ too much "
127
+ f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
128
+ f"in line {ind+1} of {label_path}. Check if `label_rate` "
129
+ f"is correctly set (currently {label_rate}). "
130
+ f"num. of samples = {audio_sizes[i]}; "
131
+ f"label length = {lengths[i]}"
132
+ )
133
+ )
134
+ num_invalid += 1
135
+ if num_invalid > 0:
136
+ logger.warning(
137
+ f"total {num_invalid} (audio, label) pairs with mismatched lengths"
138
+ )
139
+
140
+
141
+ def logmelfilterbank(
142
+ audio,
143
+ sampling_rate,
144
+ fft_size=1024,
145
+ hop_size=256,
146
+ win_length=None,
147
+ window="hann",
148
+ num_mels=80,
149
+ fmin=80,
150
+ fmax=7600,
151
+ eps=1e-10,
152
+ ):
153
+ """Compute log-Mel filterbank feature.
154
+ (https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
155
+
156
+ Args:
157
+ audio (ndarray): Audio signal (T,).
158
+ sampling_rate (int): Sampling rate.
159
+ fft_size (int): FFT size.
160
+ hop_size (int): Hop size.
161
+ win_length (int): Window length. If set to None, it will be the same as fft_size.
162
+ window (str): Window function type.
163
+ num_mels (int): Number of mel basis.
164
+ fmin (int): Minimum frequency in mel basis calculation.
165
+ fmax (int): Maximum frequency in mel basis calculation.
166
+ eps (float): Epsilon value to avoid inf in log calculation.
167
+
168
+ Returns:
169
+ ndarray: Log Mel filterbank feature (#frames, num_mels).
170
+
171
+ """
172
+ # get amplitude spectrogram
173
+ x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
174
+ win_length=win_length, window=window, pad_mode="reflect")
175
+ spc = np.abs(x_stft).T # (#frames, #bins)
176
+
177
+ # get mel basis
178
+ fmin = 0 if fmin is None else fmin
179
+ fmax = sampling_rate / 2 if fmax is None else fmax
180
+ mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
181
+
182
+ return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
183
+
184
+
185
+ class SpeechPretrainDataset(FairseqDataset):
186
+ def __init__(
187
+ self,
188
+ manifest_path: str,
189
+ sample_rate: float,
190
+ label_paths: List[str],
191
+ label_rates: Union[List[float], float], # -1 for sequence labels
192
+ pad_list: List[str],
193
+ eos_list: List[str],
194
+ label_processors: Optional[List[Any]] = None,
195
+ max_keep_sample_size: Optional[int] = None,
196
+ min_keep_sample_size: Optional[int] = None,
197
+ max_sample_size: Optional[int] = None,
198
+ shuffle: bool = True,
199
+ pad_audio: bool = False,
200
+ normalize: bool = False,
201
+ store_labels: bool = True,
202
+ random_crop: bool = False,
203
+ single_target: bool = False,
204
+ reduction_factor: int = 1,
205
+ ):
206
+ self.audio_root, self.audio_names, inds, tot, self.sizes, self.spk_embeds = load_audio(
207
+ manifest_path, max_keep_sample_size, min_keep_sample_size
208
+ )
209
+ self.sample_rate = sample_rate
210
+ self.shuffle = shuffle
211
+ self.random_crop = random_crop
212
+
213
+ self.num_labels = len(label_paths)
214
+ self.pad_list = pad_list
215
+ self.eos_list = eos_list
216
+ self.label_processors = label_processors
217
+ self.single_target = single_target
218
+ self.label_rates = (
219
+ [label_rates for _ in range(len(label_paths))]
220
+ if isinstance(label_rates, float)
221
+ else label_rates
222
+ )
223
+ self.store_labels = store_labels
224
+ if store_labels:
225
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
226
+ else:
227
+ self.label_paths = label_paths
228
+ self.label_offsets_list = [
229
+ load_label_offset(p, inds, tot) for p in label_paths
230
+ ]
231
+ assert label_processors is None or len(label_processors) == self.num_labels
232
+ for label_path, label_rate in zip(label_paths, self.label_rates):
233
+ verify_label_lengths(
234
+ self.sizes, sample_rate, label_path, label_rate, inds, tot
235
+ )
236
+
237
+ self.max_sample_size = (
238
+ max_sample_size if max_sample_size is not None else sys.maxsize
239
+ )
240
+ self.pad_audio = pad_audio
241
+ self.normalize = normalize
242
+ self.reduction_factor = reduction_factor
243
+ logger.info(
244
+ f"pad_audio={pad_audio}, random_crop={random_crop}, reduction_factor={reduction_factor}, "
245
+ f"normalize={normalize}, max_sample_size={self.max_sample_size}"
246
+ )
247
+
248
+ def get_audio(self, index):
249
+ import soundfile as sf
250
+
251
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
252
+ wav, cur_sample_rate = sf.read(wav_path)
253
+ wav = torch.from_numpy(wav).float()
254
+ fbank = logmelfilterbank(
255
+ wav.view(-1).cpu().numpy(), 16000
256
+ )
257
+ fbank = torch.from_numpy(fbank).float()
258
+ wav = self.postprocess(wav, cur_sample_rate)
259
+ return wav, fbank
260
+
261
+ def get_label(self, index, label_idx):
262
+ if self.store_labels:
263
+ label = self.label_list[label_idx][index]
264
+ else:
265
+ with open(self.label_paths[label_idx]) as f:
266
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
267
+ f.seek(offset_s)
268
+ label = f.read(offset_e - offset_s)
269
+
270
+ if self.label_processors is not None:
271
+ label = self.label_processors[label_idx](label)
272
+ return label
273
+
274
+ def get_labels(self, index):
275
+ return [self.get_label(index, i) for i in range(self.num_labels)]
276
+
277
+ def __getitem__(self, index):
278
+ wav, fbank = self.get_audio(index)
279
+ labels = self.get_labels(index)
280
+ spkembs = get_features_or_waveform(
281
+ os.path.join(self.audio_root, self.spk_embeds[index])
282
+ )
283
+ spkembs = torch.from_numpy(spkembs).float()
284
+ return {"id": index, "source": wav, "target": fbank, "label_list": labels, 'spkembs': spkembs}
285
+
286
+ def __len__(self):
287
+ return len(self.sizes)
288
+
289
+ def crop_to_max_size(self, wav, target_size):
290
+ size = len(wav)
291
+ diff = size - target_size
292
+ if diff <= 0:
293
+ return wav, 0
294
+
295
+ start, end = 0, target_size
296
+ if self.random_crop:
297
+ start = np.random.randint(0, diff + 1)
298
+ end = size - diff + start
299
+ return wav[start:end], start
300
+
301
+ def collater(self, samples):
302
+ # target = max(sizes) -> random_crop not used
303
+ # target = max_sample_size -> random_crop used for long
304
+ samples = [s for s in samples if s["source"] is not None]
305
+ if len(samples) == 0:
306
+ return {}
307
+
308
+ audios = [s["source"] for s in samples]
309
+ audio_sizes = [len(s) for s in audios]
310
+
311
+ fbanks = [s["target"] for s in samples]
312
+ fbank_sizes = [len(s) for s in fbanks]
313
+
314
+ if self.pad_audio:
315
+ audio_size = min(max(audio_sizes), self.max_sample_size)
316
+ else:
317
+ audio_size = min(min(audio_sizes), self.max_sample_size)
318
+ collated_audios, padding_mask, audio_starts = self.collater_audio(
319
+ audios, audio_size
320
+ )
321
+
322
+ collated_fbanks = []
323
+ collated_audios_size = []
324
+ for i in range(len(fbanks)):
325
+ fbank_start = int(audio_starts[i] / (audio_sizes[i] / fbank_sizes[i]))
326
+ fbank_size = int(audio_size / (audio_sizes[i] / fbank_sizes[i]))
327
+ fbank_end = min(fbank_start + fbank_size, fbank_sizes[i])
328
+ collated_fbanks.append(fbanks[i][fbank_start : fbank_end])
329
+ collated_audios_size.append(audio_size)
330
+ collated_fbanks_size = [len(s) for s in collated_fbanks]
331
+ collated_fbanks = _collate_frames(collated_fbanks)
332
+ collated_fbanks_size = torch.tensor(collated_fbanks_size, dtype=torch.long)
333
+
334
+ # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
335
+ if self.reduction_factor > 1:
336
+ collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
337
+ collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
338
+ else:
339
+ collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
340
+
341
+ prev_output_tokens = torch.cat(
342
+ [collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
343
+ )
344
+
345
+ # make labels for stop prediction
346
+ labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
347
+ for i, l in enumerate(fbank_sizes):
348
+ labels[i, l - 1 :] = 1.0
349
+
350
+ spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
351
+
352
+ targets_by_label = [
353
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
354
+ ]
355
+ targets_list, lengths_list, ntokens_list = self.collater_label(
356
+ targets_by_label, audio_size, audio_starts
357
+ )
358
+
359
+ net_input = {
360
+ "source": collated_audios,
361
+ "padding_mask": padding_mask,
362
+ "prev_output_tokens": prev_output_tokens,
363
+ "spkembs": spkembs,
364
+ "tgt_lengths": collated_fbanks_size_in,
365
+ }
366
+
367
+ batch = {
368
+ "id": torch.LongTensor([s["id"] for s in samples]),
369
+ "net_input": net_input,
370
+ "labels": labels,
371
+ "dec_target": collated_fbanks,
372
+ "dec_target_lengths": collated_fbanks_size,
373
+ "src_lengths": collated_audios_size,
374
+ "task_name": 'speech_pretrain',
375
+ }
376
+
377
+ if self.single_target:
378
+ batch["target_lengths"] = lengths_list[0]
379
+ batch["ntokens"] = ntokens_list[0]
380
+ batch["target"] = targets_list[0]
381
+ else:
382
+ batch["target_lengths_list"] = lengths_list
383
+ batch["ntokens_list"] = ntokens_list
384
+ batch["target_list"] = targets_list
385
+ return batch
386
+
387
+ def collater_audio(self, audios, audio_size):
388
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
389
+ padding_mask = (
390
+ torch.BoolTensor(collated_audios.shape).fill_(False)
391
+ # if self.pad_audio else None
392
+ )
393
+ audio_starts = [0 for _ in audios]
394
+ for i, audio in enumerate(audios):
395
+ diff = len(audio) - audio_size
396
+ if diff == 0:
397
+ collated_audios[i] = audio
398
+ elif diff < 0:
399
+ assert self.pad_audio
400
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
401
+ padding_mask[i, diff:] = True
402
+ else:
403
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
404
+ audio, audio_size
405
+ )
406
+ return collated_audios, padding_mask, audio_starts
407
+
408
+ def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
409
+ assert label_rate > 0
410
+ s2f = label_rate / self.sample_rate
411
+ frm_starts = [int(round(s * s2f)) for s in audio_starts]
412
+ frm_size = int(round(audio_size * s2f))
413
+ if not self.pad_audio:
414
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
415
+ frm_size = min(frm_size, *rem_size)
416
+ targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
417
+ logger.debug(f"audio_starts={audio_starts}")
418
+ logger.debug(f"frame_starts={frm_starts}")
419
+ logger.debug(f"frame_size={frm_size}")
420
+
421
+ lengths = torch.LongTensor([len(t) for t in targets])
422
+ ntokens = lengths.sum().item()
423
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
424
+ return targets, lengths, ntokens
425
+
426
+ def collater_seq_label(self, targets, pad):
427
+ lengths = torch.LongTensor([len(t) for t in targets])
428
+ ntokens = lengths.sum().item()
429
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
430
+ return targets, lengths, ntokens
431
+
432
+ def collater_label(self, targets_by_label, audio_size, audio_starts):
433
+ targets_list, lengths_list, ntokens_list = [], [], []
434
+ itr = zip(targets_by_label, self.label_rates, self.pad_list)
435
+ for targets, label_rate, pad in itr:
436
+ if label_rate == -1.0:
437
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
438
+ else:
439
+ targets, lengths, ntokens = self.collater_frm_label(
440
+ targets, audio_size, audio_starts, label_rate, pad
441
+ )
442
+ targets_list.append(targets)
443
+ lengths_list.append(lengths)
444
+ ntokens_list.append(ntokens)
445
+ return targets_list, lengths_list, ntokens_list
446
+
447
+ def num_tokens(self, index):
448
+ return self.size(index)
449
+
450
+ def size(self, index):
451
+ if self.pad_audio:
452
+ return self.sizes[index]
453
+ return min(self.sizes[index], self.max_sample_size)
454
+
455
+ def ordered_indices(self):
456
+ if self.shuffle:
457
+ order = [np.random.permutation(len(self))]
458
+ else:
459
+ order = [np.arange(len(self))]
460
+
461
+ order.append(self.sizes)
462
+ return np.lexsort(order)[::-1]
463
+
464
+ def postprocess(self, wav, cur_sample_rate):
465
+ if wav.dim() == 2:
466
+ wav = wav.mean(-1)
467
+ assert wav.dim() == 1, wav.dim()
468
+
469
+ if cur_sample_rate != self.sample_rate:
470
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
471
+
472
+ if self.normalize:
473
+ with torch.no_grad():
474
+ wav = F.layer_norm(wav, wav.shape)
475
+ return wav
artst/data/speech_to_class_dataset.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import logging
9
+ import os
10
+ from typing import Any, List, Optional
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq.data import data_utils, Dictionary
17
+ from fairseq.data.fairseq_dataset import FairseqDataset
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def load_audio(manifest_path, max_keep, min_keep):
23
+ """manifest tsv: wav_path, wav_nframe, wav_class
24
+
25
+ Args
26
+ manifest_path: str
27
+ max_keep: int
28
+ min_keep: int
29
+
30
+ Return
31
+ root, names, inds, tot, sizes, classes
32
+ """
33
+ n_long, n_short = 0, 0
34
+ names, inds, sizes, classes = [], [], [], []
35
+ with open(manifest_path) as f:
36
+ root = f.readline().strip()
37
+ for ind, line in enumerate(f):
38
+ items = line.strip().split("\t")
39
+ assert len(items) >= 2, line
40
+ sz = int(items[1])
41
+ if min_keep is not None and sz < min_keep:
42
+ n_short += 1
43
+ elif max_keep is not None and sz > max_keep:
44
+ n_long += 1
45
+ else:
46
+ names.append(items[0])
47
+ if len(items) > 2:
48
+ classes.append(items[2])
49
+ inds.append(ind)
50
+ sizes.append(sz)
51
+ tot = ind + 1
52
+ logger.info(
53
+ (
54
+ f"max_keep={max_keep}, min_keep={min_keep}, "
55
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
56
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
57
+ )
58
+ )
59
+ if len(classes) == 0:
60
+ logger.warn("no classes loaded only if inference")
61
+ return root, names, inds, tot, sizes, classes
62
+
63
+
64
+ def sample_from_feature(x: np.ndarray, max_segment_length: int = 300):
65
+ """Load a segment within 300-400/51200-76800 frames or the corresponding samples from a utterance.
66
+
67
+ Args:
68
+ x (np.ndarray): feature or waveform (frames[, features]), e.g., log mel filter bank or waveform
69
+ max_segment_length (int, optional): maximum segment length. Defaults to 400.
70
+
71
+ Returns:
72
+ np.ndarray: segmented features
73
+ """
74
+ if len(x) <= max_segment_length:
75
+ return x
76
+ start = np.random.randint(0, x.shape[0] - max_segment_length)
77
+ return x[start: start + max_segment_length]
78
+
79
+
80
+ class SpeechToClassDataset(FairseqDataset):
81
+ def __init__(
82
+ self,
83
+ manifest_path: str,
84
+ sample_rate: float,
85
+ label_processors: Optional[List[Any]] = None,
86
+ max_keep_sample_size: Optional[int] = None,
87
+ min_keep_sample_size: Optional[int] = None,
88
+ shuffle: bool = True,
89
+ normalize: bool = False,
90
+ tgt_dict: Optional[Dictionary] = None,
91
+ max_length: Optional[int] = None
92
+ ):
93
+ self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.wav_classes = load_audio(
94
+ manifest_path, max_keep_sample_size, min_keep_sample_size
95
+ )
96
+ self.sample_rate = sample_rate
97
+ self.shuffle = shuffle
98
+
99
+ self.label_processors = label_processors
100
+
101
+ self.normalize = normalize
102
+ self.tgt_dict = tgt_dict
103
+ self.max_length = max_length
104
+ logger.info(
105
+ f"max_length={max_length}, normalize={normalize}"
106
+ )
107
+
108
+ def get_audio(self, index):
109
+ import soundfile as sf
110
+
111
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
112
+ wav, cur_sample_rate = sf.read(wav_path)
113
+ if self.max_length is not None:
114
+ wav = sample_from_feature(wav, self.max_length)
115
+ wav = torch.from_numpy(wav).float()
116
+ wav = self.postprocess(wav, cur_sample_rate)
117
+ return wav
118
+
119
+ def get_label(self, index):
120
+ label = self.wav_classes[index]
121
+
122
+ if self.label_processors is not None:
123
+ label = self.label_processors(label)
124
+ return label
125
+
126
+ def __getitem__(self, index):
127
+ wav = self.get_audio(index)
128
+ label = None
129
+ if len(self.wav_classes) == len(self.audio_names):
130
+ label = self.get_label(index)
131
+ return {"id": index, "source": wav, "label": label}
132
+
133
+ def __len__(self):
134
+ return len(self.wav_sizes)
135
+
136
+ def collater(self, samples):
137
+ samples = [s for s in samples if s["source"] is not None]
138
+ if len(samples) == 0:
139
+ return {}
140
+
141
+ audios = [s["source"] for s in samples]
142
+ audio_sizes = [len(s) for s in audios]
143
+
144
+ audio_size = max(audio_sizes)
145
+ collated_audios, padding_mask = self.collater_audio(
146
+ audios, audio_size
147
+ )
148
+
149
+ decoder_label = None
150
+ decoder_target = None
151
+ decoder_target_lengths = None
152
+ if samples[0]["label"] is not None:
153
+ targets_by_label = [
154
+ [s["label"] for s in samples]
155
+ ]
156
+ targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label)
157
+
158
+ decoder_label = [
159
+ (targets_list[0][i, :lengths_list[0][i]]).long()
160
+ for i in range(targets_list[0].size(0))
161
+ ]
162
+
163
+ decoder_target = data_utils.collate_tokens(
164
+ decoder_label,
165
+ self.tgt_dict.pad(),
166
+ self.tgt_dict.eos(),
167
+ left_pad=False,
168
+ move_eos_to_beginning=False,
169
+ )
170
+ decoder_target_lengths = torch.tensor(
171
+ [x.size(0) for x in decoder_label], dtype=torch.long
172
+ )
173
+ prev_output_tokens = data_utils.collate_tokens(
174
+ [torch.LongTensor([-1]) for _ in samples],
175
+ self.tgt_dict.pad(),
176
+ self.tgt_dict.eos(),
177
+ left_pad=False,
178
+ move_eos_to_beginning=True,
179
+ )
180
+
181
+ net_input = {
182
+ "source": collated_audios,
183
+ "padding_mask": padding_mask,
184
+ "prev_output_tokens": prev_output_tokens,
185
+ "task_name": "s2c",
186
+ }
187
+ batch = {
188
+ "id": torch.LongTensor([s["id"] for s in samples]),
189
+ "net_input": net_input,
190
+ "target": decoder_target,
191
+ "target_lengths": decoder_target_lengths,
192
+ "task_name": "s2c",
193
+ "ntokens": len(samples),
194
+ }
195
+
196
+ return batch
197
+
198
+ def collater_audio(self, audios, audio_size):
199
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
200
+ padding_mask = (
201
+ torch.BoolTensor(collated_audios.shape).fill_(False)
202
+ )
203
+ for i, audio in enumerate(audios):
204
+ diff = len(audio) - audio_size
205
+ if diff == 0:
206
+ collated_audios[i] = audio
207
+ elif diff < 0:
208
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
209
+ padding_mask[i, diff:] = True
210
+ else:
211
+ raise Exception("Diff should not be larger than 0")
212
+ return collated_audios, padding_mask
213
+
214
+ def collater_seq_label(self, targets, pad):
215
+ lengths = torch.LongTensor([len(t) for t in targets])
216
+ ntokens = lengths.sum().item()
217
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
218
+ return targets, lengths, ntokens
219
+
220
+ def collater_label(self, targets_by_label):
221
+ targets_list, lengths_list, ntokens_list = [], [], []
222
+ itr = zip(targets_by_label, [self.tgt_dict.pad()])
223
+ for targets, pad in itr:
224
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
225
+ targets_list.append(targets)
226
+ lengths_list.append(lengths)
227
+ ntokens_list.append(ntokens)
228
+ return targets_list, lengths_list, ntokens_list
229
+
230
+ def num_tokens(self, index):
231
+ return self.size(index)
232
+
233
+ def size(self, index):
234
+ return self.wav_sizes[index]
235
+
236
+ @property
237
+ def sizes(self):
238
+ return np.array(self.wav_sizes)
239
+
240
+ def ordered_indices(self):
241
+ if self.shuffle:
242
+ order = [np.random.permutation(len(self))]
243
+ else:
244
+ order = [np.arange(len(self))]
245
+
246
+ order.append(self.wav_sizes)
247
+ return np.lexsort(order)[::-1]
248
+
249
+ def postprocess(self, wav, cur_sample_rate):
250
+ if wav.dim() == 2:
251
+ wav = wav.mean(-1)
252
+ assert wav.dim() == 1, wav.dim()
253
+
254
+ if cur_sample_rate != self.sample_rate:
255
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
256
+
257
+ if self.normalize:
258
+ with torch.no_grad():
259
+ wav = F.layer_norm(wav, wav.shape)
260
+ return wav
artst/data/speech_to_speech_dataset.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import logging
9
+ import os
10
+ from typing import Any, List, Optional
11
+
12
+ import librosa
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq.data.fairseq_dataset import FairseqDataset
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ def _collate_frames(
21
+ frames: List[torch.Tensor], is_audio_input: bool = False
22
+ ):
23
+ """
24
+ Convert a list of 2D frames into a padded 3D tensor
25
+ Args:
26
+ frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
27
+ length of i-th frame and f_dim is static dimension of features
28
+ Returns:
29
+ 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
30
+ """
31
+ max_len = max(frame.size(0) for frame in frames)
32
+ if is_audio_input:
33
+ out = frames[0].new_zeros((len(frames), max_len))
34
+ else:
35
+ out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
36
+ for i, v in enumerate(frames):
37
+ out[i, : v.size(0)] = v
38
+ return out
39
+
40
+ def load_audio(manifest_path, max_keep, min_keep):
41
+ """manifest tsv: src_wav, src_nframe, tgt_wav, tgt_nframe, tgt_spkemb"""
42
+ n_long, n_short = 0, 0
43
+ src_names, tgt_names, inds, sizes, tgt_sizes, spk_embeds = [], [], [], [], [], []
44
+ with open(manifest_path) as f:
45
+ root = f.readline().strip()
46
+ for ind, line in enumerate(f):
47
+ items = line.strip().split("\t")
48
+ assert len(items) >= 2, line
49
+ sz = int(items[1])
50
+ if min_keep is not None and sz < min_keep:
51
+ n_short += 1
52
+ elif max_keep is not None and sz > max_keep:
53
+ n_long += 1
54
+ else:
55
+ src_names.append(items[0])
56
+ tgt_names.append(items[2])
57
+ tgt_sizes.append(items[3])
58
+ spk_embeds.append(items[4])
59
+ inds.append(ind)
60
+ sizes.append(sz)
61
+ tot = ind + 1
62
+ logger.info(
63
+ (
64
+ f"max_keep={max_keep}, min_keep={min_keep}, "
65
+ f"loaded {len(src_names)}, skipped {n_short} short and {n_long} long, "
66
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
67
+ )
68
+ )
69
+ return root, src_names, inds, tot, sizes, tgt_names, tgt_sizes, spk_embeds
70
+
71
+
72
+ def logmelfilterbank(
73
+ audio,
74
+ sampling_rate,
75
+ fft_size=1024,
76
+ hop_size=256,
77
+ win_length=None,
78
+ window="hann",
79
+ num_mels=80,
80
+ fmin=80,
81
+ fmax=7600,
82
+ eps=1e-10,
83
+ ):
84
+ """Compute log-Mel filterbank feature.
85
+ (https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
86
+
87
+ Args:
88
+ audio (ndarray): Audio signal (T,).
89
+ sampling_rate (int): Sampling rate.
90
+ fft_size (int): FFT size.
91
+ hop_size (int): Hop size.
92
+ win_length (int): Window length. If set to None, it will be the same as fft_size.
93
+ window (str): Window function type.
94
+ num_mels (int): Number of mel basis.
95
+ fmin (int): Minimum frequency in mel basis calculation.
96
+ fmax (int): Maximum frequency in mel basis calculation.
97
+ eps (float): Epsilon value to avoid inf in log calculation.
98
+
99
+ Returns:
100
+ ndarray: Log Mel filterbank feature (#frames, num_mels).
101
+
102
+ """
103
+ # get amplitude spectrogram
104
+ x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
105
+ win_length=win_length, window=window, pad_mode="reflect")
106
+ spc = np.abs(x_stft).T # (#frames, #bins)
107
+
108
+ # get mel basis
109
+ fmin = 0 if fmin is None else fmin
110
+ fmax = sampling_rate / 2 if fmax is None else fmax
111
+ mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
112
+
113
+ return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
114
+
115
+
116
+ class SpeechToSpeechDataset(FairseqDataset):
117
+ def __init__(
118
+ self,
119
+ manifest_path: str,
120
+ sample_rate: float,
121
+ max_keep_sample_size: Optional[int] = None,
122
+ min_keep_sample_size: Optional[int] = None,
123
+ shuffle: bool = True,
124
+ normalize: bool = False,
125
+ reduction_factor: int = 1,
126
+ ):
127
+ self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.tgt_audios, self.tgt_sizes, self.tgt_spkembs = load_audio(
128
+ manifest_path, max_keep_sample_size, min_keep_sample_size
129
+ )
130
+ self.sample_rate = sample_rate
131
+ self.shuffle = shuffle
132
+
133
+ self.normalize = normalize
134
+ self.reduction_factor = reduction_factor
135
+ logger.info(
136
+ f"reduction_factor={reduction_factor}, normalize={normalize}"
137
+ )
138
+
139
+ def get_audio(self, index):
140
+ import soundfile as sf
141
+
142
+ wav_fbank = []
143
+ for name in [self.audio_names[index], self.tgt_audios[index]]:
144
+ wav_path = os.path.join(self.audio_root, name)
145
+ wav, cur_sample_rate = sf.read(wav_path)
146
+ wav = torch.from_numpy(wav).float()
147
+ fbank = logmelfilterbank(
148
+ wav.view(-1).cpu().numpy(), 16000
149
+ )
150
+ fbank = torch.from_numpy(fbank).float()
151
+ wav = self.postprocess(wav, cur_sample_rate)
152
+ wav_fbank.append(wav)
153
+ wav_fbank.append(fbank)
154
+ src_wav, src_fbank, tgt_wav, tgt_fbank = wav_fbank
155
+ return src_wav, src_fbank, tgt_wav, tgt_fbank
156
+
157
+ def __getitem__(self, index):
158
+ src_wav, src_fbank, tgt_wav, tgt_fbank = self.get_audio(index)
159
+ spkembs = np.load(os.path.join(self.audio_root, self.tgt_spkembs[index]))
160
+ spkembs = torch.from_numpy(spkembs).float()
161
+ name = self.audio_names[index].replace("/", ".").replace(".wav", "") + "-" + self.tgt_audios[index].replace("/", ".").replace(".wav", "") + ".wav"
162
+ return {"id": index, "source": src_wav, "target": tgt_fbank, "spkembs": spkembs, "audio_name": name, "tgt_name": self.tgt_audios[index]}
163
+
164
+ def __len__(self):
165
+ return len(self.wav_sizes)
166
+
167
+ def collater(self, samples):
168
+ samples = [s for s in samples if s["source"] is not None]
169
+ if len(samples) == 0:
170
+ return {}
171
+
172
+ audios = [s["source"] for s in samples]
173
+ audio_sizes = [len(s) for s in audios]
174
+
175
+ audio_size = max(audio_sizes)
176
+ collated_audios, padding_mask = self.collater_audio(
177
+ audios, audio_size
178
+ )
179
+
180
+ fbanks = [s["target"] for s in samples]
181
+ fbank_sizes = [len(s) for s in fbanks]
182
+
183
+ collated_fbanks = _collate_frames(fbanks)
184
+ collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long)
185
+
186
+ # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
187
+ if self.reduction_factor > 1:
188
+ collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
189
+ collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
190
+ else:
191
+ collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
192
+
193
+ prev_output_tokens = torch.cat(
194
+ [collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
195
+ )
196
+
197
+ # make labels for stop prediction
198
+ labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
199
+ for i, l in enumerate(fbank_sizes):
200
+ labels[i, l - 1 :] = 1.0
201
+
202
+ spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
203
+
204
+ net_input = {
205
+ "source": collated_audios,
206
+ "padding_mask": padding_mask,
207
+ "prev_output_tokens": prev_output_tokens,
208
+ "tgt_lengths": collated_fbanks_size_in,
209
+ "spkembs": spkembs,
210
+ "task_name": "s2s",
211
+ }
212
+ batch = {
213
+ "id": torch.LongTensor([s["id"] for s in samples]),
214
+ "name": [s["audio_name"] for s in samples],
215
+ "tgt_name": [s["tgt_name"] for s in samples],
216
+ "net_input": net_input,
217
+ "labels": labels,
218
+ "dec_target": collated_fbanks,
219
+ "dec_target_lengths": collated_fbanks_size,
220
+ "src_lengths": torch.LongTensor(audio_sizes),
221
+ "task_name": "s2s",
222
+ "ntokens": sum(audio_sizes),
223
+ "target": collated_fbanks,
224
+ }
225
+
226
+ return batch
227
+
228
+ def collater_audio(self, audios, audio_size):
229
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
230
+ padding_mask = (
231
+ torch.BoolTensor(collated_audios.shape).fill_(False)
232
+ )
233
+ for i, audio in enumerate(audios):
234
+ diff = len(audio) - audio_size
235
+ if diff == 0:
236
+ collated_audios[i] = audio
237
+ elif diff < 0:
238
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
239
+ padding_mask[i, diff:] = True
240
+ else:
241
+ raise Exception("Diff should not be larger than 0")
242
+ return collated_audios, padding_mask
243
+
244
+
245
+ def num_tokens(self, index):
246
+ return self.wav_sizes[index]
247
+
248
+ def size(self, index):
249
+ return self.wav_sizes[index], self.tgt_sizes[index]
250
+
251
+ @property
252
+ def sizes(self):
253
+ return np.array(self.wav_sizes)
254
+
255
+ @property
256
+ def can_reuse_epoch_itr_across_epochs(self):
257
+ """No cache dataset if dataset is large-scale. Cache dataset for small dataset."""
258
+ return True
259
+
260
+ def ordered_indices(self):
261
+ if self.shuffle:
262
+ order = [np.random.permutation(len(self))]
263
+ else:
264
+ order = [np.arange(len(self))]
265
+
266
+ order.append(self.wav_sizes)
267
+ return np.lexsort(order)[::-1]
268
+
269
+ def postprocess(self, wav, cur_sample_rate):
270
+ if wav.dim() == 2:
271
+ wav = wav.mean(-1)
272
+ assert wav.dim() == 1, wav.dim()
273
+
274
+ if cur_sample_rate != self.sample_rate:
275
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
276
+
277
+ if self.normalize:
278
+ with torch.no_grad():
279
+ wav = F.layer_norm(wav, wav.shape)
280
+ return wav
artst/data/speech_to_text_dataset.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import itertools
9
+ import logging
10
+ import os
11
+ import mmap
12
+ from typing import Any, List, Optional
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ torch.set_printoptions(profile="full")
18
+ import torch.nn.functional as F
19
+ from fairseq.data import data_utils, Dictionary
20
+ from fairseq.data.fairseq_dataset import FairseqDataset
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def load_audio(manifest_path, max_keep, min_keep):
26
+ n_long, n_short = 0, 0
27
+ names, inds, sizes = [], [], []
28
+ with open(manifest_path) as f:
29
+ root = f.readline().strip()
30
+ for ind, line in enumerate(f):
31
+ items = line.strip().split("\t")
32
+ assert len(items) >= 2, line
33
+ sz = int(items[1])
34
+ if min_keep is not None and sz < min_keep:
35
+ n_short += 1
36
+ elif max_keep is not None and sz > max_keep:
37
+ n_long += 1
38
+ else:
39
+ names.append(items[0])
40
+ inds.append(ind)
41
+ sizes.append(sz)
42
+ tot = ind + 1
43
+ logger.info(
44
+ (
45
+ f"max_keep={max_keep}, min_keep={min_keep}, "
46
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
47
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
48
+ )
49
+ )
50
+ return root, names, inds, tot, sizes
51
+
52
+
53
+ def load_label(label_path, inds, tot):
54
+ with open(label_path) as f:
55
+ labels = [line.rstrip() for line in f]
56
+ assert (
57
+ len(labels) == tot
58
+ ), f"number of labels does not match ({len(labels)} != {tot})"
59
+ labels = [labels[i] for i in inds]
60
+ return labels
61
+
62
+
63
+ def load_label_offset(label_path, inds, tot):
64
+ with open(label_path) as f:
65
+ # Hawau:
66
+ # changed line length reading as it's incorrect
67
+ code_lengths = [len(line.encode("utf-8")) for line in f] #original
68
+ # code_lengths = [len(line) for line in f] #fix
69
+ assert (
70
+ len(code_lengths) == tot
71
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
72
+ offsets = list(itertools.accumulate([0] + code_lengths))
73
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
74
+ return offsets
75
+
76
+
77
+ class SpeechToTextDataset(FairseqDataset):
78
+ def __init__(
79
+ self,
80
+ manifest_path: str,
81
+ sample_rate: float,
82
+ label_paths: List[str],
83
+ label_processors: Optional[List[Any]] = None,
84
+ max_keep_sample_size: Optional[int] = None,
85
+ min_keep_sample_size: Optional[int] = None,
86
+ shuffle: bool = True,
87
+ normalize: bool = False,
88
+ store_labels: bool = True,
89
+ tgt_dict: Optional[Dictionary] = None,
90
+ tokenizer = None,
91
+ ):
92
+ self.audio_root, self.audio_names, inds, tot, self.wav_sizes = load_audio(
93
+ manifest_path, max_keep_sample_size, min_keep_sample_size
94
+ )
95
+
96
+ self.sample_rate = sample_rate
97
+ self.shuffle = shuffle
98
+ self.tgt_dict = tgt_dict
99
+ self.tokenizer = tokenizer
100
+
101
+ self.num_labels = len(label_paths)
102
+ self.label_processors = label_processors
103
+ self.store_labels = store_labels
104
+
105
+ if store_labels:
106
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
107
+ logger.info(f"label_list: {self.label_list}")
108
+ else:
109
+ self.label_paths = label_paths
110
+ self.label_offsets_list = [
111
+ load_label_offset(p, inds, tot) for p in label_paths
112
+ ]
113
+ # logger.info(f"label_offsets_list: {self.label_offsets_list}")
114
+ assert label_processors is None or len(label_processors) == self.num_labels
115
+
116
+ self.normalize = normalize
117
+ logger.info(
118
+ f"normalize={normalize}"
119
+ )
120
+
121
+ def get_audio(self, index):
122
+ import soundfile as sf
123
+ # Hawau:
124
+ # logger.info(f"loaded_audio: {self.audio_names[index]}")
125
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
126
+ wav, cur_sample_rate = sf.read(wav_path)
127
+ wav = torch.from_numpy(wav).float()
128
+ wav = self.postprocess(wav, cur_sample_rate)
129
+ return wav
130
+
131
+ def get_label(self, index, label_idx):
132
+ if self.store_labels:
133
+ label = self.label_list[label_idx][index]
134
+ else:
135
+ # list slicing method
136
+ # with open(self.label_paths[label_idx]) as f:
137
+ # offset_s, offset_e = self.label_offsets_list[label_idx][index]
138
+ # # Hawau:
139
+ # # f.seek(offset_s)
140
+ # # label = f.read(offset_e - offset_s)
141
+ # label = f.read()[offset_s : offset_e]
142
+ # Hawau:
143
+ # mmap method
144
+ with open(self.label_paths[label_idx], encoding='utf-8') as f:
145
+ with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
146
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
147
+ label = mm[offset_s:offset_e].decode("utf-8")
148
+
149
+
150
+ # Hawau:
151
+ # logger.info(f"loaded_label: {label}")
152
+ if self.tokenizer is not None:
153
+ label = self.tokenizer.encode(label)
154
+
155
+ if self.label_processors is not None:
156
+ label = self.label_processors[label_idx](label)
157
+ # logger.info(f"processed_label: {label}")
158
+ return label
159
+
160
+ def get_labels(self, index):
161
+ return [self.get_label(index, i) for i in range(self.num_labels)]
162
+
163
+ def __getitem__(self, index):
164
+ wav = self.get_audio(index)
165
+ labels = self.get_labels(index)
166
+ return {"id": index, "source": wav, "label_list": labels}
167
+
168
+ def __len__(self):
169
+ return len(self.wav_sizes)
170
+
171
+ def collater(self, samples):
172
+ samples = [s for s in samples if s["source"] is not None]
173
+ if len(samples) == 0:
174
+ return {}
175
+
176
+ audios = [s["source"] for s in samples]
177
+ audio_sizes = [len(s) for s in audios]
178
+
179
+ audio_size = max(audio_sizes)
180
+ collated_audios, padding_mask = self.collater_audio(
181
+ audios, audio_size
182
+ )
183
+
184
+ targets_by_label = [
185
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
186
+ ]
187
+ targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label)
188
+
189
+ # Hawau:
190
+ # logger.info(f'targets_list: {targets_list}')
191
+
192
+
193
+ decoder_label = [
194
+ torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
195
+ for i in range(targets_list[0].size(0))
196
+ ]
197
+
198
+ decoder_target = data_utils.collate_tokens(
199
+ decoder_label,
200
+ self.tgt_dict.pad(),
201
+ self.tgt_dict.eos(),
202
+ left_pad=False,
203
+ move_eos_to_beginning=False,
204
+ )
205
+ decoder_target_lengths = torch.tensor(
206
+ [x.size(0) for x in decoder_label], dtype=torch.long
207
+ )
208
+ prev_output_tokens = data_utils.collate_tokens(
209
+ decoder_label,
210
+ self.tgt_dict.pad(),
211
+ self.tgt_dict.eos(),
212
+ left_pad=False,
213
+ move_eos_to_beginning=True,
214
+ )
215
+
216
+ net_input = {
217
+ "source": collated_audios,
218
+ "padding_mask": padding_mask,
219
+ "prev_output_tokens": prev_output_tokens,
220
+ "task_name": "s2t",
221
+ }
222
+ batch = {
223
+ "id": torch.LongTensor([s["id"] for s in samples]),
224
+ "net_input": net_input,
225
+ "target": decoder_target,
226
+ "target_lengths": decoder_target_lengths,
227
+ "task_name": "s2t",
228
+ "ntokens": ntokens_list[0]
229
+ }
230
+
231
+ return batch
232
+
233
+ def collater_audio(self, audios, audio_size):
234
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
235
+ padding_mask = (
236
+ torch.BoolTensor(collated_audios.shape).fill_(False)
237
+ )
238
+ for i, audio in enumerate(audios):
239
+ diff = len(audio) - audio_size
240
+ if diff == 0:
241
+ collated_audios[i] = audio
242
+ elif diff < 0:
243
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
244
+ padding_mask[i, diff:] = True
245
+ else:
246
+ raise Exception("Diff should not be larger than 0")
247
+ return collated_audios, padding_mask
248
+
249
+ def collater_seq_label(self, targets, pad):
250
+ lengths = torch.LongTensor([len(t) for t in targets])
251
+ ntokens = lengths.sum().item()
252
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
253
+ return targets, lengths, ntokens
254
+
255
+ def collater_label(self, targets_by_label):
256
+ targets_list, lengths_list, ntokens_list = [], [], []
257
+ itr = zip(targets_by_label, [self.tgt_dict.pad()])
258
+
259
+ for targets, pad in itr:
260
+ # Hawau:
261
+ # logger.info(f'targets: {targets}')
262
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
263
+ targets_list.append(targets)
264
+ lengths_list.append(lengths)
265
+ ntokens_list.append(ntokens)
266
+ return targets_list, lengths_list, ntokens_list
267
+
268
+ def num_tokens(self, index):
269
+ return self.size(index)
270
+
271
+ def size(self, index):
272
+ return self.wav_sizes[index]
273
+
274
+ @property
275
+ def sizes(self):
276
+ return np.array(self.wav_sizes)
277
+
278
+ def ordered_indices(self):
279
+ if self.shuffle:
280
+ order = [np.random.permutation(len(self))]
281
+ else:
282
+ order = [np.arange(len(self))]
283
+
284
+ order.append(self.wav_sizes)
285
+ return np.lexsort(order)[::-1]
286
+
287
+ def postprocess(self, wav, cur_sample_rate):
288
+ if wav.dim() == 2:
289
+ wav = wav.mean(-1)
290
+ assert wav.dim() == 1, wav.dim()
291
+
292
+ if cur_sample_rate != self.sample_rate:
293
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
294
+
295
+ if self.normalize:
296
+ with torch.no_grad():
297
+ wav = F.layer_norm(wav, wav.shape)
298
+ return wav
artst/data/text_dataset.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from fairseq.data import FairseqDataset, data_utils
14
+
15
+
16
+ def collate(
17
+ samples,
18
+ pad_idx,
19
+ eos_idx,
20
+ vocab,
21
+ left_pad_source=False,
22
+ left_pad_target=False,
23
+ input_feeding=True,
24
+ pad_to_length=None,
25
+ ):
26
+ assert input_feeding
27
+ if len(samples) == 0:
28
+ return {}
29
+
30
+ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
31
+ return data_utils.collate_tokens(
32
+ [s[key] for s in samples],
33
+ pad_idx,
34
+ eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
35
+ left_pad=left_pad,
36
+ move_eos_to_beginning=move_eos_to_beginning,
37
+ pad_to_length=pad_to_length,
38
+ )
39
+
40
+ id = torch.LongTensor([s["id"] for s in samples])
41
+ src_tokens = merge(
42
+ "source",
43
+ left_pad=left_pad_source,
44
+ pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
45
+ )
46
+ # sort by descending source length
47
+ src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
48
+ src_lengths, sort_order = src_lengths.sort(descending=True)
49
+ id = id.index_select(0, sort_order)
50
+ src_tokens = src_tokens.index_select(0, sort_order)
51
+
52
+ prev_output_tokens = None
53
+ target = None
54
+ if samples[0].get("target", None) is not None:
55
+ target = merge(
56
+ "target",
57
+ left_pad=left_pad_target,
58
+ pad_to_length=pad_to_length["target"]
59
+ if pad_to_length is not None
60
+ else None,
61
+ )
62
+ target = target.index_select(0, sort_order)
63
+ ntokens = sum(len(s["target"]) for s in samples)
64
+
65
+ if input_feeding:
66
+ # we create a shifted version of targets for feeding the
67
+ # previous output token(s) into the next decoder step
68
+ prev_output_tokens = merge(
69
+ "target",
70
+ left_pad=left_pad_target,
71
+ move_eos_to_beginning=True,
72
+ pad_to_length=pad_to_length["target"]
73
+ if pad_to_length is not None
74
+ else None,
75
+ )
76
+ prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
77
+ else:
78
+ ntokens = sum(len(s["source"]) for s in samples)
79
+
80
+ batch = {
81
+ "id": id,
82
+ "ntokens": ntokens,
83
+ "net_input": {
84
+ "src_tokens": src_tokens,
85
+ "src_lengths": src_lengths,
86
+ },
87
+ "target": target,
88
+ "nsentences": samples[0]["source"].size(0),
89
+ "sort_order": sort_order,
90
+ "task_name": 'text_pretrain',
91
+ }
92
+ if prev_output_tokens is not None:
93
+ batch["net_input"]["prev_output_tokens"] = prev_output_tokens
94
+
95
+ return batch
96
+
97
+
98
+ class TextPretrainDataset(FairseqDataset):
99
+ """
100
+ A wrapper around TokenBlockDataset for BART dataset.
101
+
102
+ Args:
103
+ dataset (TokenBlockDataset): dataset to wrap
104
+ sizes (List[int]): sentence lengths
105
+ vocab (~fairseq.data.Dictionary): vocabulary
106
+ mask_idx (int): dictionary index used for masked token
107
+ mask_whole_words: only mask whole words. This should be a byte mask
108
+ over vocab indices, indicating whether it is the beginning of a
109
+ word. We will extend any mask to encompass the whole word.
110
+ shuffle (bool, optional): shuffle the elements before batching.
111
+ Default: ``True``
112
+ seed: Seed for random number generator for reproducibility.
113
+ args: argparse arguments.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ dataset,
119
+ sizes,
120
+ vocab,
121
+ mask_idx,
122
+ mask_whole_words,
123
+ shuffle,
124
+ seed,
125
+ args,
126
+ eos=None,
127
+ item_transform_func=None,
128
+ iid_noise_target=False,
129
+ uni_mask_idxs=None,
130
+ ):
131
+ self.dataset = dataset
132
+
133
+ self.sizes = sizes
134
+
135
+ self.vocab = vocab
136
+ self.shuffle = shuffle
137
+ self.seed = seed
138
+ if iid_noise_target:
139
+ assert isinstance(uni_mask_idxs, torch.Tensor), "if use iid_noise_target, the uni_mask_idxs must be a tensor which contain the mask indexs"
140
+ self.iid_noise_target = iid_noise_target
141
+ self.uni_mask_idxs = uni_mask_idxs
142
+ self.mask_idx = mask_idx
143
+ self.mask_whole_word = mask_whole_words
144
+ self.mask_ratio = args.mask
145
+ self.random_ratio = args.mask_random
146
+ self.insert_ratio = args.insert
147
+ self.rotate_ratio = args.rotate
148
+ self.permute_sentence_ratio = args.permute_sentences
149
+ self.eos = eos if eos is not None else vocab.eos()
150
+ self.item_transform_func = item_transform_func
151
+
152
+ if args.bpe != "gpt2":
153
+ self.full_stop_index = self.vocab.eos()
154
+ else:
155
+ assert args.bpe == "gpt2"
156
+ self.full_stop_index = self.vocab.index("13")
157
+
158
+ self.replace_length = args.replace_length
159
+ if self.replace_length not in [-1, 0, 1]:
160
+ raise ValueError(f"invalid arg: replace_length={self.replace_length}")
161
+ if args.mask_length not in ["subword", "word", "span-poisson"]:
162
+ raise ValueError(f"invalid arg: mask-length={args.mask_length}")
163
+ if args.mask_length == "subword" and args.replace_length not in [0, 1]:
164
+ raise ValueError(f"if using subwords, use replace-length=1 or 0")
165
+
166
+ self.mask_span_distribution = None
167
+ if args.mask_length == "span-poisson":
168
+ _lambda = args.poisson_lambda
169
+
170
+ lambda_to_the_k = 1
171
+ e_to_the_minus_lambda = math.exp(-_lambda)
172
+ k_factorial = 1
173
+ ps = []
174
+ for k in range(0, 128):
175
+ ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
176
+ lambda_to_the_k *= _lambda
177
+ k_factorial *= k + 1
178
+ if ps[-1] < 0.0000001:
179
+ break
180
+ ps = torch.FloatTensor(ps)
181
+ self.mask_span_distribution = torch.distributions.Categorical(ps)
182
+
183
+ self.epoch = 0
184
+
185
+ @property
186
+ def can_reuse_epoch_itr_across_epochs(self):
187
+ return True # only the noise changes, not item sizes
188
+
189
+ def set_epoch(self, epoch, **unused):
190
+ self.epoch = epoch
191
+
192
+ def __getitem__(self, index):
193
+ with data_utils.numpy_seed(self.seed, self.epoch, index):
194
+ tokens = self.dataset[index]
195
+ assert tokens[-1] == self.eos
196
+ source, target = tokens, tokens.clone()
197
+
198
+ if self.permute_sentence_ratio > 0.0:
199
+ source = self.permute_sentences(source, self.permute_sentence_ratio)
200
+
201
+ if self.mask_ratio > 0:
202
+ source, new_target = self.add_whole_word_mask(source, self.mask_ratio)
203
+ if new_target is not None:
204
+ target = new_target
205
+
206
+ if self.insert_ratio > 0:
207
+ source = self.add_insertion_noise(source, self.insert_ratio)
208
+
209
+ if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
210
+ source = self.add_rolling_noise(source)
211
+ # there can additional changes to make:
212
+ if self.item_transform_func is not None:
213
+ source, target = self.item_transform_func(source, target)
214
+
215
+ assert (source >= 0).all()
216
+ assert (source[1:-1] >= 1).all()
217
+ assert (source <= len(self.vocab)).all()
218
+ assert source[0] == self.vocab.bos()
219
+ assert source[-1] == self.eos
220
+ return {
221
+ "id": index,
222
+ "source": source,
223
+ "target": target,
224
+ }
225
+
226
+ def __len__(self):
227
+ return len(self.dataset)
228
+
229
+ def permute_sentences(self, source, p=1.0):
230
+ full_stops = source == self.full_stop_index
231
+ # Pretend it ends with a full stop so last span is a sentence
232
+ full_stops[-2] = 1
233
+
234
+ # Tokens that are full stops, where the previous token is not
235
+ sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
236
+ result = source.clone()
237
+
238
+ num_sentences = sentence_ends.size(0)
239
+ num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
240
+ substitutions = torch.randperm(num_sentences)[:num_to_permute]
241
+ ordering = torch.arange(0, num_sentences)
242
+ ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
243
+
244
+ # Ignore <bos> at start
245
+ index = 1
246
+ for i in ordering:
247
+ sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
248
+ result[index : index + sentence.size(0)] = sentence
249
+ index += sentence.size(0)
250
+ return result
251
+
252
+ def word_starts(self, source):
253
+ if self.mask_whole_word is not None:
254
+ is_word_start = self.mask_whole_word.gather(0, source)
255
+ else:
256
+ is_word_start = torch.ones(source.size())
257
+ is_word_start[0] = 0
258
+ is_word_start[-1] = 0
259
+ return is_word_start
260
+
261
+ def add_whole_word_mask(self, source, p):
262
+ source_ori = source.clone()
263
+ is_word_start = self.word_starts(source)
264
+ num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
265
+ num_inserts = 0
266
+ if num_to_mask == 0:
267
+ return source
268
+
269
+ if self.mask_span_distribution is not None:
270
+ lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
271
+
272
+ # Make sure we have enough to mask
273
+ cum_length = torch.cumsum(lengths, 0)
274
+ while cum_length[-1] < num_to_mask:
275
+ lengths = torch.cat(
276
+ [
277
+ lengths,
278
+ self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
279
+ ],
280
+ dim=0,
281
+ )
282
+ cum_length = torch.cumsum(lengths, 0)
283
+
284
+ # Trim to masking budget
285
+ i = 0
286
+ while cum_length[i] < num_to_mask:
287
+ i += 1
288
+ lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
289
+ num_to_mask = i + 1
290
+ lengths = lengths[:num_to_mask]
291
+
292
+ # Handle 0-length mask (inserts) separately
293
+ lengths = lengths[lengths > 0]
294
+ num_inserts = num_to_mask - lengths.size(0)
295
+ num_to_mask -= num_inserts
296
+ if num_to_mask == 0:
297
+ return self.add_insertion_noise(source, num_inserts / source.size(0))
298
+
299
+ assert (lengths > 0).all()
300
+ else:
301
+ lengths = torch.ones((num_to_mask,)).long()
302
+ assert is_word_start[-1] == 0
303
+ word_starts = is_word_start.nonzero(as_tuple=False)
304
+ indices = word_starts[
305
+ torch.randperm(word_starts.size(0))[:num_to_mask]
306
+ ].squeeze(1)
307
+ mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
308
+
309
+ source_length = source.size(0)
310
+ assert source_length - 1 not in indices
311
+ to_keep = torch.ones(source_length, dtype=torch.bool)
312
+ is_word_start[
313
+ -1
314
+ ] = 255 # acts as a long length, so spans don't go over the end of doc
315
+ if self.replace_length == 0:
316
+ to_keep[indices] = 0
317
+ else:
318
+ # keep index, but replace it with [MASK]
319
+ source[indices] = self.mask_idx
320
+ source[indices[mask_random]] = torch.randint(
321
+ 1, len(self.vocab), size=(mask_random.sum(),)
322
+ )
323
+
324
+ if self.mask_span_distribution is not None:
325
+ assert len(lengths.size()) == 1
326
+ assert lengths.size() == indices.size()
327
+ lengths -= 1
328
+ while indices.size(0) > 0:
329
+ assert lengths.size() == indices.size()
330
+ lengths -= is_word_start[indices + 1].long()
331
+ uncompleted = lengths >= 0
332
+ indices = indices[uncompleted] + 1
333
+ mask_random = mask_random[uncompleted]
334
+ lengths = lengths[uncompleted]
335
+ if self.replace_length != -1:
336
+ # delete token
337
+ to_keep[indices] = 0
338
+ else:
339
+ # keep index, but replace it with [MASK]
340
+ source[indices] = self.mask_idx
341
+ source[indices[mask_random]] = torch.randint(
342
+ 1, len(self.vocab), size=(mask_random.sum(),)
343
+ )
344
+ else:
345
+ # A bit faster when all lengths are 1
346
+ while indices.size(0) > 0:
347
+ uncompleted = is_word_start[indices + 1] == 0
348
+ indices = indices[uncompleted] + 1
349
+ mask_random = mask_random[uncompleted]
350
+ if self.replace_length != -1:
351
+ # delete token
352
+ to_keep[indices] = 0
353
+ else:
354
+ # keep index, but replace it with [MASK]
355
+ source[indices] = self.mask_idx
356
+ source[indices[mask_random]] = torch.randint(
357
+ 1, len(self.vocab), size=(mask_random.sum(),)
358
+ )
359
+
360
+ assert source_length - 1 not in indices
361
+
362
+ if not self.iid_noise_target:
363
+ source = source[to_keep]
364
+ target = None
365
+ else:
366
+ ## Prepare source
367
+ source_mask_idx = (source == self.mask_idx).nonzero().view(-1)
368
+ source[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)]
369
+ source = source[to_keep]
370
+
371
+ ## Prepare target
372
+ to_keep[source_mask_idx] = 0
373
+
374
+ # source_mask_idx: from [a, b, c, ...] to [a, b + 1, c + 2, ...]
375
+ source_mask_idx = source_mask_idx + torch.arange(source_mask_idx.size(0))
376
+ # target: source_length + mask_length
377
+ target = source_ori.new_zeros(source_mask_idx.size(0) + source_ori.size(0))
378
+ # target: [0, 0, 0, X, 0, 0, Y, ....]
379
+ target[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)]
380
+
381
+ target_to_keep = to_keep.new_zeros(source_mask_idx.size(0) + source_ori.size(0))
382
+
383
+ # Copy original value to target and target_to_keep
384
+ target_to_keep[target == 0] = to_keep
385
+ target_to_keep[-1] = 0
386
+ target[target == 0] = source_ori
387
+
388
+ target = target[~target_to_keep]
389
+
390
+ if num_inserts > 0:
391
+ source = self.add_insertion_noise(source, num_inserts / source.size(0))
392
+
393
+ return source, target
394
+
395
+ def add_permuted_noise(self, tokens, p):
396
+ num_words = len(tokens)
397
+ num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
398
+ substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
399
+ tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
400
+ return tokens
401
+
402
+ def add_rolling_noise(self, tokens):
403
+ offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
404
+ tokens = torch.cat(
405
+ (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
406
+ dim=0,
407
+ )
408
+ return tokens
409
+
410
+ def add_insertion_noise(self, tokens, p):
411
+ if p == 0.0:
412
+ return tokens
413
+
414
+ num_tokens = len(tokens)
415
+ n = int(math.ceil(num_tokens * p))
416
+
417
+ noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
418
+ noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
419
+ noise_mask[noise_indices] = 1
420
+ result = torch.LongTensor(n + len(tokens)).fill_(-1)
421
+
422
+ num_random = int(math.ceil(n * self.random_ratio))
423
+ result[noise_indices[num_random:]] = self.mask_idx
424
+ result[noise_indices[:num_random]] = torch.randint(
425
+ low=1, high=len(self.vocab), size=(num_random,)
426
+ )
427
+
428
+ result[~noise_mask] = tokens
429
+
430
+ assert (result >= 0).all()
431
+ return result
432
+
433
+ def collater(self, samples, pad_to_length=None):
434
+ """Merge a list of samples to form a mini-batch.
435
+ Args:
436
+ samples (List[dict]): samples to collate
437
+ Returns:
438
+ dict: a mini-batch of data
439
+ """
440
+ return collate(
441
+ samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
442
+ )
443
+
444
+ def num_tokens(self, index):
445
+ """Return the number of tokens in a sample. This value is used to
446
+ enforce ``--max-tokens`` during batching."""
447
+ return self.sizes[index]
448
+
449
+ def size(self, index):
450
+ """Return an example's size as a float or tuple. This value is used when
451
+ filtering a dataset with ``--max-positions``."""
452
+ return self.sizes[index]
453
+
454
+ def ordered_indices(self):
455
+ """Return an ordered list of indices. Batches will be constructed based
456
+ on this order."""
457
+ if self.shuffle:
458
+ indices = np.random.permutation(len(self))
459
+ else:
460
+ indices = np.arange(len(self))
461
+ return indices[np.argsort(self.sizes[indices], kind="mergesort")]
462
+
463
+ def prefetch(self, indices):
464
+ self.src.prefetch(indices)
465
+ self.tgt.prefetch(indices)
466
+
467
+ @property
468
+ def supports_prefetch(self):
469
+ return (
470
+ hasattr(self.src, "supports_prefetch")
471
+ and self.src.supports_prefetch
472
+ and hasattr(self.tgt, "supports_prefetch")
473
+ and self.tgt.supports_prefetch
474
+ )
artst/data/text_to_speech_dataset.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import itertools
9
+ import logging
10
+ import os
11
+ from typing import Any, List, Optional
12
+ import mmap
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import librosa
19
+ from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
20
+ from fairseq.data import data_utils, Dictionary
21
+ from fairseq.data.fairseq_dataset import FairseqDataset
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ def _collate_frames(
27
+ frames: List[torch.Tensor], is_audio_input: bool = False
28
+ ):
29
+ """
30
+ Convert a list of 2D frames into a padded 3D tensor
31
+ Args:
32
+ frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
33
+ length of i-th frame and f_dim is static dimension of features
34
+ Returns:
35
+ 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
36
+ """
37
+ max_len = max(frame.size(0) for frame in frames)
38
+ if is_audio_input:
39
+ out = frames[0].new_zeros((len(frames), max_len))
40
+ else:
41
+ out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
42
+ for i, v in enumerate(frames):
43
+ out[i, : v.size(0)] = v
44
+ return out
45
+
46
+ def load_audio(manifest_path, max_keep, min_keep):
47
+ n_long, n_short = 0, 0
48
+ names, inds, sizes, spk_embeds = [], [], [], []
49
+ with open(manifest_path) as f:
50
+ root = f.readline().strip()
51
+ for ind, line in enumerate(f):
52
+ items = line.strip().split("\t")
53
+ assert len(items) == 3, line
54
+ sz = int(items[1])
55
+ if min_keep is not None and sz < min_keep:
56
+ n_short += 1
57
+ elif max_keep is not None and sz > max_keep:
58
+ n_long += 1
59
+ else:
60
+ names.append(items[0])
61
+ spk_embeds.append(items[2])
62
+ inds.append(ind)
63
+ sizes.append(sz)
64
+ tot = ind + 1
65
+ logger.info(
66
+ (
67
+ f"max_keep={max_keep}, min_keep={min_keep}, "
68
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
69
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
70
+ )
71
+ )
72
+ return root, names, inds, tot, sizes, spk_embeds
73
+
74
+
75
+ def load_label(label_path, inds, tot):
76
+ with open(label_path) as f:
77
+ labels = [line.rstrip() for line in f]
78
+ assert (
79
+ len(labels) == tot
80
+ ), f"number of labels does not match ({len(labels)} != {tot})"
81
+ labels = [labels[i] for i in inds]
82
+ return labels
83
+
84
+
85
+ def load_label_offset(label_path, inds, tot):
86
+ with open(label_path, encoding='utf-8') as f:
87
+ code_lengths = [len(line.encode("utf-8")) for line in f] #changed as in speech_to_text_dataset.py
88
+ assert (
89
+ len(code_lengths) == tot
90
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
91
+ offsets = list(itertools.accumulate([0] + code_lengths))
92
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
93
+ return offsets
94
+
95
+
96
+ def logmelfilterbank(
97
+ audio,
98
+ sampling_rate,
99
+ fft_size=1024,
100
+ hop_size=256,
101
+ win_length=None,
102
+ window="hann",
103
+ num_mels=80,
104
+ fmin=80,
105
+ fmax=7600,
106
+ eps=1e-10,
107
+ ):
108
+ """Compute log-Mel filterbank feature.
109
+ (https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
110
+
111
+ Args:
112
+ audio (ndarray): Audio signal (T,).
113
+ sampling_rate (int): Sampling rate.
114
+ fft_size (int): FFT size.
115
+ hop_size (int): Hop size.
116
+ win_length (int): Window length. If set to None, it will be the same as fft_size.
117
+ window (str): Window function type.
118
+ num_mels (int): Number of mel basis.
119
+ fmin (int): Minimum frequency in mel basis calculation.
120
+ fmax (int): Maximum frequency in mel basis calculation.
121
+ eps (float): Epsilon value to avoid inf in log calculation.
122
+
123
+ Returns:
124
+ ndarray: Log Mel filterbank feature (#frames, num_mels).
125
+
126
+ """
127
+ # get amplitude spectrogram
128
+ x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
129
+ win_length=win_length, window=window, pad_mode="reflect")
130
+ spc = np.abs(x_stft).T # (#frames, #bins)
131
+
132
+ # get mel basis
133
+ fmin = 0 if fmin is None else fmin
134
+ fmax = sampling_rate / 2 if fmax is None else fmax
135
+ mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
136
+
137
+ return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
138
+
139
+
140
+
141
+ class TextToSpeechDataset(FairseqDataset):
142
+ def __init__(
143
+ self,
144
+ manifest_path: str,
145
+ sample_rate: float,
146
+ label_paths: List[str],
147
+ label_processors: Optional[List[Any]] = None,
148
+ max_keep_sample_size: Optional[int] = None,
149
+ min_keep_sample_size: Optional[int] = None,
150
+ shuffle: bool = True,
151
+ normalize: bool = False,
152
+ store_labels: bool = True,
153
+ src_dict: Optional[Dictionary] = None,
154
+ tokenizer = None,
155
+ reduction_factor: int = 1,
156
+ inference: bool = False,
157
+ ):
158
+
159
+ self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.spk_embeds = load_audio(
160
+ manifest_path, max_keep_sample_size, min_keep_sample_size
161
+ )
162
+ self.inference = inference
163
+
164
+ self.sample_rate = sample_rate
165
+ self.shuffle = shuffle
166
+ self.src_dict = src_dict
167
+ self.tokenizer = tokenizer
168
+
169
+ self.num_labels = len(label_paths)
170
+ self.label_processors = label_processors
171
+ self.store_labels = store_labels
172
+ if store_labels:
173
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
174
+ else:
175
+ self.label_paths = label_paths
176
+ self.label_offsets_list = [
177
+ load_label_offset(p, inds, tot) for p in label_paths
178
+ ]
179
+ assert label_processors is None or len(label_processors) == self.num_labels
180
+
181
+ self.normalize = normalize
182
+ self.reduction_factor = reduction_factor
183
+ logger.info(
184
+ f"reduction_factor={reduction_factor}, normalize={normalize}"
185
+ )
186
+
187
+ def get_audio(self, index):
188
+ import soundfile as sf
189
+
190
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
191
+ wav, cur_sample_rate = sf.read(wav_path)
192
+ wav = torch.from_numpy(wav).float()
193
+ fbank = logmelfilterbank(
194
+ wav.view(-1).cpu().numpy(), 16000
195
+ )
196
+ fbank = torch.from_numpy(fbank).float()
197
+ wav = self.postprocess(wav, cur_sample_rate)
198
+ return wav, fbank
199
+
200
+ def get_label(self, index, label_idx):
201
+ if self.store_labels:
202
+ label = self.label_list[label_idx][index]
203
+ else:
204
+ # with open(self.label_paths[label_idx]) as f:
205
+ # offset_s, offset_e = self.label_offsets_list[label_idx][index]
206
+ # f.seek(offset_s)
207
+ # label = f.read(offset_e - offset_s)
208
+
209
+ # Hawau:
210
+ # mmap method
211
+ with open(self.label_paths[label_idx], encoding='utf-8') as f:
212
+ with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
213
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
214
+ label = mm[offset_s:offset_e].decode("utf-8")
215
+
216
+
217
+ if self.tokenizer is not None:
218
+ label = self.tokenizer.encode(label)
219
+
220
+ if self.label_processors is not None:
221
+ label = self.label_processors[label_idx](label)
222
+ return label
223
+
224
+ def get_labels(self, index):
225
+ return [self.get_label(index, i) for i in range(self.num_labels)]
226
+
227
+ def __getitem__(self, index):
228
+ wav, fbank = self.get_audio(index)
229
+ labels = self.get_labels(index)
230
+ spkembs = get_features_or_waveform(
231
+ os.path.join(self.audio_root, self.spk_embeds[index])
232
+ )
233
+ spkembs = torch.from_numpy(spkembs).float()
234
+
235
+ return {"id": index, "source": labels, "target": fbank, "spkembs": spkembs, "audio_name": self.audio_names[index]}
236
+
237
+
238
+ def __len__(self):
239
+ return len(self.wav_sizes)
240
+
241
+ def collater(self, samples):
242
+ samples = [s for s in samples if s["source"] is not None]
243
+ if len(samples) == 0:
244
+ return {}
245
+
246
+ fbanks = [s["target"] for s in samples]
247
+ fbank_sizes = [len(s) for s in fbanks]
248
+
249
+ collated_fbanks = _collate_frames(fbanks)
250
+ collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long)
251
+
252
+ # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
253
+ if self.reduction_factor > 1:
254
+ collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
255
+ collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
256
+ else:
257
+ collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
258
+
259
+ prev_output_tokens = torch.cat(
260
+ [collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
261
+ )
262
+
263
+ # make labels for stop prediction
264
+ labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
265
+ for i, l in enumerate(fbank_sizes):
266
+ labels[i, l - 1 :] = 1.0
267
+
268
+ spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
269
+
270
+ sources_by_label = [
271
+ [s["source"][i] for s in samples] for i in range(self.num_labels)
272
+ ]
273
+ sources_list, lengths_list, ntokens_list = self.collater_label(sources_by_label)
274
+
275
+ net_input = {
276
+ "src_tokens": sources_list[0],
277
+ "src_lengths": lengths_list[0],
278
+ "prev_output_tokens": prev_output_tokens,
279
+ "tgt_lengths": collated_fbanks_size_in,
280
+ "spkembs": spkembs,
281
+ "task_name": "t2s",
282
+ }
283
+ batch = {
284
+ "id": torch.LongTensor([s["id"] for s in samples]),
285
+ "name": [s["audio_name"] for s in samples],
286
+ "net_input": net_input,
287
+ "labels": labels,
288
+ "dec_target": collated_fbanks,
289
+ "dec_target_lengths": collated_fbanks_size,
290
+ "src_lengths": lengths_list[0],
291
+ "task_name": "t2s",
292
+ "ntokens": ntokens_list[0],
293
+ "target": collated_fbanks,
294
+ }
295
+
296
+ return batch
297
+
298
+ def collater_seq_label(self, targets, pad):
299
+ lengths = torch.LongTensor([len(t) for t in targets])
300
+ ntokens = lengths.sum().item()
301
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
302
+ return targets, lengths, ntokens
303
+
304
+ def collater_label(self, targets_by_label):
305
+ targets_list, lengths_list, ntokens_list = [], [], []
306
+ itr = zip(targets_by_label, [self.src_dict.pad()])
307
+ for targets, pad in itr:
308
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
309
+ targets_list.append(targets)
310
+ lengths_list.append(lengths)
311
+ ntokens_list.append(ntokens)
312
+ return targets_list, lengths_list, ntokens_list
313
+
314
+ def num_tokens(self, index):
315
+ return self.size(index)
316
+
317
+ def size(self, index):
318
+ return self.wav_sizes[index]
319
+
320
+ @property
321
+ def sizes(self):
322
+ return np.array(self.wav_sizes)
323
+
324
+ def ordered_indices(self):
325
+ if self.shuffle:
326
+ order = [np.random.permutation(len(self))]
327
+ else:
328
+ order = [np.arange(len(self))]
329
+
330
+ order.append(self.wav_sizes)
331
+ return np.lexsort(order)[::-1]
332
+
333
+ def postprocess(self, wav, cur_sample_rate):
334
+ if wav.dim() == 2:
335
+ wav = wav.mean(-1)
336
+ assert wav.dim() == 1, wav.dim()
337
+
338
+ if cur_sample_rate != self.sample_rate:
339
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
340
+
341
+ if self.normalize:
342
+ with torch.no_grad():
343
+ wav = F.layer_norm(wav, wav.shape)
344
+ return wav
artst/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .artst import * # noqa
2
+ from .t5_transformer_lm import * # noqa
artst/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (193 Bytes). View file
 
artst/models/__pycache__/artst.cpython-38.pyc ADDED
Binary file (37.2 kB). View file
 
artst/models/__pycache__/speecht5.cpython-38.pyc ADDED
Binary file (37 kB). View file
 
artst/models/__pycache__/t5_transformer_lm.cpython-38.pyc ADDED
Binary file (733 Bytes). View file
 
artst/models/artst.py ADDED
@@ -0,0 +1,1448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
3
+ # Github source: https://github.com/mbzuai-nlp/ArTST
4
+ # Based on speecht5, fairseq and espnet code bases
5
+ # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
6
+ # --------------------------------------------------------
7
+
8
+ import logging
9
+ from ast import literal_eval
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from fairseq import utils
15
+ from fairseq.models import (
16
+ FairseqEncoderDecoderModel,
17
+ FairseqIncrementalDecoder,
18
+ register_model,
19
+ register_model_architecture,
20
+ )
21
+ from .modules.text_encoder_prenet import TextEncoderPrenet
22
+ from .modules.text_decoder_prenet import TextDecoderPrenet
23
+ from .modules.text_decoder_postnet import TextDecoderPostnet
24
+ from .modules.speech_encoder_prenet import SpeechEncoderPrenet
25
+ from .modules.speech_encoder_postnet import SpeechEncoderPostnet
26
+ from .modules.speech_decoder_prenet import SpeechDecoderPrenet
27
+ from .modules.speech_decoder_postnet import SpeechDecoderPostnet
28
+ from .modules.speaker_decoder_postnet import SpeakerDecoderPostnet
29
+ from .modules.encoder import TransformerEncoder
30
+ from .modules.decoder import TransformerDecoder
31
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
32
+ from fairseq.models.transformer import Embedding
33
+ from fairseq.modules import (
34
+ GumbelVectorQuantizer,
35
+ )
36
+ from torch import Tensor
37
+
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ DEFAULT_MAX_TEXT_POSITIONS = 450
42
+ DEFAULT_MAX_SPEECH_POSITIONS = 4000
43
+
44
+
45
+ @register_model("artst_transformer")
46
+ class ArTSTTransformerModel(FairseqEncoderDecoderModel):
47
+ """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
48
+ speech-to-text tasks. The Transformer encoder/decoder remains the same.
49
+ A trainable input subsampler is prepended to the Transformer encoder to
50
+ project inputs into the encoder dimension as well as downsample input
51
+ sequence for computational efficiency."""
52
+
53
+ def __init__(
54
+ self,
55
+ args,
56
+ encoder, decoder,
57
+ text_encoder_prenet, speech_encoder_prenet,
58
+ text_decoder_prenet, speech_decoder_prenet,
59
+ text_decoder_postnet, speech_decoder_postnet,
60
+ speaker_decoder_postnet, speech_encoder_postnet,
61
+ ):
62
+ super().__init__(encoder, decoder)
63
+
64
+ self.encoder = encoder
65
+ self.decoder = decoder
66
+
67
+ self.text_encoder_prenet = text_encoder_prenet
68
+ self.speech_encoder_prenet = speech_encoder_prenet
69
+
70
+ self.text_decoder_prenet = text_decoder_prenet
71
+ self.speech_decoder_prenet = speech_decoder_prenet
72
+
73
+ self.text_decoder_postnet = text_decoder_postnet
74
+ self.speech_decoder_postnet = speech_decoder_postnet
75
+ self.speaker_decoder_postnet = speaker_decoder_postnet
76
+
77
+ self.hubert_layer = speech_encoder_postnet
78
+
79
+ self.reduction_factor = args.reduction_factor
80
+ self.spk_embed_dim = args.spk_embed_dim
81
+
82
+ # define projection layer
83
+ self.spk_embed_integration_type = args.spk_embed_integration_type
84
+ if self.spk_embed_dim is not None and self.spk_embed_integration_type != 'pre':
85
+ if self.spk_embed_integration_type == "add":
86
+ self.projection = torch.nn.Linear(self.spk_embed_dim, args.decoder_embed_dim)
87
+ else:
88
+ self.projection = torch.nn.Linear(
89
+ args.decoder_embed_dim + self.spk_embed_dim, args.decoder_embed_dim
90
+ )
91
+
92
+ # Hawau: here we can add language embedding integration
93
+
94
+ self.use_codebook = args.use_codebook
95
+ self.codebook_prob = getattr(args, "codebook_prob", 0.5) # args.codebook_prob
96
+ if self.use_codebook:
97
+ vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim
98
+ self.quantizer = GumbelVectorQuantizer(
99
+ dim=args.encoder_embed_dim,
100
+ num_vars=args.latent_vars,
101
+ temp=args.latent_temp,
102
+ groups=args.latent_groups,
103
+ combine_groups=False,
104
+ vq_dim=vq_dim,
105
+ time_first=True,
106
+ weight_proj_depth=args.quantizer_depth,
107
+ weight_proj_factor=args.quantizer_factor,
108
+ )
109
+
110
+ self.num_updates = 0
111
+
112
+ # # Follow BERT's random weight initialization (for BART)
113
+ if args.bert_init:
114
+ self.apply(init_bert_params)
115
+ self.args = args
116
+ self.prune_modules(args.modules_filter)
117
+
118
+ @staticmethod
119
+ def add_args(parser):
120
+ """Add model-specific arguments to the parser."""
121
+ # Transformer
122
+ parser.add_argument(
123
+ "--activation-fn",
124
+ type=str,
125
+ choices=utils.get_available_activation_fns(),
126
+ help="activation function to use",
127
+ )
128
+ parser.add_argument(
129
+ "--dropout", type=float, metavar="D", help="dropout probability"
130
+ )
131
+ parser.add_argument(
132
+ "--attention-dropout",
133
+ type=float,
134
+ metavar="D",
135
+ help="dropout probability for attention weights",
136
+ )
137
+ parser.add_argument(
138
+ "--activation-dropout",
139
+ "--relu-dropout",
140
+ type=float,
141
+ metavar="D",
142
+ help="dropout probability after activation in FFN.",
143
+ )
144
+ parser.add_argument(
145
+ "--encoder-embed-dim",
146
+ type=int,
147
+ metavar="N",
148
+ help="encoder embedding dimension",
149
+ )
150
+ parser.add_argument(
151
+ "--encoder-ffn-embed-dim",
152
+ type=int,
153
+ metavar="N",
154
+ help="encoder embedding dimension for FFN",
155
+ )
156
+ parser.add_argument(
157
+ "--encoder-layers", type=int, metavar="N", help="num encoder layers"
158
+ )
159
+ parser.add_argument(
160
+ "--encoder-attention-heads",
161
+ type=int,
162
+ metavar="N",
163
+ help="num encoder attention heads",
164
+ )
165
+ parser.add_argument(
166
+ "--encoder-normalize-before",
167
+ action="store_true",
168
+ help="apply layernorm before each encoder block",
169
+ )
170
+ parser.add_argument(
171
+ "--decoder-normalize-before",
172
+ action="store_true",
173
+ help="apply layernorm before each decoder block",
174
+ )
175
+ parser.add_argument(
176
+ "--decoder-embed-dim",
177
+ type=int,
178
+ metavar="N",
179
+ help="decoder embedding dimension",
180
+ )
181
+ parser.add_argument(
182
+ "--decoder-ffn-embed-dim",
183
+ type=int,
184
+ metavar="N",
185
+ help="decoder embedding dimension for FFN",
186
+ )
187
+ parser.add_argument(
188
+ "--decoder-layers", type=int, metavar="N", help="num decoder layers"
189
+ )
190
+ parser.add_argument(
191
+ "--decoder-attention-heads",
192
+ type=int,
193
+ metavar="N",
194
+ help="num decoder attention heads",
195
+ )
196
+ parser.add_argument(
197
+ "--reduction-factor",
198
+ type=int,
199
+ help="reduction factor for decoder",
200
+ )
201
+ parser.add_argument(
202
+ "--spk-embed-dim",
203
+ type=int,
204
+ help="speaker embedding dimension",
205
+ )
206
+ parser.add_argument(
207
+ "--layernorm-embedding",
208
+ action="store_true",
209
+ help="add layernorm to embedding",
210
+ )
211
+ parser.add_argument(
212
+ "--load-pretrained-encoder-from",
213
+ type=str,
214
+ metavar="STR",
215
+ help="model to take encoder weights from (for initialization)",
216
+ )
217
+ parser.add_argument(
218
+ '--freeze-encoder-updates',
219
+ type=int,
220
+ help='number of steps to freeze encoder before finetune'
221
+ )
222
+ parser.add_argument(
223
+ '--freeze-decoder-updates',
224
+ type=int,
225
+ help='number of steps to freeze decoder before finetune'
226
+ )
227
+ parser.add_argument(
228
+ '--no-freeze-encoder-layer',
229
+ type=str,
230
+ help='which encoder layer not freeze during finetune'
231
+ )
232
+ parser.add_argument(
233
+ "--share-input-output-embed",
234
+ action="store_true",
235
+ help="share decoder input and output embeddings",
236
+ )
237
+ parser.add_argument(
238
+ "--share-ctc-embed",
239
+ action="store_true",
240
+ help="share ctc embed and decoder embed",
241
+ )
242
+ parser.add_argument(
243
+ "--encoder-sliding-window-attn",
244
+ default=None,
245
+ type=int,
246
+ help="If not None but a even number, set sliding window attention to encoder's attn_mask, e.g., 4, 10, and 20",
247
+ )
248
+
249
+ # Convolutional subsampler
250
+ parser.add_argument(
251
+ "--encoder-speech-prenet",
252
+ default="conv",
253
+ type=str,
254
+ choices=["conv", "linear"],
255
+ help="The type of encoder speech prenet, e.g., conv or linear."
256
+ )
257
+ parser.add_argument(
258
+ "--conv-kernel-sizes",
259
+ default="5,5",
260
+ type=str,
261
+ help="The layer of convolution of encoder speech prenet."
262
+ )
263
+ parser.add_argument(
264
+ "--conv-channels",
265
+ default=1024,
266
+ type=int,
267
+ help="The channels of encoder speech prenet."
268
+ )
269
+ parser.add_argument(
270
+ "--subsample-stride",
271
+ default="2,2",
272
+ type=str,
273
+ help="The subsample stride for conv1dsubsample."
274
+ )
275
+ parser.add_argument(
276
+ "--spk-embed-integration-type",
277
+ type=str,
278
+ choices=["pre", "add"],
279
+ help="speaker embedding integration type"
280
+ )
281
+ parser.add_argument(
282
+ "--dprenet-dropout-rate",
283
+ default=0.5,
284
+ type=float,
285
+ help="The dropout rate of decoder speech prenet."
286
+ )
287
+
288
+ ## SE
289
+ parser.add_argument(
290
+ "--se-predict",
291
+ default=None,
292
+ choices=["masking", "target", "delta"],
293
+ help="If set, source speech inputs decoder to predict the masking/target/delta of corresponding inputs."
294
+ + "masking is [0, 1], target is predicted output, delta is difference between inputs and outputs",
295
+ )
296
+ parser.add_argument(
297
+ "--se-decoder-input",
298
+ type=str,
299
+ default="previous_target",
300
+ choices=["previous_target", "source"],
301
+ )
302
+
303
+ ## SID
304
+ parser.add_argument(
305
+ "--modules-filter",
306
+ default=None,
307
+ type=str,
308
+ help="Remove unused modules for, e.g., SID.",
309
+ )
310
+ parser.add_argument(
311
+ "--sid-pad-prenet",
312
+ action="store_true",
313
+ help="If set, the size of text dictionary is as small as for <pad> token.",
314
+ )
315
+ parser.add_argument(
316
+ "--encoder-attn-branch",
317
+ type=str,
318
+ default="identity,full",
319
+ help="encoder attention branch sliding window, e.g., 'identity,0,2,4,full'",
320
+ )
321
+ parser.add_argument(
322
+ "--encoder-block-branch",
323
+ type=str,
324
+ help="average the output of encoder, e.g., '4,5,6'",
325
+ )
326
+ parser.add_argument(
327
+ "--sid-encoder-cls",
328
+ default=None,
329
+ choices=["encoder"],
330
+ help="If set, add cls vector to the encoder input, e.g., constant vector.",
331
+ )
332
+ parser.add_argument(
333
+ "--sid-shuffle-encoder-input",
334
+ action="store_true",
335
+ help="If set, shuffle encoder input in time.",
336
+ )
337
+ parser.add_argument(
338
+ "--sid-decoder-speaker",
339
+ action="store_true",
340
+ help="If set, apply speaker decoder as transformer decoder.",
341
+ )
342
+ parser.add_argument(
343
+ "--sid-decoder-attn-dim",
344
+ default=128,
345
+ type=int,
346
+ help="Attention dimension in attensive statistics pooling of speaker decoder.",
347
+ )
348
+ parser.add_argument(
349
+ "--sid-t5-postnet",
350
+ action="store_true",
351
+ help="If set, apply TextDecoderPostnet as speaker classification.",
352
+ )
353
+ parser.add_argument(
354
+ "--sid-embed-dim",
355
+ default=128,
356
+ type=int,
357
+ help="Embedding dimension in speaker postnet for speaker identification if embed postnet.",
358
+ )
359
+ parser.add_argument(
360
+ "--sid-pooling-layer",
361
+ default="decoder",
362
+ type=str,
363
+ choices=["decoder-las", "decoder", "encoder", "encoder-cls", "encoder-speaker"],
364
+ help="The output of decoder or encoder uses as SID pooling layer over temporal dimension.",
365
+ )
366
+ parser.add_argument(
367
+ "--sid-no-pooling-bn",
368
+ action="store_true",
369
+ help="If set, not attention batchnorm.",
370
+ )
371
+ parser.add_argument(
372
+ "--sid-no-embed-postnet",
373
+ action="store_true",
374
+ help="If set, no layer between decoder output and classification layer.",
375
+ )
376
+ parser.add_argument(
377
+ "--sid-normalize-postnet",
378
+ action="store_true",
379
+ help="If set, normalize input and weight in postnet/classifier.",
380
+ )
381
+ parser.add_argument(
382
+ "--sid-softmax-type",
383
+ default="softmax",
384
+ choices=["softmax", "amsoftmax", "aamsoftmax"],
385
+ help="If using amsoftmax or aamsoftmax, the target should be given.",
386
+ )
387
+ parser.add_argument(
388
+ "--softmax-scale",
389
+ default=1.0,
390
+ type=float,
391
+ help="Scale for AMSoftmax or AAMSoftmax.",
392
+ )
393
+ parser.add_argument(
394
+ "--softmax-margin",
395
+ default=0.0,
396
+ type=float,
397
+ help="Margin for AMSoftmax or AAMSoftmax.",
398
+ )
399
+ parser.add_argument(
400
+ "--softmax-easy-margin",
401
+ action="store_true",
402
+ help="Enable easy margin for AAMSoftmax.",
403
+ )
404
+ parser.add_argument(
405
+ "--encoder-layerdrop",
406
+ type=float,
407
+ metavar="D",
408
+ help="LayerDrop probability for encoder",
409
+ )
410
+ parser.add_argument(
411
+ "--decoder-layerdrop",
412
+ type=float,
413
+ metavar="D",
414
+ help="LayerDrop probability for decoder",
415
+ )
416
+
417
+ ## Hubert
418
+ parser.add_argument(
419
+ '--feature-grad-mult',
420
+ type=float,
421
+ help='multiply feature extractor var grads by this'
422
+ )
423
+ parser.add_argument(
424
+ '--logit-temp',
425
+ type=float,
426
+ help='temperature to divide logits by'
427
+ )
428
+ parser.add_argument(
429
+ '--final-dim',
430
+ type=int,
431
+ help="project final representations and targets to this many "
432
+ "dimensions. set to encoder_embed_dim is <= 0"
433
+ )
434
+
435
+ # mask
436
+ parser.add_argument(
437
+ '--hubert-mask-length',
438
+ type=int,
439
+ help='mask length'
440
+ )
441
+ parser.add_argument(
442
+ '--mask-prob',
443
+ type=float,
444
+ help='probability of replacing a token with mask'
445
+ )
446
+ parser.add_argument(
447
+ "--mask-selection",
448
+ choices=["static", "uniform", "normal", "poisson"],
449
+ help="how to choose mask length",
450
+ )
451
+ parser.add_argument(
452
+ '--mask-other',
453
+ type=float,
454
+ help="secondary mask argument "
455
+ "(used for more complex distributions), "
456
+ "see help in compute_mask_indices"
457
+ )
458
+ parser.add_argument(
459
+ '--mask-min-space',
460
+ type=int,
461
+ help='min space between spans (if no overlap is enabled)'
462
+ )
463
+
464
+ # channel masking
465
+ parser.add_argument(
466
+ '--mask-channel-length',
467
+ type=int,
468
+ help='length of the mask for features (channels)'
469
+ )
470
+ parser.add_argument(
471
+ '--mask-channel-prob',
472
+ type=float,
473
+ help="probability of replacing a feature with 0"
474
+ )
475
+ parser.add_argument(
476
+ "--mask-channel-selection",
477
+ choices=["static", "uniform", "normal", "poisson"],
478
+ help="how to choose mask length for channel masking",
479
+ )
480
+ parser.add_argument(
481
+ '--mask-channel-other',
482
+ type=float,
483
+ help="secondary mask argument "
484
+ "(used for more complex distributions), "
485
+ "see help in compute_mask_indices"
486
+ )
487
+ parser.add_argument(
488
+ '--mask-channel-min-space',
489
+ type=int,
490
+ help='min space between spans (if no overlap is enabled)'
491
+ )
492
+
493
+ # abs positional embeddings
494
+ parser.add_argument(
495
+ '--conv-pos',
496
+ type=int,
497
+ help='number of filters for convolutional positional embeddings'
498
+ )
499
+ parser.add_argument(
500
+ '--conv-pos-groups',
501
+ type=int,
502
+ help='number of groups for convolutional positional embedding'
503
+ )
504
+
505
+ # codebook related
506
+ parser.add_argument(
507
+ "--use-codebook",
508
+ action="store_true",
509
+ help="whether to use codebook",
510
+ )
511
+ parser.add_argument(
512
+ "--codebook-prob",
513
+ type=float,
514
+ help="probability to use codebook",
515
+ )
516
+ parser.add_argument(
517
+ "--latent-vars",
518
+ type=int,
519
+ help="number of latent variables V in each group of the codebook",
520
+ )
521
+ parser.add_argument(
522
+ "--latent-groups",
523
+ type=int,
524
+ help="number of groups G of latent variables in the codebook",
525
+ )
526
+ parser.add_argument(
527
+ "--latent-dim",
528
+ type=int,
529
+ help="if > 0, uses this dimensionality for latent variables. "
530
+ "otherwise uses final_dim / latent_groups",
531
+ )
532
+ parser.add_argument(
533
+ "--latent-temp",
534
+ type=literal_eval,
535
+ help="temperature for latent variable sampling. "
536
+ "can be tuple of 3 values (start, end, decay)",
537
+ )
538
+ parser.add_argument(
539
+ "--quantizer-depth",
540
+ type=int,
541
+ help="number of quantizer layers",
542
+ )
543
+ parser.add_argument(
544
+ "--quantizer-factor",
545
+ type=int,
546
+ help="number of quantizer layers",
547
+ )
548
+ parser.add_argument(
549
+ "--get-code-distribution",
550
+ action='store_true',
551
+ help="whether to get the code distribution (for test)",
552
+ )
553
+
554
+ # relative pos enc
555
+ parser.add_argument(
556
+ "--relative-position-embedding",
557
+ action='store_true',
558
+ help="whether to use relative position embedding",
559
+ )
560
+ parser.add_argument(
561
+ "--num-buckets",
562
+ type=int,
563
+ default=320,
564
+ help="num of buckets for relative position embedding",
565
+ )
566
+ parser.add_argument(
567
+ "--max-distance",
568
+ type=int,
569
+ default=1280,
570
+ help="max distance for relative position embedding",
571
+ )
572
+ parser.add_argument(
573
+ "--encoder-max-relative-position",
574
+ type=int,
575
+ help="max distance for relative position embedding in encoder",
576
+ )
577
+ parser.add_argument(
578
+ "--decoder-max-relative-position",
579
+ type=int,
580
+ help="max distance for relative position embedding in decoder",
581
+ )
582
+
583
+ # hubert feature extractor
584
+ parser.add_argument(
585
+ "--conv-feature-layers",
586
+ type=str,
587
+ help= "string describing convolutional feature extraction "
588
+ "layers in form of a python list that contains "
589
+ "[(dim, kernel_size, stride), ...]",
590
+ )
591
+ parser.add_argument(
592
+ "--conv-bias",
593
+ action='store_true',
594
+ help="include bias in conv encoder",
595
+ )
596
+ parser.add_argument(
597
+ "--extractor-mode",
598
+ choices=["default", "layer_norm"],
599
+ help="mode for feature extractor. default has a single group "
600
+ "norm with d groups in the first conv block, whereas layer_norm "
601
+ "has layer norms in every block (meant to use with normalize=True)"
602
+ )
603
+
604
+ # others
605
+ parser.add_argument(
606
+ "--bert-init",
607
+ action='store_true',
608
+ help="initilize as bert",
609
+ )
610
+ parser.add_argument(
611
+ "--unb-enc-layer",
612
+ type=int,
613
+ default=-1,
614
+ help="which layer's output is used as the input of decoder",
615
+ )
616
+
617
+ # Encoder, Decoder
618
+ @classmethod
619
+ def build_encoder(cls, args, dictionary=None, embed_tokens=None):
620
+ return TransformerEncoder(args, dictionary, embed_tokens)
621
+
622
+ @classmethod
623
+ def build_decoder(cls, args):
624
+ return TransformerDecoder(args)
625
+
626
+ # Encoder Prenet
627
+ @classmethod
628
+ def build_text_encoder_prenet(cls, embed_tokens, args):
629
+ return TextEncoderPrenet(embed_tokens, args)
630
+
631
+ @classmethod
632
+ def build_speech_encoder_prenet(cls, args):
633
+ return SpeechEncoderPrenet(args)
634
+
635
+ # Decoder Prenet
636
+ @classmethod
637
+ def build_text_decoder_prenet(cls, embed_tokens, args):
638
+ return TextDecoderPrenet(embed_tokens, args)
639
+
640
+ @classmethod
641
+ def build_speech_decoder_prenet(cls, odim, args):
642
+ return SpeechDecoderPrenet(odim, args)
643
+
644
+ # Decoder Postnet
645
+ @classmethod
646
+ def build_text_decoder_postnet(cls, embed_tokens, dictionary, args):
647
+ return TextDecoderPostnet(embed_tokens, dictionary, args)
648
+
649
+ @classmethod
650
+ def build_speaker_decoder_postnet(cls, embed_dim, class_num, args):
651
+ return SpeakerDecoderPostnet(embed_dim, class_num, args)
652
+
653
+ @classmethod
654
+ def build_speech_decoder_postnet(cls, odim, args):
655
+ return SpeechDecoderPostnet(odim, args)
656
+
657
+ @classmethod
658
+ def build_speech_encoder_postnet(cls, dictionaries, args):
659
+ return SpeechEncoderPostnet(dictionaries, args)
660
+
661
+ @classmethod
662
+ def build_model(cls, args, task):
663
+ """Build a new model instance."""
664
+
665
+ # make sure all arguments are present in older models
666
+ base_architecture(args)
667
+
668
+ def build_embedding(dictionary, embed_dim, max_num_embeddings=None):
669
+ num_embeddings = len(dictionary)
670
+ if max_num_embeddings is not None and isinstance(max_num_embeddings, int):
671
+ num_embeddings = min(num_embeddings, max_num_embeddings)
672
+ padding_idx = dictionary.pad()
673
+ return Embedding(num_embeddings, embed_dim, padding_idx)
674
+
675
+ if hasattr(args, "sid_pad_prenet") and args.sid_pad_prenet:
676
+ max_num_embeddings = 3 # <pad> at index 2
677
+ else:
678
+ max_num_embeddings = None
679
+
680
+ text_decoder_embed_tokens = build_embedding(
681
+ task.dicts["text"], args.decoder_embed_dim, max_num_embeddings
682
+ )
683
+
684
+ if args.share_input_output_embed:
685
+ text_encoder_embed_tokens = text_decoder_embed_tokens
686
+ else:
687
+ text_encoder_embed_tokens = build_embedding(
688
+ task.dicts["text"], args.encoder_embed_dim
689
+ )
690
+
691
+ speech_odim = args.speech_odim
692
+ if "text" in task.dicts:
693
+ encoder = cls.build_encoder(args, task.dicts["text"], text_encoder_embed_tokens)
694
+ else:
695
+ encoder = cls.build_encoder(args)
696
+ decoder = cls.build_decoder(args)
697
+
698
+ text_encoder_prenet = cls.build_text_encoder_prenet(text_encoder_embed_tokens, args)
699
+ speech_encoder_prenet = cls.build_speech_encoder_prenet(args)
700
+
701
+ text_decoder_prenet = cls.build_text_decoder_prenet(text_decoder_embed_tokens, args)
702
+ if getattr(args, "sid_pooling_layer", None) == "decoder-las":
703
+ speech_decoder_prenet = cls.build_speech_encoder_prenet(args)
704
+ else:
705
+ speech_decoder_prenet = cls.build_speech_decoder_prenet(speech_odim, args)
706
+
707
+ text_decoder_postnet = cls.build_text_decoder_postnet(text_decoder_embed_tokens, task.dicts['text'], args)
708
+ speech_decoder_postnet = cls.build_speech_decoder_postnet(speech_odim, args)
709
+
710
+ if getattr(args, "sid_t5_postnet", False):
711
+ speaker_decoder_postnet = None
712
+ else:
713
+ if task.t5_task == "s2c":
714
+ speaker_decoder_postnet = cls.build_speaker_decoder_postnet(args.sid_embed_dim, len(task.dicts['text']), args)
715
+ else:
716
+ speaker_decoder_postnet = None
717
+
718
+ if "hubert" in task.dicts:
719
+ speech_encoder_postnet = cls.build_speech_encoder_postnet(task.dicts['hubert'], args)
720
+ else:
721
+ speech_encoder_postnet = None
722
+
723
+ return cls(
724
+ args,
725
+ encoder, decoder,
726
+ text_encoder_prenet, speech_encoder_prenet,
727
+ text_decoder_prenet, speech_decoder_prenet,
728
+ text_decoder_postnet, speech_decoder_postnet,
729
+ speaker_decoder_postnet, speech_encoder_postnet,
730
+ )
731
+
732
+ def get_normalized_probs(
733
+ self,
734
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
735
+ log_probs: bool,
736
+ sample: Optional[Dict[str, Tensor]] = None,
737
+ ):
738
+ # net_output['encoder_out'] is a (B, T, D) tensor
739
+ lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
740
+ lprobs.batch_first = True
741
+ return lprobs
742
+
743
+ def get_normalized_probs_for_ctc(self, net_output, log_probs):
744
+ """Get normalized probabilities (or log probs) from a net's output."""
745
+
746
+ logits = net_output["encoder_out_for_ctc"][0]
747
+ if log_probs:
748
+ return utils.log_softmax(logits.float(), dim=-1)
749
+ else:
750
+ return utils.softmax(logits.float(), dim=-1)
751
+
752
+ def get_logits(self, net_output, is_masked=True):
753
+ if is_masked:
754
+ logits_list = net_output["logit_m_list"]
755
+ else:
756
+ logits_list = net_output["logit_u_list"]
757
+ logits_list = [x.float() for x in logits_list if x is not None]
758
+ return logits_list
759
+
760
+ def get_targets(self, sample, net_output, is_masked=True):
761
+ if "logit_m_list" in net_output:
762
+ logits_list = self.get_logits(net_output, is_masked)
763
+ targets_list = [
764
+ x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list
765
+ ]
766
+ return targets_list
767
+ else:
768
+ return sample["target"]
769
+
770
+ def get_extra_losses(self, net_output):
771
+ extra_losses = []
772
+ names = []
773
+
774
+ if "features_pen" in net_output:
775
+ extra_losses.append(net_output["features_pen"])
776
+ names.append("features_pen")
777
+
778
+ if "prob_perplexity" in net_output:
779
+ extra_losses.append(
780
+ (net_output["num_vars"] - net_output["prob_perplexity"])
781
+ / net_output["num_vars"]
782
+ )
783
+ names.append("prob_perplexity")
784
+
785
+ return extra_losses, names
786
+
787
+ def forward(self, source=None, src_tokens=None, src_lengths=None, prev_output_tokens=None, tgt_lengths=None, spkembs=None, target_list=None, task_name=None, padding_mask=None, only_hubert=False, only_ctc=False, feature_only=False, tgt_enc_layer=None, mask=True):
788
+ """
789
+ The forward method inherited from the base class has a **kwargs
790
+ argument in its input, which is not supported in torchscript. This
791
+ method overwrites the forward method definition without **kwargs.
792
+ """
793
+ assert source is not None or src_tokens is not None
794
+ # padding_mask is not none only when input is waveform
795
+ if source is None and padding_mask is None and not feature_only:
796
+ input_type = 'text'
797
+ else:
798
+ input_type = 'speech'
799
+
800
+ if prev_output_tokens is not None and len(prev_output_tokens.size()) == 2:
801
+ output_type = 'text'
802
+ codebook_out = {}
803
+ else:
804
+ output_type = 'speech'
805
+
806
+ if task_name is not None and task_name == "s2c":
807
+ if target_list is not None and target_list.size(1) == 1 and not getattr(self.args, "sid_t5_postnet", False):
808
+ sid_target = F.one_hot(target_list.squeeze(1), num_classes=self.speaker_decoder_postnet.class_num)
809
+ else:
810
+ sid_target = None
811
+ target_list = None
812
+
813
+ # Encoder Prenet
814
+ if input_type == 'text':
815
+ encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
816
+ else:
817
+ if target_list is not None:
818
+ encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, require_feat_pen=True, target_list=target_list, padding_mask=padding_mask, mask=mask)
819
+ encoder_input, features_pen, mask_indices, target_list = encoder_input
820
+ else:
821
+ encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=self.training)
822
+ # shuffle a batch of inputs of encoder
823
+ if self.training and hasattr(self.args, "sid_shuffle_encoder_input") and getattr(self.args, "sid_shuffle_encoder_input", False):
824
+ shuffle_index = torch.randperm(encoder_padding_mask.size(1), device=encoder_padding_mask.device)
825
+ encoder_input = torch.index_select(encoder_input, 1, shuffle_index)
826
+ encoder_padding_mask = torch.index_select(encoder_padding_mask, 1, shuffle_index)
827
+ if getattr(self.args, "sid_encoder_cls", None) == "encoder":
828
+ prev_output_tokens = torch.zeros_like(prev_output_tokens)
829
+ encoder_input, encoder_padding_mask = self._integrate_with_speaker_cls(prev_output_tokens, encoder_input, encoder_padding_mask)
830
+
831
+ # Encoder: T x B x C
832
+ encoder_output = self.encoder(encoder_input, encoder_padding_mask, tgt_layer=tgt_enc_layer)
833
+
834
+ if task_name is not None and task_name == 'speech_pretrain' and feature_only:
835
+ return encoder_output["encoder_out"][0].transpose(0, 1)
836
+
837
+ if task_name is not None and task_name == 's2c':
838
+ if self.args.sid_pooling_layer == "encoder":
839
+ return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1).mean(1), sid_target), None
840
+ elif self.args.sid_pooling_layer == "encoder-cls":
841
+ return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1)[:,0], sid_target), None
842
+ elif self.args.sid_pooling_layer == "encoder-speaker" or getattr(self.args, "sid_decoder_speaker", False):
843
+ return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1), sid_target), None
844
+
845
+ if target_list is not None:
846
+ hubert_results = self.hubert_layer(
847
+ encoder_output["encoder_out"][0].transpose(0, 1),
848
+ encoder_padding_mask,
849
+ mask_indices,
850
+ target_list
851
+ )
852
+
853
+ hubert_results['features_pen'] = features_pen
854
+
855
+ if "decoder_input" in encoder_output and encoder_output["decoder_input"][0] is not None:
856
+ # Change the encoder output to decoder input once set unb-enc-layer
857
+ encoder_output["encoder_out"] = encoder_output["decoder_input"]
858
+
859
+ if self.use_codebook:
860
+ q = self.quantizer(encoder_output["encoder_out"][0].transpose(0, 1))
861
+
862
+ # q["x"]: B x T x C
863
+ # Sample indexs according to the codebook prob
864
+ random_idx = torch.randperm(q["x"].size(1))[:int(q["x"].size(1) * self.codebook_prob)]
865
+ # Make weight for q
866
+ q_w = q["x"].new_zeros(q["x"].size(1))
867
+ q_w[random_idx] = 1.0
868
+ # Combine quantized codes and encoder output
869
+ encoder_output["encoder_out"][0] = (
870
+ q_w.view(-1, 1) * q["x"] + (- q_w + 1).view(-1, 1) * encoder_output["encoder_out"][0].transpose(0, 1)
871
+ ).transpose(0, 1)
872
+
873
+ # encoder_output["encoder_out"][0] = q["x"].transpose(0, 1)
874
+ if output_type == 'speech':
875
+ hubert_results["prob_perplexity"] = q["prob_perplexity"]
876
+ hubert_results["code_perplexity"] = q["code_perplexity"]
877
+ hubert_results["num_vars"] = q["num_vars"]
878
+ hubert_results["temp"] = q["temp"]
879
+ elif output_type == 'text':
880
+ codebook_out["prob_perplexity"] = q["prob_perplexity"]
881
+ codebook_out["code_perplexity"] = q["code_perplexity"]
882
+ codebook_out["num_vars"] = q["num_vars"]
883
+ codebook_out["temp"] = q["temp"]
884
+
885
+ if only_hubert and target_list is not None:
886
+ return hubert_results, None
887
+
888
+ if only_ctc and task_name is not None and task_name == "s2t":
889
+ return None, encoder_output
890
+ elif not self.training and prev_output_tokens is None and task_name == "s2t" and task_name is not None:
891
+ return encoder_output
892
+
893
+ # Decoder Prenet
894
+ if output_type == 'text':
895
+ # _ is the incremental state
896
+ prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens)
897
+ if task_name is not None and task_name == 's2c':
898
+ prev_output_tokens = torch.zeros_like(prev_output_tokens)
899
+ else:
900
+ # integrate speaker embedding
901
+ if self.spk_embed_integration_type == "pre" and self.spk_embed_dim is not None:
902
+ # Decoder Prenet
903
+ prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths, spkembs)
904
+ else:
905
+ if self.spk_embed_dim is not None:
906
+ encoder_output["encoder_out"] = [self._integrate_with_spk_embed(
907
+ encoder_output["encoder_out"][0].transpose(0, 1), spkembs
908
+ ).transpose(0, 1)]
909
+
910
+ prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths)
911
+
912
+ # BART Sequence Classification: cat <pad> + feature before decoder
913
+ if task_name is not None and task_name == 's2c' and self.args.sid_pooling_layer == "decoder-las":
914
+ decoder_feat_input, decoder_feat_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
915
+ prev_output_tokens, tgt_mask = self._integrate_with_speaker_cls((prev_output_tokens, tgt_mask), decoder_feat_input, decoder_feat_mask, cls_first=False)
916
+
917
+ # SE predict masking to corresponding inputs and source speech replaces the prev_output_tokens as the input of decoder
918
+ if task_name is not None and task_name == "s2s" and getattr(self.args, "se_decoder_input", "previous_target") == "source":
919
+ prev_output_tokens, tgt_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
920
+
921
+ # Decoder
922
+ decoder_output, extra = self.decoder(prev_output_tokens, tgt_mask, encoder_output,
923
+ full_context_alignment=getattr(self.args, "decoder_full_context_alignment", False),
924
+ alignment_layer=(-1 if target_list is None and output_type == 'speech' else None))
925
+ # Decoder Postnet
926
+ if task_name is not None and task_name == 's2c':
927
+ if not getattr(self.args, "sid_t5_postnet", False):
928
+ if self.args.sid_pooling_layer == "decoder":
929
+ return self.speaker_decoder_postnet(decoder_output.mean(1), sid_target), None
930
+ elif self.args.sid_pooling_layer == "decoder-las":
931
+ indices = (tgt_mask.eq(False).float().sum(1) - 1.0).type(torch.int64)
932
+ indices = indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, decoder_output.size(2))
933
+ return self.speaker_decoder_postnet(decoder_output.gather(1, indices), sid_target), None
934
+ else:
935
+ return (self.text_decoder_postnet(decoder_output), None), encoder_output
936
+
937
+ # SE predict: masking, target, delta. Ensure reduction factor 1
938
+ if task_name is not None and task_name == 's2s' and getattr(self.args, "se_predict", None) is not None:
939
+ assert self.reduction_factor == 1, f"{self.reduction_factor} != 1"
940
+ before_outs, after_outs, logits = self.speech_decoder_postnet(decoder_output)
941
+ se_predict = getattr(self.args, "se_predict")
942
+ if se_predict == "masking":
943
+ before_outs = torch.sigmoid(before_outs) * src_tokens
944
+ after_outs = torch.sigmoid(after_outs) * src_tokens
945
+ return before_outs, after_outs, logits, extra['attn'][0]
946
+ elif se_predict == "target":
947
+ return before_outs, after_outs, logits, extra['attn'][0]
948
+ elif se_predict == "delta":
949
+ before_outs = before_outs - src_tokens
950
+ after_outs = after_outs - src_tokens
951
+ return before_outs, after_outs, logits, extra['attn'][0]
952
+ else:
953
+ raise ValueError(f"{se_predict} not in [masking, target, delta]")
954
+
955
+ if task_name is not None and task_name == 's2t':
956
+ #return self.text_decoder_postnet(decoder_output), None
957
+ return (self.text_decoder_postnet(decoder_output), None), encoder_output
958
+ if output_type == 'text':
959
+ return (self.text_decoder_postnet(decoder_output), None), codebook_out, encoder_output
960
+ else:
961
+ if target_list is not None:
962
+ return hubert_results, (self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],))
963
+ else:
964
+ return self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],)
965
+
966
+ def _integrate_with_speaker_cls(self, pad_input, encoder_input, encoder_padding_mask=None, cls_first=True):
967
+ """
968
+ encoder_input: [B, T, C]
969
+ encoder_padding_mask: [B, T]
970
+ """
971
+ if hasattr(self, "text_decoder_prenet"):
972
+ if isinstance(pad_input, tuple):
973
+ repeat_cls_vector, repeat_cls_mask = pad_input
974
+ else:
975
+ repeat_cls_vector, repeat_cls_mask, _ = self.text_decoder_prenet(pad_input)
976
+
977
+ if encoder_padding_mask is not None:
978
+ bsz = encoder_input.size(0)
979
+ tsz = encoder_input.size(1)
980
+ encoder_padding_mask = encoder_input.new_zeros((bsz, tsz)) == 1.0
981
+ if repeat_cls_mask is None:
982
+ mask_size = (encoder_padding_mask.size(0), 1)
983
+ mask_type = encoder_padding_mask.dtype
984
+ repeat_cls_mask = encoder_padding_mask.new_zeros(mask_size) == 1.0
985
+ ret_encoder_padding_mask = torch.cat([repeat_cls_mask, encoder_padding_mask], dim=1)
986
+
987
+ if cls_first:
988
+ ret_encoder_input = torch.cat([repeat_cls_vector, encoder_input], dim=1)
989
+ else:
990
+ ret_encoder_input = torch.cat([encoder_input, encoder_input[:,-1:,:]], dim=1)
991
+ mask_size = (encoder_padding_mask.size(0), 1)
992
+ mask_type = encoder_padding_mask.dtype
993
+ repeat_cls_mask_ = encoder_padding_mask.new_ones(mask_size) == 1.0
994
+ encoder_padding_mask_ = torch.cat([encoder_padding_mask, repeat_cls_mask_], dim=1)
995
+ indices = encoder_padding_mask.eq(False).float().sum(1).type(torch.int64).unsqueeze(1)
996
+ indices_mask = torch.zeros_like(ret_encoder_padding_mask).scatter(1, indices, 1.0)
997
+ ret_encoder_input = ret_encoder_input * (1.0 - encoder_padding_mask_.type(ret_encoder_input.dtype).unsqueeze(2)) \
998
+ + repeat_cls_vector * indices_mask.type(repeat_cls_vector.dtype).unsqueeze(2)
999
+
1000
+ return ret_encoder_input, ret_encoder_padding_mask
1001
+
1002
+ def _integrate_with_spk_embed(self, hs, spembs):
1003
+ """Integrate speaker embedding with hidden states.
1004
+ Args:
1005
+ hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
1006
+ spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).
1007
+ Returns:
1008
+ Tensor: Batch of integrated hidden state sequences (B, Tmax, adim)
1009
+ """
1010
+ if self.spk_embed_integration_type == "add":
1011
+ # apply projection and then add to hidden states
1012
+ spembs = self.projection(F.normalize(spembs))
1013
+ hs = hs + spembs.unsqueeze(1)
1014
+ elif self.spk_embed_integration_type == "concat":
1015
+ # concat hidden states with spk embeds and then apply projection
1016
+ spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
1017
+ hs = self.projection(torch.cat([hs, spembs], dim=-1))
1018
+ else:
1019
+ raise NotImplementedError("support only add or concat.")
1020
+
1021
+ return hs
1022
+
1023
+ def load_state_dict(
1024
+ self,
1025
+ state_dict,
1026
+ strict=True,
1027
+ model_cfg=None,
1028
+ args=None,
1029
+ ):
1030
+ """NOT STRICT Copies parameters and buffers from *state_dict* into this module and
1031
+ its descendants.
1032
+
1033
+ Overrides the method in :class:`nn.Module`. Compared with that method
1034
+ this additionally "upgrades" *state_dicts* from old checkpoints.
1035
+ """
1036
+ # self.prune_modules(model_cfg.modules_filter)
1037
+ model_dict_size = self.text_decoder_postnet.output_projection.out_features
1038
+ ckpt_dict_size = state_dict["text_decoder_postnet.output_projection.weight"].size(0)
1039
+ if model_dict_size != ckpt_dict_size:
1040
+ # reset dictionary-related modules, such as embedding table and encoder ctc embed
1041
+ logger.warn(f"not equal dictionary between model and checkpoint: {model_dict_size} vs {ckpt_dict_size}")
1042
+ logger.info(f"reset model dictionary with size of {model_dict_size}")
1043
+ removed_keys = [
1044
+ key for key in state_dict.keys() if any(
1045
+ key.startswith(previ) for previ in [
1046
+ "encoder.proj", "text_encoder_prenet", "text_decoder_prenet", "text_decoder_postnet"
1047
+ ]
1048
+ )
1049
+ ]
1050
+ for key in removed_keys:
1051
+ state_dict.pop(key, None)
1052
+ logger.info(f"removed loaded checkpoint: {key}")
1053
+ for m in self._modules.keys():
1054
+ m_state_dict = {
1055
+ key.replace(f"{m}.", ""): value for key, value in state_dict.items() if key.startswith(f"{m}.")
1056
+ }
1057
+ if hasattr(self, m):
1058
+ self._modules[m].load_state_dict(m_state_dict, False)
1059
+ return self
1060
+
1061
+ def prune_modules(self, modules_filter=None):
1062
+ """Prune unused modules for specific tasks."""
1063
+ if modules_filter is None:
1064
+ return
1065
+ elif modules_filter == "s2c":
1066
+ if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
1067
+ if hasattr(self, "speech_decoder_prenet") and getattr(self.args, "sid_pooling_layer", None) != "decoder-las":
1068
+ del self.speech_decoder_prenet
1069
+ if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
1070
+ if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
1071
+ if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
1072
+ if hasattr(self.encoder, "proj"): self.encoder.proj = None
1073
+ if hasattr(self, "projection"): del self.projection
1074
+ if hasattr(self, "quantizer"): del self.quantizer
1075
+ if getattr(self.args, "sid_pooling_layer", "decoder").startswith("encoder") or getattr(self.args, "sid_decoder_speaker", False):
1076
+ if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
1077
+ if hasattr(self.decoder, "layers"): del self.decoder.layers
1078
+ if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
1079
+ if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
1080
+ elif modules_filter == "s2s":
1081
+ if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
1082
+ if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
1083
+ if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
1084
+ if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
1085
+ if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
1086
+ if hasattr(self.encoder, "proj"): self.encoder.proj = None
1087
+ if hasattr(self, "projection"): del self.projection
1088
+ if hasattr(self, "quantizer"): del self.quantizer
1089
+ elif modules_filter == "t2s":
1090
+ if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
1091
+ if hasattr(self, "speech_encoder_prenet"): del self.speech_encoder_prenet
1092
+ if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
1093
+ if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
1094
+ if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
1095
+ if hasattr(self.encoder, "proj"): self.encoder.proj = None
1096
+ if hasattr(self, "projection"): del self.projection
1097
+ if hasattr(self, "quantizer"): del self.quantizer
1098
+ elif modules_filter == "s3prl":
1099
+ # remain the encoder and the pre/post net
1100
+ if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
1101
+ if hasattr(self.decoder, "layers"): del self.decoder.layers
1102
+ if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
1103
+ if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
1104
+ if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
1105
+ if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
1106
+ if hasattr(self, "speech_decoder_prenet"): del self.speech_decoder_prenet
1107
+ if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
1108
+ if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
1109
+ if hasattr(self.encoder, "proj"): self.encoder.proj = None
1110
+ if hasattr(self, "projection"): del self.projection
1111
+ if hasattr(self, "quantizer"): del self.quantizer
1112
+
1113
+ def forward_encoder_torchscript(self, net_input: Dict[str, Tensor]):
1114
+ """A TorchScript-compatible version of forward.
1115
+
1116
+ Encoders which use additional arguments may want to override
1117
+ this method for TorchScript compatibility.
1118
+ """
1119
+ if torch.jit.is_scripting():
1120
+ return self.forward_encoder(
1121
+ source=net_input["source"],
1122
+ padding_mask=net_input["padding_mask"]
1123
+ )
1124
+ else:
1125
+ return self.forward_encoder_non_torchscript(net_input)
1126
+
1127
+ @torch.jit.unused
1128
+ def forward_encoder_non_torchscript(self, net_input: Dict[str, Tensor]):
1129
+ encoder_input = {
1130
+ k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
1131
+ }
1132
+ return self.forward_encoder(**encoder_input)
1133
+
1134
+ def forward_encoder(self, source, padding_mask=None):
1135
+ # Encoder Prenet
1136
+ encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=False)
1137
+
1138
+ # Encoder
1139
+ encoder_output = self.encoder(encoder_input, encoder_padding_mask)
1140
+
1141
+ return encoder_output
1142
+
1143
+ def forward_text_encoder(self, src_tokens):
1144
+ # Text Encoder Prenet
1145
+ encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
1146
+
1147
+ # Encoder
1148
+ encoder_output = self.encoder(encoder_input, encoder_padding_mask)
1149
+
1150
+ return encoder_output
1151
+
1152
+ def forward_decoder(self, tokens, encoder_out, incremental_state):
1153
+ # Decoder Prenet
1154
+ prev_output_tokens, tgt_mask, incremental_state = self.text_decoder_prenet(tokens, incremental_state)
1155
+
1156
+ # Decoder
1157
+ decoder_output, extra = self.decoder(
1158
+ prev_output_tokens,
1159
+ tgt_mask,
1160
+ encoder_out=encoder_out,
1161
+ incremental_state=incremental_state,
1162
+ )
1163
+
1164
+ # Decoder Postnet
1165
+ return self.text_decoder_postnet(decoder_output), extra
1166
+
1167
+ def set_num_updates(self, num_updates):
1168
+ """Set the number of parameters updates."""
1169
+ super().set_num_updates(num_updates)
1170
+ self.num_updates = num_updates
1171
+
1172
+ def generate_class(self, source, prev_output_tokens, **kwargs):
1173
+ encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
1174
+
1175
+ prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens, {})
1176
+ prev_output_tokens = torch.zeros_like(prev_output_tokens) # s2c use zero vector as [CLS]
1177
+
1178
+ decoder_output, extra = self.decoder(
1179
+ prev_output_tokens,
1180
+ tgt_mask,
1181
+ encoder_out=encoder_out,
1182
+ )
1183
+
1184
+ decoder_out, embed = self.speaker_decoder_postnet(decoder_output.mean(1))
1185
+
1186
+ pred_class = decoder_out.argmax(1)
1187
+ return pred_class
1188
+
1189
+ def generate_speech(self, source=None, src_tokens=None, spkembs=None, **kwargs):
1190
+ assert source is not None or src_tokens is not None
1191
+
1192
+ threshold = kwargs.get("threshold", 0.5)
1193
+ minlenratio = kwargs.get("threshold", 0.0)
1194
+
1195
+ if source is None:
1196
+ assert src_tokens.size(0) == 1
1197
+ encoder_out = self.forward_text_encoder(src_tokens)
1198
+ maxlenratio = kwargs.get("threshold", 20.0)
1199
+ else:
1200
+ assert source.size(0) == 1
1201
+ encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
1202
+ maxlenratio = kwargs.get("threshold", 10.0)
1203
+
1204
+ if spkembs is not None and self.spk_embed_integration_type != "pre":
1205
+ encoder_out["encoder_out"] = [self._integrate_with_spk_embed(
1206
+ encoder_out["encoder_out"][0].transpose(0, 1), spkembs
1207
+ ).transpose(0, 1)]
1208
+ spkembs = None
1209
+
1210
+ maxlen = int(encoder_out["encoder_out"][0].size(0) * maxlenratio / self.reduction_factor)
1211
+ minlen = int(encoder_out["encoder_out"][0].size(0) * minlenratio / self.reduction_factor)
1212
+
1213
+ idx = 0
1214
+ ys = encoder_out["encoder_out"][0].new_zeros(1, 1, self.speech_decoder_postnet.odim)
1215
+ outs, probs = [], []
1216
+
1217
+ # forward decoder step-by-step
1218
+ if isinstance(self.decoder, FairseqIncrementalDecoder):
1219
+ incremental_states = {}
1220
+ else:
1221
+ incremental_states = None
1222
+ attns = []
1223
+ while True:
1224
+ # update index
1225
+ idx += 1
1226
+ # calculate output and stop prob at idx-th step
1227
+ decoder_in, _ = self.speech_decoder_prenet(ys, spkembs=spkembs)
1228
+ z, extra = self.decoder(decoder_in[:,-1:], None, encoder_out, incremental_states, alignment_layer=-1)
1229
+ outs += [self.speech_decoder_postnet.feat_out(z[0, -1]).view(self.reduction_factor, self.speech_decoder_postnet.odim)] # [(r, odim), ...]
1230
+ probs += [torch.sigmoid(self.speech_decoder_postnet.prob_out(z[0, -1]))] # [(r), ...]
1231
+
1232
+ # update next inputs
1233
+ ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.speech_decoder_postnet.odim)), dim=1) # (1, idx + 1, odim)
1234
+ attns.append(torch.stack([att_l[0] for att_l in extra['attn'][0]], dim=0))
1235
+ # check whether to finish generation
1236
+ if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
1237
+ # check mininum length
1238
+ if idx < minlen:
1239
+ continue
1240
+ outs = (torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2)) # (L, odim) -> (1, L, odim) -> (1, odim, L)
1241
+ if self.speech_decoder_postnet.postnet is not None:
1242
+ outs = outs + self.speech_decoder_postnet.postnet(outs) # (1, odim, L)
1243
+ outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
1244
+ probs = torch.cat(probs, dim=0)
1245
+ attn = torch.cat(attns, dim=2)
1246
+ break
1247
+
1248
+ if outs.size(0) == maxlen:
1249
+ logging.warning("output length reaches maximum length")
1250
+ return outs, probs, attn
1251
+
1252
+
1253
+ @register_model_architecture(model_name="artst_transformer", arch_name="artst_transformer")
1254
+ def base_architecture(args):
1255
+ # Transformer
1256
+ args.bert_init = getattr(args, "bert_init", False)
1257
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
1258
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4)
1259
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
1260
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
1261
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
1262
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
1263
+ args.decoder_ffn_embed_dim = getattr(
1264
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
1265
+ )
1266
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
1267
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
1268
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
1269
+ args.dropout = getattr(args, "dropout", 0.1)
1270
+ args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
1271
+ args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
1272
+ args.activation_fn = getattr(args, "activation_fn", "gelu")
1273
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
1274
+ args.decoder_output_dim = getattr(
1275
+ args, "decoder_output_dim", args.decoder_embed_dim
1276
+ )
1277
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
1278
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
1279
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
1280
+ args.max_text_positions = getattr(args, "max_text_positions", DEFAULT_MAX_TEXT_POSITIONS)
1281
+ args.max_speech_positions = getattr(args, "max_speech_positions", DEFAULT_MAX_SPEECH_POSITIONS)
1282
+
1283
+ # Espnet related, including prenet, postnet
1284
+ args.eprenet_conv_layers = getattr(args, "eprenet_conv_layers", 0)
1285
+ args.eprenet_conv_filts = getattr(args, "eprenet_conv_filts", 0)
1286
+ args.eprenet_conv_chans = getattr(args, "eprenet_conv_chans", 0)
1287
+ args.use_batch_norm = getattr(args, "use_batch_norm", True)
1288
+ args.eprenet_dropout_rate = getattr(args, "eprenet_dropout_rate", 0.0)
1289
+ args.enc_use_scaled_pos_enc = getattr(args, "enc_use_scaled_pos_enc", True)
1290
+ args.dec_use_scaled_pos_enc = getattr(args, "dec_use_scaled_pos_enc", True)
1291
+ args.postnet_layers = getattr(args, "postnet_layers", 5)
1292
+ args.postnet_chans = getattr(args, "postnet_chans", 256)
1293
+ args.postnet_filts = getattr(args, "postnet_filts", 5)
1294
+ args.postnet_dropout_rate = getattr(args, "postnet_dropout_rate", 0.5)
1295
+ args.dprenet_dropout_rate = getattr(args, "dprenet_dropout_rate", 0.5)
1296
+ args.dprenet_layers = getattr(args, "dprenet_layers", 2)
1297
+ args.dprenet_units = getattr(args, "dprenet_units", 256)
1298
+ args.initial_encoder_alpha = getattr(args, "initial_encoder_alpha", 1.0)
1299
+ args.initial_decoder_alpha = getattr(args, "initial_decoder_alpha", 1.0)
1300
+ args.spk_embed_integration_type = getattr(args, "spk_embed_integration_type", "pre")
1301
+ args.spk_embed_dim = getattr(args, "spk_embed_dim", 512)
1302
+ args.encoder_reduction_factor = getattr(args, "encoder_reduction_factor", 1)
1303
+ args.reduction_factor = getattr(args, "reduction_factor", 2)
1304
+ args.transformer_enc_positional_dropout_rate = getattr(args, "transformer_enc_positional_dropout_rate", 0.1)
1305
+ args.transformer_dec_positional_dropout_rate = getattr(args, "transformer_dec_positional_dropout_rate", 0.1)
1306
+ args.layer_norm_eps = getattr(args, "layer_norm_eps", 1e-5)
1307
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
1308
+ # Convolutional subsampler
1309
+ args.encoder_speech_prenet = getattr(args, "encoder_speech_prenet", "conv")
1310
+ args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
1311
+ args.conv_channels = getattr(args, "conv_channels", 1024)
1312
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
1313
+
1314
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
1315
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
1316
+ args.no_token_positional_embeddings = getattr(
1317
+ args, "no_token_positional_embeddings", False
1318
+ )
1319
+ args.adaptive_input = getattr(args, "adaptive_input", False)
1320
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
1321
+ args.share_input_output_embed = getattr(args, "share_input_output_embed", False)
1322
+ args.share_ctc_embed = getattr(args, "share_ctc_embed", False)
1323
+ args.freeze_encoder_updates = getattr(args, "freeze_encoder_updates", 0)
1324
+ args.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0)
1325
+ args.no_freeze_encoder_layer = getattr(args, "no_freeze_encoder_layer", None)
1326
+
1327
+ ## sid
1328
+ args.sid_embed_dim = getattr(args, "sid_embed_dim", 128)
1329
+ args.sid_pooling_layer = getattr(args, "sid_pooling_layer", "decoder")
1330
+ args.softmax_scale = getattr(args, "softmax_scale", 1)
1331
+ args.softmax_margin = getattr(args, "softmax_margin", 0)
1332
+ args.softmax_easy_margin = getattr(args, "softmax_easy_margin", False)
1333
+ args.modules_filter = getattr(args, "modules_filter", None)
1334
+
1335
+ ## Hubert
1336
+ args.conv_pos = getattr(args, "conv_pos", 128)
1337
+ args.conv_pos_groups = getattr(args, "conv_pos_groups", 16)
1338
+ args.target_glu = getattr(args, "target_glu", False)
1339
+ args.logit_temp = getattr(args, "logit_temp", 0.1)
1340
+ args.final_dim = getattr(args, "final_dim", 256)
1341
+ args.untie_final_proj = getattr(args, "untie_final_proj", True)
1342
+ args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.1)
1343
+ args.use_sent_enc_layer = getattr(args, "use_sent_enc_layer", True)
1344
+ # hubert feature extractor
1345
+ args.extractor_mode = getattr(args, "extractor_mode", "default")
1346
+ args.conv_feature_layers = getattr(args, "conv_feature_layers", "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2")
1347
+ args.conv_bias = getattr(args, "conv_bias", False)
1348
+ # mask
1349
+ args.hubert_mask_length = getattr(args, "hubert_mask_length", 10)
1350
+ args.mask_prob = getattr(args, "mask_prob", 0.0)
1351
+ args.mask_selection = getattr(args, "mask_selection", "static")
1352
+ args.mask_other = getattr(args, "mask_other", 0)
1353
+ args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
1354
+ args.mask_min_space = getattr(args, "mask_min_space", 1)
1355
+ # channel mask
1356
+ args.mask_channel_length = getattr(args, "mask_channel_length", 10)
1357
+ args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.0)
1358
+ args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
1359
+ args.mask_channel_other = getattr(args, "mask_channel_other", 0)
1360
+ args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
1361
+ args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1)
1362
+ # loss computation
1363
+ args.skip_masked = getattr(args, "skip_masked", False)
1364
+ args.skip_nomask = getattr(args, "skip_nomask", False)
1365
+ # conv Pos
1366
+ args.use_conv_pos = getattr(args, "use_conv_pos", False)
1367
+ args.use_sinc_pos = getattr(args, "use_sinc_pos", False)
1368
+
1369
+ # codebook
1370
+ args.use_codebook = getattr(args, "use_codebook", False)
1371
+ args.latent_vars = getattr(args, "latent_vars", 100)
1372
+ args.latent_groups = getattr(args, "latent_groups", 2)
1373
+ args.latent_dim = getattr(args, "latent_dim", 0)
1374
+ args.latent_temp = getattr(args, "latent_temp", (2, 0.5, 0.999995))
1375
+ args.quantizer_depth = getattr(args, "quantizer_depth", 1)
1376
+ args.quantizer_factor = getattr(args, "quantizer_factor", 3)
1377
+ args.codebook_prob = getattr(args, "codebook_prob", 0.5)
1378
+
1379
+ # Relative pos embed
1380
+ args.relative_position_embedding = getattr(args, "relative_position_embedding", False)
1381
+ args.num_buckets = getattr(args, "num_buckets", 320)
1382
+ args.max_distance = getattr(args, "max_distance", 1280)
1383
+ args.encoder_max_relative_position = getattr(args, "encoder_max_relative_position", 160)
1384
+ args.decoder_max_relative_position = getattr(args, "decoder_max_relative_position", 160)
1385
+
1386
+ @register_model_architecture("artst_transformer", "artst_transformer_base")
1387
+ def artst_transformer_base(args):
1388
+ args.use_conv_pos = getattr(args, "use_conv_pos", True)
1389
+ args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
1390
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
1391
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
1392
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
1393
+ args.layer_norm_first = getattr(args, "layer_norm_first", False)
1394
+ args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
1395
+ args.dropout = getattr(args, "dropout", 0.1)
1396
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
1397
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
1398
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.05)
1399
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.05)
1400
+ args.mask_prob = getattr(args, "mask_prob", 0.80)
1401
+ base_architecture(args)
1402
+
1403
+ @register_model_architecture("artst_transformer", "artst_transformer_large")
1404
+ def artst_transformer_large(args):
1405
+ args.use_conv_pos = getattr(args, "use_conv_pos", True)
1406
+ args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
1407
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
1408
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
1409
+ args.layer_norm_first = getattr(args, "layer_norm_first", True)
1410
+ args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
1411
+ args.dropout = getattr(args, "dropout", 0.0)
1412
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
1413
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
1414
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
1415
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
1416
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
1417
+ args.encoder_layers = getattr(args, "encoder_layers", 24)
1418
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
1419
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
1420
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
1421
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
1422
+ args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0)
1423
+ args.extractor_mode = getattr(args, "extractor_mode", "layer_norm")
1424
+ args.final_dim = getattr(args, "final_dim", 768)
1425
+ args.mask_prob = getattr(args, "mask_prob", 0.80)
1426
+ base_architecture(args)
1427
+
1428
+ @register_model_architecture("artst_transformer", "artst_transformer_base_asr")
1429
+ def artst_transformer_base_asr(args):
1430
+ args.use_conv_pos = getattr(args, "use_conv_pos", True)
1431
+ args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
1432
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
1433
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
1434
+ args.layer_norm_first = getattr(args, "layer_norm_first", False)
1435
+ args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
1436
+ args.dropout = getattr(args, "dropout", 0.1)
1437
+ args.activation_dropout = getattr(args, "activation_dropout", 0.1)
1438
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
1439
+ args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0)
1440
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1)
1441
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.1)
1442
+ args.mask_prob = getattr(args, "mask_prob", 0.75)
1443
+ args.mask_selection = getattr(args, "mask_selection", "static")
1444
+ args.mask_channel_length = getattr(args, "mask_channel_length", 64)
1445
+ args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
1446
+ args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
1447
+ args.max_text_positions = getattr(args, "max_text_positions", 600)
1448
+ base_architecture(args)
artst/models/modules/__init__.py ADDED
File without changes
artst/models/modules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (150 Bytes). View file
 
artst/models/modules/__pycache__/decoder.cpython-38.pyc ADDED
Binary file (8.71 kB). View file
 
artst/models/modules/__pycache__/encoder.cpython-38.pyc ADDED
Binary file (8.84 kB). View file
 
artst/models/modules/__pycache__/multihead_attention.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
artst/models/modules/__pycache__/speaker_decoder_postnet.cpython-38.pyc ADDED
Binary file (6.16 kB). View file
 
artst/models/modules/__pycache__/speech_decoder_postnet.cpython-38.pyc ADDED
Binary file (2.05 kB). View file
 
artst/models/modules/__pycache__/speech_decoder_prenet.cpython-38.pyc ADDED
Binary file (3.54 kB). View file
 
artst/models/modules/__pycache__/speech_encoder_postnet.cpython-38.pyc ADDED
Binary file (4.04 kB). View file
 
artst/models/modules/__pycache__/speech_encoder_prenet.cpython-38.pyc ADDED
Binary file (10.2 kB). View file